use self::{future::RouteFuture, not_found::NotFound, path_router::PathRouter};
#[cfg(feature = "tokio")]
use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
use crate::{
body::{Body, HttpBody},
boxed::BoxedIntoRoute,
handler::Handler,
util::try_downcast,
};
use axum_core::response::{IntoResponse, Response};
use http::Request;
use std::{
convert::Infallible,
fmt,
task::{Context, Poll},
};
use sync_wrapper::SyncWrapper;
use tower_layer::Layer;
use tower_service::Service;
pub mod future;
pub mod method_routing;
mod into_make_service;
mod method_filter;
mod not_found;
pub(crate) mod path_router;
mod route;
mod strip_prefix;
pub(crate) mod url_params;
#[cfg(test)]
mod tests;
pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route};
pub use self::method_routing::{
any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service,
options, options_service, patch, patch_service, post, post_service, put, put_service, trace,
trace_service, MethodRouter,
};
macro_rules! panic_on_err {
($expr:expr) => {
match $expr {
Ok(x) => x,
Err(err) => panic!("{err}"),
}
};
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteId(u32);
#[must_use]
pub struct Router<S = (), B = Body> {
path_router: PathRouter<S, B, false>,
fallback_router: PathRouter<S, B, true>,
default_fallback: bool,
catch_all_fallback: Fallback<S, B>,
}
impl<S, B> Clone for Router<S, B> {
fn clone(&self) -> Self {
Self {
path_router: self.path_router.clone(),
fallback_router: self.fallback_router.clone(),
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback.clone(),
}
}
}
impl<S, B> Default for Router<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<S, B> fmt::Debug for Router<S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Router")
.field("path_router", &self.path_router)
.field("fallback_router", &self.fallback_router)
.field("default_fallback", &self.default_fallback)
.field("catch_all_fallback", &self.catch_all_fallback)
.finish()
}
}
pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param";
pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param";
pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback";
pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback";
impl<S, B> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
path_router: Default::default(),
fallback_router: PathRouter::new_fallback(),
default_fallback: true,
catch_all_fallback: Fallback::Default(Route::new(NotFound)),
}
}
#[doc = include_str!("../docs/routing/route.md")]
#[track_caller]
pub fn route(mut self, path: &str, method_router: MethodRouter<S, B>) -> Self {
panic_on_err!(self.path_router.route(path, method_router));
self
}
#[doc = include_str!("../docs/routing/route_service.md")]
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let service = match try_downcast::<Router<S, B>, _>(service) {
Ok(_) => {
panic!(
"Invalid route: `Router::route_service` cannot be used with `Router`s. \
Use `Router::nest` instead"
);
}
Err(service) => service,
};
panic_on_err!(self.path_router.route_service(path, service));
self
}
#[doc = include_str!("../docs/routing/nest.md")]
#[track_caller]
pub fn nest(mut self, path: &str, router: Router<S, B>) -> Self {
let Router {
path_router,
fallback_router,
default_fallback,
catch_all_fallback: _,
} = router;
panic_on_err!(self.path_router.nest(path, path_router));
if !default_fallback {
panic_on_err!(self.fallback_router.nest(path, fallback_router));
}
self
}
#[track_caller]
pub fn nest_service<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
panic_on_err!(self.path_router.nest_service(path, service));
self
}
#[doc = include_str!("../docs/routing/merge.md")]
#[track_caller]
pub fn merge<R>(mut self, other: R) -> Self
where
R: Into<Router<S, B>>,
{
const PANIC_MSG: &str =
"Failed to merge fallbacks. This is a bug in axum. Please file an issue";
let Router {
path_router,
fallback_router: mut other_fallback,
default_fallback,
catch_all_fallback,
} = other.into();
panic_on_err!(self.path_router.merge(path_router));
match (self.default_fallback, default_fallback) {
(true, true) => {
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
}
(true, false) => {
self.fallback_router.merge(other_fallback).expect(PANIC_MSG);
self.default_fallback = false;
}
(false, true) => {
let fallback_router = std::mem::take(&mut self.fallback_router);
other_fallback.merge(fallback_router).expect(PANIC_MSG);
self.fallback_router = other_fallback;
}
(false, false) => {
panic!("Cannot merge two `Router`s that both have a fallback")
}
};
self.catch_all_fallback = self
.catch_all_fallback
.merge(catch_all_fallback)
.unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback"));
self
}
#[doc = include_str!("../docs/routing/layer.md")]
pub fn layer<L, NewReqBody>(self, layer: L) -> Router<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: HttpBody + 'static,
{
Router {
path_router: self.path_router.layer(layer.clone()),
fallback_router: self.fallback_router.layer(layer.clone()),
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)),
}
}
#[doc = include_str!("../docs/routing/route_layer.md")]
#[track_caller]
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
{
Router {
path_router: self.path_router.route_layer(layer),
fallback_router: self.fallback_router,
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback,
}
}
#[track_caller]
#[doc = include_str!("../docs/routing/fallback.md")]
pub fn fallback<H, T>(mut self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
self.catch_all_fallback =
Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone()));
self.fallback_endpoint(Endpoint::MethodRouter(any(handler)))
}
pub fn fallback_service<T>(mut self, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
let route = Route::new(service);
self.catch_all_fallback = Fallback::Service(route.clone());
self.fallback_endpoint(Endpoint::Route(route))
}
fn fallback_endpoint(mut self, endpoint: Endpoint<S, B>) -> Self {
self.fallback_router.set_fallback(endpoint);
self.default_fallback = false;
self
}
#[doc = include_str!("../docs/routing/with_state.md")]
pub fn with_state<S2>(self, state: S) -> Router<S2, B> {
Router {
path_router: self.path_router.with_state(state.clone()),
fallback_router: self.fallback_router.with_state(state.clone()),
default_fallback: self.default_fallback,
catch_all_fallback: self.catch_all_fallback.with_state(state),
}
}
pub(crate) fn call_with_state(
&mut self,
mut req: Request<B>,
state: S,
) -> RouteFuture<B, Infallible> {
if !self.default_fallback {
req.extensions_mut().insert(SuperFallback(SyncWrapper::new(
self.fallback_router.clone(),
)));
}
match self.path_router.call_with_state(req, state) {
Ok(future) => future,
Err((mut req, state)) => {
let super_fallback = req
.extensions_mut()
.remove::<SuperFallback<S, B>>()
.map(|SuperFallback(path_router)| path_router.into_inner());
if let Some(mut super_fallback) = super_fallback {
match super_fallback.call_with_state(req, state) {
Ok(future) => return future,
Err((req, state)) => {
return self.catch_all_fallback.call_with_state(req, state);
}
}
}
match self.fallback_router.call_with_state(req, state) {
Ok(future) => future,
Err((req, state)) => self.catch_all_fallback.call_with_state(req, state),
}
}
}
}
}
impl<B> Router<(), B>
where
B: HttpBody + Send + 'static,
{
pub fn into_make_service(self) -> IntoMakeService<Self> {
IntoMakeService::new(self.with_state(()))
}
#[doc = include_str!("../docs/routing/into_make_service_with_connect_info.md")]
#[cfg(feature = "tokio")]
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
IntoMakeServiceWithConnectInfo::new(self.with_state(()))
}
}
impl<B> Service<Request<B>> for Router<(), B>
where
B: HttpBody + Send + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = RouteFuture<B, Infallible>;
#[inline]
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.call_with_state(req, ())
}
}
enum Fallback<S, B, E = Infallible> {
Default(Route<B, E>),
Service(Route<B, E>),
BoxedHandler(BoxedIntoRoute<S, B, E>),
}
impl<S, B, E> Fallback<S, B, E>
where
S: Clone,
{
fn merge(self, other: Self) -> Option<Self> {
match (self, other) {
(Self::Default(_), pick @ Self::Default(_)) => Some(pick),
(Self::Default(_), pick) | (pick, Self::Default(_)) => Some(pick),
_ => None,
}
}
fn map<F, B2, E2>(self, f: F) -> Fallback<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: HttpBody + 'static,
E2: 'static,
{
match self {
Self::Default(route) => Fallback::Default(f(route)),
Self::Service(route) => Fallback::Service(f(route)),
Self::BoxedHandler(handler) => Fallback::BoxedHandler(handler.map(f)),
}
}
fn with_state<S2>(self, state: S) -> Fallback<S2, B, E> {
match self {
Fallback::Default(route) => Fallback::Default(route),
Fallback::Service(route) => Fallback::Service(route),
Fallback::BoxedHandler(handler) => Fallback::Service(handler.into_route(state)),
}
}
fn call_with_state(&mut self, req: Request<B>, state: S) -> RouteFuture<B, E> {
match self {
Fallback::Default(route) | Fallback::Service(route) => {
RouteFuture::from_future(route.oneshot_inner(req))
}
Fallback::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
RouteFuture::from_future(route.oneshot_inner(req))
}
}
}
}
impl<S, B, E> Clone for Fallback<S, B, E> {
fn clone(&self) -> Self {
match self {
Self::Default(inner) => Self::Default(inner.clone()),
Self::Service(inner) => Self::Service(inner.clone()),
Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
}
}
}
impl<S, B, E> fmt::Debug for Fallback<S, B, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Default(inner) => f.debug_tuple("Default").field(inner).finish(),
Self::Service(inner) => f.debug_tuple("Service").field(inner).finish(),
Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
}
}
}
#[allow(clippy::large_enum_variant)]
enum Endpoint<S, B> {
MethodRouter(MethodRouter<S, B>),
Route(Route<B>),
}
impl<S, B> Endpoint<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
fn layer<L, NewReqBody>(self, layer: L) -> Endpoint<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: HttpBody + 'static,
{
match self {
Endpoint::MethodRouter(method_router) => {
Endpoint::MethodRouter(method_router.layer(layer))
}
Endpoint::Route(route) => Endpoint::Route(route.layer(layer)),
}
}
}
impl<S, B> Clone for Endpoint<S, B> {
fn clone(&self) -> Self {
match self {
Self::MethodRouter(inner) => Self::MethodRouter(inner.clone()),
Self::Route(inner) => Self::Route(inner.clone()),
}
}
}
impl<S, B> fmt::Debug for Endpoint<S, B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MethodRouter(method_router) => {
f.debug_tuple("MethodRouter").field(method_router).finish()
}
Self::Route(route) => f.debug_tuple("Route").field(route).finish(),
}
}
}
struct SuperFallback<S, B>(SyncWrapper<PathRouter<S, B, true>>);
#[test]
#[allow(warnings)]
fn traits() {
use crate::test_helpers::*;
assert_send::<Router<(), ()>>();
}