//! Access control based on HTTP headers pub mod cors; pub mod host; mod matcher; pub use cors::{AllowHeaders, AllowOrigin, Origin}; pub use host::{AllowHosts, Host}; use crate::Error; use self::cors::get_cors_allow_origin; /// Define access on control on HTTP layer. #[derive(Clone, Debug)] pub struct AccessControl { allowed_hosts: AllowHosts, allowed_origins: Option>, allowed_headers: AllowHeaders, } impl AccessControl { /// Validate incoming request by host. /// /// `host` is the return value from the `host header` pub fn verify_host(&self, host: &str) -> Result<(), Error> { self.allowed_hosts.verify(host) } /// Validate incoming request by origin. /// /// `host` is the return value from the `host header` /// `origin` is the value from the `origin header`. pub fn verify_origin(&self, origin: Option<&str>, host: &str) -> Result<(), Error> { if let cors::AllowCors::Invalid = get_cors_allow_origin(origin, &self.allowed_origins, Some(host)) { Err(Error::HttpHeaderRejected("origin", origin.unwrap_or("").into())) } else { Ok(()) } } /// Validate incoming request by CORS(`access-control-request-headers`). /// /// header_name: all keys of the header in the request /// cors_request_headers: values of `access-control-request-headers` headers. /// pub fn verify_headers(&self, header_names: I, cors_request_headers: II) -> Result<(), Error> where T: AsRef, I: Iterator, II: Iterator, { let header = cors::get_cors_allow_headers(header_names, cors_request_headers, &self.allowed_headers, |name| name); if let cors::AllowCors::Invalid = header { Err(Error::HttpHeaderRejected( "access-control-request-headers", "".into(), )) } else { Ok(()) } } /// Return the allowed headers we've set pub fn allowed_headers(&self) -> &AllowHeaders { &self.allowed_headers } } impl Default for AccessControl { fn default() -> Self { Self { allowed_hosts: AllowHosts::Any, allowed_origins: None, allowed_headers: AllowHeaders::Any } } } /// Convenience builder pattern #[derive(Debug)] pub struct AccessControlBuilder { allowed_hosts: AllowHosts, allowed_origins: Option>, allowed_headers: AllowHeaders, } impl Default for AccessControlBuilder { fn default() -> Self { Self { allowed_hosts: AllowHosts::Any, allowed_origins: None, allowed_headers: AllowHeaders::Any } } } impl AccessControlBuilder { /// Create a new builder for `AccessControl`. pub fn new() -> Self { Self::default() } /// Allow all hosts. pub fn allow_all_hosts(mut self) -> Self { self.allowed_hosts = AllowHosts::Any; self } /// Allow all origins. pub fn allow_all_origins(mut self) -> Self { self.allowed_origins = None; self } /// Allow all headers. pub fn allow_all_headers(mut self) -> Self { self.allowed_headers = AllowHeaders::Any; self } /// Configure allowed hosts. /// /// Default - allow all. pub fn set_allowed_hosts(mut self, list: List) -> Result where List: IntoIterator, H: Into, { let allowed_hosts: Vec<_> = list.into_iter().map(|s| Host::parse(&s.into())).map(Into::into).collect(); if allowed_hosts.is_empty() { return Err(Error::EmptyAllowList("Host")); } self.allowed_hosts = AllowHosts::Only(allowed_hosts); Ok(self) } /// Configure allowed origins. /// /// Default - allow all. pub fn set_allowed_origins(mut self, list: List) -> Result where List: IntoIterator, Origin: Into, { let allowed_origins: Vec = list.into_iter().map(Into::into).map(Into::into).collect(); if allowed_origins.is_empty() { return Err(Error::EmptyAllowList("Origin")); } self.allowed_origins = Some(allowed_origins); Ok(self) } /// Configure allowed CORS headers. /// /// Default - allow all. pub fn set_allowed_headers(mut self, list: List) -> Result where List: IntoIterator, Header: Into, { let allowed_headers: Vec = list.into_iter().map(Into::into).collect(); if allowed_headers.is_empty() { return Err(Error::EmptyAllowList("Header")); } self.allowed_headers = AllowHeaders::Only(allowed_headers); Ok(self) } /// Finalize the `AccessControl` settings. pub fn build(self) -> AccessControl { AccessControl { allowed_hosts: self.allowed_hosts, allowed_origins: self.allowed_origins, allowed_headers: self.allowed_headers, } } }