use super::{
rejection::{FailedToResolveHost, HostRejection},
FromRequestParts,
};
use async_trait::async_trait;
use http::{
header::{HeaderMap, FORWARDED},
request::Parts,
};
const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host";
#[derive(Debug, Clone)]
pub struct Host(pub String);
#[async_trait]
impl<S> FromRequestParts<S> for Host
where
S: Send + Sync,
{
type Rejection = HostRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(host) = parse_forwarded(&parts.headers) {
return Ok(Host(host.to_owned()));
}
if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST_HEADER_KEY)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}
if let Some(host) = parts
.headers
.get(http::header::HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(Host(host.to_owned()));
}
if let Some(host) = parts.uri.host() {
return Ok(Host(host.to_owned()));
}
Err(HostRejection::FailedToResolveHost(FailedToResolveHost))
}
}
#[allow(warnings)]
fn parse_forwarded(headers: &HeaderMap) -> Option<&str> {
let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?;
let first_value = forwarded_values.split(',').nth(0)?;
first_value.split(';').find_map(|pair| {
let (key, value) = pair.split_once('=')?;
key.trim()
.eq_ignore_ascii_case("host")
.then(|| value.trim().trim_matches('"'))
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::TestClient, Router};
use http::header::HeaderName;
fn test_client() -> TestClient {
async fn host_as_body(Host(host): Host) -> String {
host
}
TestClient::new(Router::new().route("/", get(host_as_body)))
}
#[crate::test]
async fn host_header() {
let original_host = "some-domain:123";
let host = test_client()
.get("/")
.header(http::header::HOST, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}
#[crate::test]
async fn x_forwarded_host_header() {
let original_host = "some-domain:456";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, original_host)
.send()
.await
.text()
.await;
assert_eq!(host, original_host);
}
#[crate::test]
async fn x_forwarded_host_precedence_over_host_header() {
let x_forwarded_host_header = "some-domain:456";
let host_header = "some-domain:123";
let host = test_client()
.get("/")
.header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header)
.header(http::header::HOST, host_header)
.send()
.await
.text()
.await;
assert_eq!(host, x_forwarded_host_header);
}
#[crate::test]
async fn uri_host() {
let host = test_client().get("/").send().await.text().await;
assert!(host.contains("127.0.0.1"));
}
#[test]
fn forwarded_parsing() {
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
let headers = header_map(&[(FORWARDED, "host=\"[2001:db8:cafe::17]:4711\"")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "[2001:db8:cafe::17]:4711");
let headers = header_map(&[(FORWARDED, "host=192.0.2.60, host=127.0.0.1")]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
let headers = header_map(&[
(FORWARDED, "host=192.0.2.60"),
(FORWARDED, "host=127.0.0.1"),
]);
let value = parse_forwarded(&headers).unwrap();
assert_eq!(value, "192.0.2.60");
}
fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap {
let mut headers = HeaderMap::new();
for (key, value) in values {
headers.append(key, value.parse().unwrap());
}
headers
}
}