Unverified Commit 600cad08 authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

fix(ws server): support `*` in host and origin filtering (#781)



* initial rewrite to re-use HTTP access control

* clean things up

* Update core/src/error.rs

* Update core/src/error.rs

* allow origin: add back removed Display impl

* cleanup again

* Update http-server/src/lib.rs

* Update examples/examples/cors_server.rs

* Update core/src/server/access_control/mod.rs
Co-authored-by: default avatarTarik Gul <47201679+TarikGul@users.noreply.github.com>

* Update http-server/src/server.rs
Co-authored-by: default avatarTarik Gul <47201679+TarikGul@users.noreply.github.com>

* fix bad comment

* remove todo

* fix grumbles

* more grumbles

* rename and document a bit

* remove `Access-Control-Allow-Origin` in whitelist

* fix nit: pub(super)

* fix bad naming
Co-authored-by: default avatarTarik Gul <47201679+TarikGul@users.noreply.github.com>
parent 6888804f
Pipeline #198342 passed with stages
in 6 minutes
......@@ -30,6 +30,9 @@ parking_lot = { version = "0.12", optional = true }
tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-timer = { version = "3", optional = true }
globset = { version = "0.4", optional = true }
lazy_static = { version = "1", optional = true }
unicase = { version = "2.6.0", optional = true }
[features]
default = []
......@@ -37,12 +40,15 @@ http-helpers = ["hyper", "futures-util"]
server = [
"arrayvec",
"futures-util/alloc",
"globset",
"rustc-hash/std",
"tracing",
"parking_lot",
"rand",
"tokio/rt",
"tokio/sync",
"lazy_static",
"unicase",
]
client = ["futures-util/sink", "futures-channel/sink", "futures-channel/std"]
async-client = [
......
......@@ -105,9 +105,12 @@ pub enum Error {
/// Attempted to stop server that is already stopped.
#[error("Attempted to stop server that is already stopped")]
AlreadyStopped,
/// List passed into `set_allowed_origins` was empty
/// List passed into access control based on HTTP header verification.
#[error("Must set at least one allowed value for the {0} header")]
EmptyAllowList(&'static str),
/// Access control verification of HTTP headers failed.
#[error("HTTP header: `{0}` value: `{1}` verification failed")]
HttpHeaderRejected(&'static str, String),
/// Failed to execute a method because a resource was already at capacity
#[error("Resource at capacity: {0}")]
ResourceAtCapacity(&'static str),
......
......@@ -101,13 +101,25 @@ pub fn read_header_value<'a>(headers: &'a hyper::header::HeaderMap, header_name:
pub fn read_header_values<'a>(
headers: &'a hyper::header::HeaderMap,
header_name: &str,
) -> hyper::header::ValueIter<'a, hyper::header::HeaderValue> {
headers.get_all(header_name).iter()
) -> hyper::header::GetAll<'a, hyper::header::HeaderValue> {
headers.get_all(header_name)
}
/// Get the header values from the `access-control-request-headers` header.
pub fn get_cors_request_headers<'a>(headers: &'a hyper::header::HeaderMap) -> impl Iterator<Item = &str> {
const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers";
read_header_values(headers, ACCESS_CONTROL_REQUEST_HEADERS)
.iter()
.filter_map(|val| val.to_str().ok())
.flat_map(|val| val.split(","))
// The strings themselves might contain leading and trailing whitespaces
.map(|s| s.trim())
}
#[cfg(test)]
mod tests {
use super::{read_body, read_header_content_length};
use super::{get_cors_request_headers, read_body, read_header_content_length};
#[tokio::test]
async fn body_to_bytes_size_limit_works() {
......@@ -132,4 +144,23 @@ mod tests {
headers.insert(hyper::header::CONTENT_LENGTH, "18446744073709551616".parse().unwrap());
assert_eq!(read_header_content_length(&headers), None);
}
#[test]
fn get_cors_headers_works() {
let mut headers = hyper::header::HeaderMap::new();
// access-control-request-headers
headers.insert(hyper::header::ACCESS_CONTROL_REQUEST_HEADERS, "Content-Type,x-requested-with".parse().unwrap());
let values: Vec<&str> = get_cors_request_headers(&headers).collect();
assert_eq!(values, vec!["Content-Type", "x-requested-with"]);
headers.insert(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS,
"Content-Type, x-requested-with ".parse().unwrap(),
);
let values: Vec<&str> = get_cors_request_headers(&headers).collect();
assert_eq!(values, vec!["Content-Type", "x-requested-with"]);
}
}
......@@ -29,9 +29,9 @@
use std::collections::HashSet;
use std::{fmt, ops};
use crate::access_control::hosts::{Host, Port};
use crate::access_control::matcher::{Matcher, Pattern};
use jsonrpsee_core::Cow;
use crate::server::access_control::host::{Host, Port};
use crate::server::access_control::matcher::{Matcher, Pattern};
use crate::Cow;
use lazy_static::lazy_static;
use unicase::Ascii;
......@@ -128,54 +128,54 @@ impl ops::Deref for Origin {
/// Origins allowed to access
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AccessControlAllowOrigin {
/// Specific hostname
Value(Origin),
pub enum AllowOrigin {
/// Specific origin.
Origin(Origin),
/// null-origin (file:///, sandboxed iframe)
Null,
/// Any non-null origin
Any,
}
impl fmt::Display for AccessControlAllowOrigin {
impl fmt::Display for AllowOrigin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
match *self {
AccessControlAllowOrigin::Any => "*",
AccessControlAllowOrigin::Null => "null",
AccessControlAllowOrigin::Value(ref val) => val,
Self::Any => "*",
Self::Null => "null",
Self::Origin(ref val) => val,
}
)
}
}
impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
fn from(s: T) -> AccessControlAllowOrigin {
impl<T: Into<String>> From<T> for AllowOrigin {
fn from(s: T) -> Self {
match s.into().as_str() {
"all" | "*" | "any" => AccessControlAllowOrigin::Any,
"null" => AccessControlAllowOrigin::Null,
origin => AccessControlAllowOrigin::Value(origin.into()),
"all" | "*" | "any" => Self::Any,
"null" => Self::Null,
origin => Self::Origin(origin.into()),
}
}
}
/// Headers allowed to access
#[derive(Debug, Clone, PartialEq)]
pub enum AccessControlAllowHeaders {
pub enum AllowHeaders {
/// Specific headers
Only(Vec<String>),
/// Any header
Any,
}
impl AccessControlAllowHeaders {
impl AllowHeaders {
/// Return an appropriate value for the CORS header "Access-Control-Allow-Headers".
pub fn to_cors_header_value(&self) -> Cow<'_, str> {
match self {
AccessControlAllowHeaders::Any => "*".into(),
AccessControlAllowHeaders::Only(headers) => headers.join(", ").into(),
AllowHeaders::Any => "*".into(),
AllowHeaders::Only(headers) => headers.join(", ").into(),
}
}
}
......@@ -219,11 +219,11 @@ impl<T> From<AllowCors<T>> for Option<T> {
}
/// Returns correct CORS header (if any) given list of allowed origins and current origin.
pub(crate) fn get_cors_allow_origin(
pub(super) fn get_cors_allow_origin(
origin: Option<&str>,
allowed: &Option<Vec<AllowOrigin>>,
host: Option<&str>,
allowed: &Option<Vec<AccessControlAllowOrigin>>,
) -> AllowCors<AccessControlAllowOrigin> {
) -> AllowCors<AllowOrigin> {
match origin {
None => AllowCors::NotRequired,
Some(ref origin) => {
......@@ -239,22 +239,22 @@ pub(crate) fn get_cors_allow_origin(
}
match allowed.as_ref() {
None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
None if *origin == "null" => AllowCors::Ok(AllowOrigin::Null),
None => AllowCors::Ok(AllowOrigin::Origin(Origin::parse(origin))),
Some(allowed) if *origin == "null" => allowed
.iter()
.find(|cors| **cors == AccessControlAllowOrigin::Null)
.find(|cors| **cors == AllowOrigin::Null)
.cloned()
.map(AllowCors::Ok)
.unwrap_or(AllowCors::Invalid),
Some(allowed) => allowed
.iter()
.find(|cors| match **cors {
AccessControlAllowOrigin::Any => true,
AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
AllowOrigin::Any => true,
AllowOrigin::Origin(ref val) if val.matches(origin) => true,
_ => false,
})
.map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
.map(|_| AllowOrigin::Origin(Origin::parse(origin)))
.map(AllowCors::Ok)
.unwrap_or(AllowCors::Invalid),
}
......@@ -262,15 +262,19 @@ pub(crate) fn get_cors_allow_origin(
}
}
/// Validates if the `AccessControlAllowedHeaders` in the request are allowed.
/// Validates if the headers in the request are allowed.
///
/// headers: all the headers in the request.
/// cors_request_headers: `values` in the `access-control-request-headers` header.
/// cors_allow_headers: whitelisted headers by the user.
pub(crate) fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
mut headers: impl Iterator<Item = T>,
requested_headers: impl Iterator<Item = T>,
cors_allow_headers: &AccessControlAllowHeaders,
cors_request_headers: impl Iterator<Item = T>,
cors_allow_headers: &AllowHeaders,
to_result: F,
) -> AllowCors<Vec<O>> {
// Check if the header fields which were sent in the request are allowed
if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
if let AllowHeaders::Only(only) = cors_allow_headers {
let are_all_allowed = headers.all(|header| {
let name = &Ascii::new(header.as_ref());
only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
......@@ -283,13 +287,13 @@ pub(crate) fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
// Check if `AccessControlRequestHeaders` contains fields which were allowed
let (filtered, headers) = match cors_allow_headers {
AccessControlAllowHeaders::Any => {
let headers = requested_headers.map(to_result).collect();
AllowHeaders::Any => {
let headers = cors_request_headers.map(to_result).collect();
(false, headers)
}
AccessControlAllowHeaders::Only(only) => {
AllowHeaders::Only(only) => {
let mut filtered = false;
let headers: Vec<_> = requested_headers
let headers: Vec<_> = cors_request_headers
.filter(|header| {
let name = &Ascii::new(header.as_ref());
filtered = true;
......@@ -319,7 +323,6 @@ lazy_static! {
let mut hs = HashSet::new();
hs.insert(Ascii::new("Accept"));
hs.insert(Ascii::new("Accept-Language"));
hs.insert(Ascii::new("Access-Control-Allow-Origin"));
hs.insert(Ascii::new("Access-Control-Request-Headers"));
hs.insert(Ascii::new("Content-Language"));
hs.insert(Ascii::new("Content-Type"));
......@@ -337,7 +340,7 @@ mod tests {
use std::iter;
use super::*;
use crate::access_control::hosts::Host;
use crate::server::access_control::host::Host;
#[test]
fn should_parse_origin() {
......@@ -365,8 +368,8 @@ mod tests {
let host = Some(&*host);
// when
let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
let res1 = get_cors_allow_origin(origin1, &Some(vec![]), host);
let res2 = get_cors_allow_origin(origin2, &Some(vec![]), host);
// then
assert_eq!(res1, AllowCors::Invalid);
......@@ -383,7 +386,7 @@ mod tests {
let host = Some(&*host);
// when
let res = get_cors_allow_origin(origin, host, &None);
let res = get_cors_allow_origin(origin, &None, host);
// then
assert_eq!(res, AllowCors::NotRequired);
......@@ -396,7 +399,7 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(origin, host, &None);
let res = get_cors_allow_origin(origin, &None, host);
// then
assert_eq!(res, AllowCors::NotRequired);
......@@ -409,7 +412,7 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(origin, host, &None);
let res = get_cors_allow_origin(origin, &None, host);
// then
assert_eq!(res, AllowCors::Ok("parity.io".into()));
......@@ -422,11 +425,7 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(
origin,
host,
&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
);
let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Origin("http://ethereum.org".into())]), host);
// then
assert_eq!(res, AllowCors::NotRequired);
......@@ -439,7 +438,7 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
let res = get_cors_allow_origin(origin, &Some(Vec::new()), host);
// then
assert_eq!(res, AllowCors::NotRequired);
......@@ -452,11 +451,7 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(
origin,
host,
&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
);
let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Origin("http://ethereum.org".into())]), host);
// then
assert_eq!(res, AllowCors::Invalid);
......@@ -469,10 +464,10 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Any]), host);
// then
assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
assert_eq!(res, AllowCors::Ok(AllowOrigin::Origin("http://parity.io".into())));
}
#[test]
......@@ -482,7 +477,7 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Null]), host);
// then
assert_eq!(res, AllowCors::NotRequired);
......@@ -495,10 +490,10 @@ mod tests {
let host = None;
// when
let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
let res = get_cors_allow_origin(origin, &Some(vec![AllowOrigin::Null]), host);
// then
assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
assert_eq!(res, AllowCors::Ok(AllowOrigin::Null));
}
#[test]
......@@ -510,15 +505,15 @@ mod tests {
// when
let res = get_cors_allow_origin(
origin,
host,
&Some(vec![
AccessControlAllowOrigin::Value("http://ethereum.org".into()),
AccessControlAllowOrigin::Value("http://parity.io".into()),
AllowOrigin::Origin("http://ethereum.org".into()),
AllowOrigin::Origin("http://parity.io".into()),
]),
host,
);
// then
assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
assert_eq!(res, AllowCors::Ok(AllowOrigin::Origin("http://parity.io".into())));
}
#[test]
......@@ -528,26 +523,24 @@ mod tests {
let origin2 = Some("http://parity.iot");
let origin3 = Some("chrome-extension://test");
let host = None;
let allowed = Some(vec![
AccessControlAllowOrigin::Value("http://*.io".into()),
AccessControlAllowOrigin::Value("chrome-extension://*".into()),
]);
let allowed =
Some(vec![AllowOrigin::Origin("http://*.io".into()), AllowOrigin::Origin("chrome-extension://*".into())]);
// when
let res1 = get_cors_allow_origin(origin1, host, &allowed);
let res2 = get_cors_allow_origin(origin2, host, &allowed);
let res3 = get_cors_allow_origin(origin3, host, &allowed);
let res1 = get_cors_allow_origin(origin1, &allowed, host);
let res2 = get_cors_allow_origin(origin2, &allowed, host);
let res3 = get_cors_allow_origin(origin3, &allowed, host);
// then
assert_eq!(res1, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())));
assert_eq!(res1, AllowCors::Ok(AllowOrigin::Origin("http://parity.io".into())));
assert_eq!(res2, AllowCors::Invalid);
assert_eq!(res3, AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into())));
assert_eq!(res3, AllowCors::Ok(AllowOrigin::Origin("chrome-extension://test".into())));
}
#[test]
fn should_return_invalid_if_header_not_allowed() {
// given
let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]);
let cors_allow_headers = AllowHeaders::Only(vec!["x-allowed".to_owned()]);
let headers = vec!["Access-Control-Request-Headers"];
let requested = vec!["x-not-allowed"];
......@@ -562,7 +555,7 @@ mod tests {
fn should_return_valid_if_header_allowed() {
// given
let allowed = vec!["x-allowed".to_owned()];
let cors_allow_headers = AccessControlAllowHeaders::Only(allowed);
let cors_allow_headers = AllowHeaders::Only(allowed);
let headers = vec!["Access-Control-Request-Headers"];
let requested = vec!["x-allowed"];
......@@ -578,7 +571,7 @@ mod tests {
fn should_return_no_allowed_headers_if_none_in_request() {
// given
let allowed = vec!["x-allowed".to_owned()];
let cors_allow_headers = AccessControlAllowHeaders::Only(allowed);
let cors_allow_headers = AllowHeaders::Only(allowed);
let headers: Vec<String> = vec![];
// when
......@@ -591,7 +584,7 @@ mod tests {
#[test]
fn should_return_not_required_if_any_header_allowed() {
// given
let cors_allow_headers = AccessControlAllowHeaders::Any;
let cors_allow_headers = AllowHeaders::Any;
let headers: Vec<String> = vec![];
// when
......
......@@ -26,7 +26,8 @@
//! Host header validation.
use crate::access_control::matcher::{Matcher, Pattern};
use crate::server::access_control::matcher::{Matcher, Pattern};
use crate::Error;
const SPLIT_PROOF: &str = "split always returns non-empty iterator.";
......@@ -139,47 +140,31 @@ impl std::ops::Deref for Host {
}
}
/// Specifies if domains should be validated.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DomainsValidation<T> {
/// Allow only domains on the list.
AllowOnly(Vec<T>),
/// Disable domains validation completely.
Disabled,
/// Policy for validating the `HTTP host header`.
#[derive(Debug, Clone)]
pub enum AllowHosts {
/// Allow all hosts (no filter).
Any,
/// Allow only specified hosts.
Only(Vec<Host>),
}
impl<T> From<Option<Vec<T>>> for DomainsValidation<T> {
fn from(other: Option<Vec<T>>) -> Self {
match other {
Some(list) => DomainsValidation::AllowOnly(list),
None => DomainsValidation::Disabled,
impl AllowHosts {
/// Verify a host.
pub fn verify(&self, value: &str) -> Result<(), Error> {
if let AllowHosts::Only(list) = self {
if !list.iter().any(|o| o.matches(value)) {
return Err(Error::HttpHeaderRejected("host", value.into()));
}
}
}
}
/// Returns `true` when `Host` header is whitelisted in `allow_hosts`.
pub(crate) fn is_host_valid(host: Option<&str>, allow_hosts: &AllowHosts) -> bool {
match host {
None => false,
Some(ref host) => match allow_hosts {
AllowHosts::Any => true,
AllowHosts::Only(allow_hosts) => allow_hosts.iter().any(|h| h.matches(host)),
},
Ok(())
}
}
/// Allowed hosts for http header 'host'
#[derive(Clone, Debug)]
pub enum AllowHosts {
/// Allow requests from any host
Any,
/// Allow only a selection of specific hosts
Only(Vec<Host>),
}
#[cfg(test)]
mod tests {
use super::{is_host_valid, AllowHosts, Host};
use super::{AllowHosts, Host, Port};
#[test]
fn should_parse_host() {
......@@ -188,43 +173,35 @@ mod tests {
assert_eq!(Host::parse("chrome-extension://124.0.0.1"), Host::new("124.0.0.1", None));
assert_eq!(Host::parse("parity.io/somepath"), Host::new("parity.io", None));
assert_eq!(Host::parse("127.0.0.1:8545/somepath"), Host::new("127.0.0.1", Some(8545)));
}
#[test]
fn should_reject_when_there_is_no_header() {
let valid = is_host_valid(None, &AllowHosts::Any);
assert!(!valid);
let valid = is_host_valid(None, &AllowHosts::Only(vec![]));
assert!(!valid);
let host = Host::parse("*.domain:*/somepath");
assert_eq!(host.port, Port::Pattern("*".into()));
assert_eq!(host.hostname.as_str(), "*.domain");
}
#[test]
fn should_reject_when_validation_is_disabled() {
let valid = is_host_valid(Some("any"), &AllowHosts::Any);
assert!(valid);
fn should_allow_when_validation_is_disabled() {
assert!((AllowHosts::Any).verify("any").is_ok());
}
#[test]
fn should_reject_if_header_not_on_the_list() {
let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec![]));
assert!(!valid);
assert!((AllowHosts::Only(vec![].into())).verify("parity.io").is_err());
}
#[test]
fn should_accept_if_on_the_list() {
let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec!["parity.io".into()]));
assert!(valid);
assert!((AllowHosts::Only(vec!["parity.io".into()].into())).verify("parity.io").is_ok());
}
#[test]
fn should_accept_if_on_the_list_with_port() {
let valid = is_host_valid(Some("parity.io:443"), &AllowHosts::Only(vec!["parity.io:443".into()]));
assert!(valid);
assert!((AllowHosts::Only(vec!["parity.io:443".into()].into())).verify("parity.io:443").is_ok());
assert!((AllowHosts::Only(vec!["parity.io".into()].into())).verify("parity.io:443").is_err());
}
#[test]
fn should_support_wildcards() {
let valid = is_host_valid(Some("parity.web3.site:8180"), &AllowHosts::Only(vec!["*.web3.site:*".into()]));
assert!(valid);
assert!((AllowHosts::Only(vec!["*.web3.site:*".into()].into())).verify("parity.web3.site:8180").is_ok());
}
}