Unverified Commit 1ebaf626 authored by Alexandru Vasile's avatar Alexandru Vasile Committed by GitHub
Browse files

Remove CORS logic (#851)



* http: Add inner server data structure

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Handle RPC messages

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Implement equivalent of `service_fn`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Implement equivalent of `make_service_fn`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Expose `tower` compatible service

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Prebuild http server with optional listener

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* examples: WIP tower service

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Fix warnings

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* tower_http: Fix warnings

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Ensure service works with tower

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Remove `RPSeeServerMakeSvc` to allow further flexibility

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* tower_http: Fix warnings

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* tower_http: Resubmit the same request for testing

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Transform builder into service directly

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Rename `RPSeeServerSvc` into user friendly `TowerService`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Rely on internal TowerService to handle requests

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Fix middleware typo

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-server: Improve API builder for tower service

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Rename the inner service data and check comments

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* examples: Add comments

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-server: Receive tower service builder as param

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* examples: Adjust tower_http example

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-server: Add tower middleware on the HttpBuilder

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-server: Do not expose the internal `TowerService` for now

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Update http-server/src/server.rs

Co-authored-by: Niklas Adolfsson's avatarNiklas Adolfsson <niklasadolfsson1@gmail.com>

* http-server: Use `std::error::Error`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Fix fmt

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-server: Remove header and CORS validation

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* core: Remove CORS logic

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* examples: Add custom CORS layer to the RPC

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* address some grumbles

* fix more grumbles: no more Infallible

* make clippy happy

* Rename tower http example

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-server: Remove handling of OPTIONS request

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* tests: Test CORS with external layers

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* examples: Document access control and external CORS layer

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Remove unused deps

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* remove unused CORS code

* Remove extra lifetime param

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Rename `invalid_allow_origin` to `origin_rejected`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Fix clippy

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Update core/src/server/access_control/origin.rs

Co-authored-by: Niklas Adolfsson's avatarNiklas Adolfsson <niklasadolfsson1@gmail.com>

* Rename `AnyNonNull` to `Wildcard`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Rename `OriginType` to `Origin`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>
Co-authored-by: Niklas Adolfsson's avatarNiklas Adolfsson <niklasadolfsson1@gmail.com>
parent 04a695ac
Pipeline #208510 canceled with stages
in 1 minute and 5 seconds
......@@ -32,8 +32,6 @@ tokio = { version = "1.16", optional = true }
wasm-bindgen-futures = { version = "0.4.19", optional = true }
futures-timer = { version = "3", optional = true }
globset = { version = "0.4", optional = true }
lazy_static = { version = "1", optional = true }
unicase = { version = "2.6.0", optional = true }
http = { version = "0.2.7", optional = true }
[features]
......@@ -48,8 +46,6 @@ server = [
"rand",
"tokio/rt",
"tokio/sync",
"lazy_static",
"unicase",
"http",
"hyper",
]
......
......@@ -105,21 +105,9 @@ pub fn read_header_values<'a>(
headers.get_all(header_name)
}
/// Get the header values from the `access-control-request-headers` header.
pub fn get_cors_request_headers<'a>(headers: &'a hyper::header::HeaderMap) -> impl Iterator<Item = &str> {
const ACCESS_CONTROL_REQUEST_HEADERS: &str = "access-control-request-headers";
read_header_values(headers, ACCESS_CONTROL_REQUEST_HEADERS)
.iter()
.filter_map(|val| val.to_str().ok())
.flat_map(|val| val.split(','))
// The strings themselves might contain leading and trailing whitespaces
.map(|s| s.trim())
}
#[cfg(test)]
mod tests {
use super::{get_cors_request_headers, read_body, read_header_content_length};
use super::{read_body, read_header_content_length};
#[tokio::test]
async fn body_to_bytes_size_limit_works() {
......@@ -144,23 +132,4 @@ mod tests {
headers.insert(hyper::header::CONTENT_LENGTH, "18446744073709551616".parse().unwrap());
assert_eq!(read_header_content_length(&headers), None);
}
#[test]
fn get_cors_headers_works() {
let mut headers = hyper::header::HeaderMap::new();
// access-control-request-headers
headers.insert(hyper::header::ACCESS_CONTROL_REQUEST_HEADERS, "Content-Type,x-requested-with".parse().unwrap());
let values: Vec<&str> = get_cors_request_headers(&headers).collect();
assert_eq!(values, vec!["Content-Type", "x-requested-with"]);
headers.insert(
hyper::header::ACCESS_CONTROL_REQUEST_HEADERS,
"Content-Type, x-requested-with ".parse().unwrap(),
);
let values: Vec<&str> = get_cors_request_headers(&headers).collect();
assert_eq!(values, vec!["Content-Type", "x-requested-with"]);
}
}
//! Access control based on HTTP headers
pub mod cors;
pub mod origin;
pub mod host;
mod matcher;
pub use cors::{AllowHeaders, AllowOrigin, Origin};
pub use host::{AllowHosts, Host};
pub use origin::{AllowOrigins, Origin};
pub use host::{Host, AllowHosts};
use crate::Error;
use self::cors::get_cors_allow_origin;
/// Define access on control on HTTP layer.
#[derive(Clone, Debug)]
pub struct AccessControl {
allowed_hosts: AllowHosts,
allowed_origins: Option<Vec<AllowOrigin>>,
allowed_headers: AllowHeaders,
allowed_origins: AllowOrigins,
}
impl AccessControl {
......@@ -32,46 +29,13 @@ impl AccessControl {
/// `host` is the return value from the `host header`
/// `origin` is the value from the `origin header`.
pub fn verify_origin(&self, origin: Option<&str>, host: &str) -> Result<(), Error> {
if let cors::AllowCors::Invalid = get_cors_allow_origin(origin, &self.allowed_origins, Some(host)) {
Err(Error::HttpHeaderRejected("origin", origin.unwrap_or("<missing>").into()))
} else {
Ok(())
}
}
/// Validate incoming request by CORS(`access-control-request-headers`).
///
/// header_name: all keys of the header in the request
/// cors_request_headers: values of `access-control-request-headers` headers.
///
pub fn verify_headers<T, I, II>(&self, header_names: I, cors_request_headers: II) -> Result<(), Error>
where
T: AsRef<str>,
I: Iterator<Item = T>,
II: Iterator<Item = T>,
{
let header =
cors::get_cors_allow_headers(header_names, cors_request_headers, &self.allowed_headers, |name| name);
if let cors::AllowCors::Invalid = header {
Err(Error::HttpHeaderRejected(
"access-control-request-headers",
"<too inefficient to displayed; use wireshark or something similar to find the header values>".into(),
))
} else {
Ok(())
}
}
/// Return the allowed headers we've set
pub fn allowed_headers(&self) -> &AllowHeaders {
&self.allowed_headers
self.allowed_origins.verify(origin, host)
}
}
impl Default for AccessControl {
fn default() -> Self {
Self { allowed_hosts: AllowHosts::Any, allowed_origins: None, allowed_headers: AllowHeaders::Any }
Self { allowed_hosts: AllowHosts::Any, allowed_origins: AllowOrigins::Any }
}
}
......@@ -79,13 +43,12 @@ impl Default for AccessControl {
#[derive(Debug)]
pub struct AccessControlBuilder {
allowed_hosts: AllowHosts,
allowed_origins: Option<Vec<AllowOrigin>>,
allowed_headers: AllowHeaders,
allowed_origins: AllowOrigins,
}
impl Default for AccessControlBuilder {
fn default() -> Self {
Self { allowed_hosts: AllowHosts::Any, allowed_origins: None, allowed_headers: AllowHeaders::Any }
Self { allowed_hosts: AllowHosts::Any, allowed_origins: AllowOrigins::Any }
}
}
......@@ -103,13 +66,7 @@ impl AccessControlBuilder {
/// Allow all origins.
pub fn allow_all_origins(mut self) -> Self {
self.allowed_origins = None;
self
}
/// Allow all headers.
pub fn allow_all_headers(mut self) -> Self {
self.allowed_headers = AllowHeaders::Any;
self.allowed_origins = AllowOrigins::Any;
self
}
......@@ -132,32 +89,16 @@ impl AccessControlBuilder {
/// Configure allowed origins.
///
/// Default - allow all.
pub fn set_allowed_origins<Origin, List>(mut self, list: List) -> Result<Self, Error>
pub fn set_allowed_origins<O, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Origin>,
Origin: Into<String>,
List: IntoIterator<Item = O>,
O: Into<String>,
{
let allowed_origins: Vec<AllowOrigin> = list.into_iter().map(Into::into).map(Into::into).collect();
let allowed_origins: Vec<Origin> = list.into_iter().map(Into::into).map(Into::into).collect();
if allowed_origins.is_empty() {
return Err(Error::EmptyAllowList("Origin"));
}
self.allowed_origins = Some(allowed_origins);
Ok(self)
}
/// Configure allowed CORS headers.
///
/// Default - allow all.
pub fn set_allowed_headers<Header, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Header>,
Header: Into<String>,
{
let allowed_headers: Vec<String> = list.into_iter().map(Into::into).collect();
if allowed_headers.is_empty() {
return Err(Error::EmptyAllowList("Header"));
}
self.allowed_headers = AllowHeaders::Only(allowed_headers);
self.allowed_origins = AllowOrigins::Only(allowed_origins);
Ok(self)
}
......@@ -166,7 +107,6 @@ impl AccessControlBuilder {
AccessControl {
allowed_hosts: self.allowed_hosts,
allowed_origins: self.allowed_origins,
allowed_headers: self.allowed_headers,
}
}
}
......@@ -24,16 +24,13 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//! CORS handling utility functions
//! Origin filtering functions
use std::collections::HashSet;
use std::{fmt, ops};
use crate::server::access_control::host::{Host, Port};
use crate::server::access_control::matcher::{Matcher, Pattern};
use crate::Cow;
use lazy_static::lazy_static;
use unicase::Ascii;
use crate::Error;
/// Origin Protocol
#[derive(Clone, Hash, Debug, PartialEq, Eq)]
......@@ -48,25 +45,25 @@ pub enum OriginProtocol {
/// Request Origin
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub struct Origin {
pub struct InnerOrigin {
protocol: OriginProtocol,
host: Host,
host_with_proto: String,
matcher: Matcher,
}
impl<T: AsRef<str>> From<T> for Origin {
impl<T: AsRef<str>> From<T> for InnerOrigin {
fn from(origin: T) -> Self {
Origin::parse(origin.as_ref())
InnerOrigin::parse(origin.as_ref())
}
}
impl Origin {
impl InnerOrigin {
fn with_host(protocol: OriginProtocol, host: Host) -> Self {
let host_with_proto = Self::host_with_proto(&protocol, &host);
let matcher = Matcher::new(&host_with_proto);
Origin { protocol, host, host_with_proto, matcher }
InnerOrigin { protocol, host, host_with_proto, matcher }
}
/// Creates new origin given protocol, hostname and port parts.
......@@ -97,7 +94,7 @@ impl Origin {
Some(other) => OriginProtocol::Custom(other),
};
Origin::with_host(protocol, hostname)
InnerOrigin::with_host(protocol, hostname)
}
fn host_with_proto(protocol: &OriginProtocol, host: &Host) -> String {
......@@ -113,38 +110,53 @@ impl Origin {
}
}
impl Pattern for Origin {
impl Pattern for InnerOrigin {
fn matches<T: AsRef<str>>(&self, other: T) -> bool {
self.matcher.matches(other)
}
}
impl ops::Deref for Origin {
impl ops::Deref for InnerOrigin {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.host_with_proto
}
}
/// Origins allowed to access
#[derive(Debug, Clone, PartialEq, Eq)]
/// Origin type allowed to access.
#[allow(clippy::large_enum_variant)]
pub enum AllowOrigin {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Origin {
/// Specific origin.
Origin(Origin),
/// null-origin (file:///, sandboxed iframe)
Origin(InnerOrigin),
/// Null origin (file:///, sandboxed iframe).
Null,
/// Any non-null origin
Any,
/// Allow all origins i.e, the literal value "*" which is regarded as a wildcard.
Wildcard,
}
impl fmt::Display for AllowOrigin {
impl Pattern for Origin {
fn matches<T: AsRef<str>>(&self, other: T) -> bool {
if other.as_ref() == "null" {
return *self == Origin::Null;
}
match self {
Origin::Wildcard => true,
Origin::Null => false,
Origin::Origin(ref origin) => origin.matches(other),
}
}
}
impl fmt::Display for Origin {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
match *self {
Self::Any => "*",
Self::Wildcard => "*",
Self::Null => "null",
Self::Origin(ref val) => val,
}
......@@ -152,194 +164,58 @@ impl fmt::Display for AllowOrigin {
}
}
impl<T: Into<String>> From<T> for AllowOrigin {
impl<T: Into<String>> From<T> for Origin {
fn from(s: T) -> Self {
match s.into().as_str() {
"all" | "*" | "any" => Self::Any,
"all" | "*" | "any" => Self::Wildcard,
"null" => Self::Null,
origin => Self::Origin(origin.into()),
}
}
}
/// Headers allowed to access
#[derive(Debug, Clone, PartialEq)]
pub enum AllowHeaders {
/// Specific headers
Only(Vec<String>),
/// Any header
/// Policy for validating the `HTTP origin header`.
#[derive(Clone, Debug)]
pub enum AllowOrigins {
/// Allow all origins (no filter).
Any,
/// Allow only specified origins.
Only(Vec<Origin>),
}
impl AllowHeaders {
/// Return an appropriate value for the CORS header "Access-Control-Allow-Headers".
pub fn to_cors_header_value(&self) -> Cow<'_, str> {
match self {
AllowHeaders::Any => "*".into(),
AllowHeaders::Only(headers) => headers.join(", ").into(),
}
}
}
/// CORS response headers
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AllowCors<T> {
/// CORS header was not required. Origin is not present in the request.
NotRequired,
/// CORS header is not returned, Origin is not allowed to access the resource.
Invalid,
/// CORS header to include in the response. Origin is allowed to access the resource.
Ok(T),
}
impl<T> AllowCors<T> {
/// Maps `Ok` variant of `AllowCors`.
pub fn map<F, O>(self, f: F) -> AllowCors<O>
where
F: FnOnce(T) -> O,
{
use self::AllowCors::*;
match self {
NotRequired => NotRequired,
Invalid => Invalid,
Ok(val) => Ok(f(val)),
}
}
}
impl<T> From<AllowCors<T>> for Option<T> {
fn from(cors: AllowCors<T>) -> Option<T> {
use self::AllowCors::*;
impl AllowOrigins {
/// Verify a origin.
pub fn verify(&self, origin: Option<&str>, host: &str) -> Result<(), Error> {
// Nothing to be checked if origin is not part of the request's headers.
let origin = match origin {
Some(ref origin) => origin,
None => return Ok(()),
};
match cors {
NotRequired | Invalid => None,
Ok(header) => Some(header),
// Requests initiated from the same server are allowed.
if origin.ends_with(host) {
// Additional check
let origin = InnerOrigin::parse(origin);
if &*origin.host == host {
return Ok(());
}
}
}
}
/// Returns correct CORS header (if any) given list of allowed origins and current origin.
pub(super) fn get_cors_allow_origin(
origin: Option<&str>,
allowed: &Option<Vec<AllowOrigin>>,
host: Option<&str>,
) -> AllowCors<AllowOrigin> {
match origin {
None => AllowCors::NotRequired,
Some(ref origin) => {
if let Some(host) = host {
// Request initiated from the same server.
if origin.ends_with(host) {
// Additional check
let origin = Origin::parse(origin);
if &*origin.host == host {
return AllowCors::NotRequired;
}
match self {
AllowOrigins::Any => return Ok(()),
AllowOrigins::Only(list) => {
if !list.iter().any(|allowed_origin| allowed_origin.matches(*origin)) {
return Err(Error::HttpHeaderRejected("origin", origin.to_string()));
}
}
match allowed.as_ref() {
None if *origin == "null" => AllowCors::Ok(AllowOrigin::Null),
None => AllowCors::Ok(AllowOrigin::Origin(Origin::parse(origin))),
Some(allowed) if *origin == "null" => allowed
.iter()
.find(|cors| **cors == AllowOrigin::Null)
.cloned()
.map(AllowCors::Ok)
.unwrap_or(AllowCors::Invalid),
Some(allowed) => allowed
.iter()
.find(|cors| match **cors {
AllowOrigin::Any => true,
AllowOrigin::Origin(ref val) if val.matches(origin) => true,
_ => false,
})
.map(|_| AllowOrigin::Origin(Origin::parse(origin)))
.map(AllowCors::Ok)
.unwrap_or(AllowCors::Invalid),
}
}
}
}
/// Validates if the headers in the request are allowed.
///
/// headers: all the headers in the request.
/// cors_request_headers: `values` in the `access-control-request-headers` header.
/// cors_allow_headers: whitelisted headers by the user.
pub(crate) fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
mut headers: impl Iterator<Item = T>,
cors_request_headers: impl Iterator<Item = T>,
cors_allow_headers: &AllowHeaders,
to_result: F,
) -> AllowCors<Vec<O>> {
// Check if the header fields which were sent in the request are allowed
if let AllowHeaders::Only(only) = cors_allow_headers {
let are_all_allowed = headers.all(|header| {
let name = &Ascii::new(header.as_ref());
only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
});
if !are_all_allowed {
return AllowCors::Invalid;
}
Ok(())
}
// Check if `AccessControlRequestHeaders` contains fields which were allowed
let (filtered, headers) = match cors_allow_headers {
AllowHeaders::Any => {
let headers = cors_request_headers.map(to_result).collect();
(false, headers)
}
AllowHeaders::Only(only) => {
let mut filtered = false;
let headers: Vec<_> = cors_request_headers
.filter(|header| {
let name = &Ascii::new(header.as_ref());
filtered = true;
only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
})
.map(to_result)
.collect();
(filtered, headers)
}
};
if headers.is_empty() {
if filtered {
AllowCors::Invalid
} else {
AllowCors::NotRequired
}
} else {
AllowCors::Ok(headers)
}
}
lazy_static! {
/// Returns headers which are always allowed.
static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
let mut hs = HashSet::new();
hs.insert(Ascii::new("Accept"));
hs.insert(Ascii::new("Accept-Language"));
hs.insert(Ascii::new("Access-Control-Request-Headers"));
hs.insert(Ascii::new("Content-Language"));
hs.insert(Ascii::new("Content-Type"));
hs.insert(Ascii::new("Host"));
hs.insert(Ascii::new("Origin"));
hs.insert(Ascii::new("Content-Length"));
hs.insert(Ascii::new("Connection"));
hs.insert(Ascii::new("User-Agent"));
hs
};
}
#[cfg(test)]
mod tests {
use std::iter;
use super::*;
use crate::server::access_control::host::Host;
......@@ -347,251 +223,138 @@ mod tests {
fn should_parse_origin() {
use self::OriginProtocol::*;
assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None));
assert_eq!(Origin::parse("https://parity.io:8443"), Origin::new(Https, "parity.io", Some(8443)));
assert_eq!(InnerOrigin::parse("http://parity.io"), InnerOrigin::new(Http, "parity.io", None));
assert_eq!(InnerOrigin::parse("https://parity.io:8443"), InnerOrigin::new(Https, "parity.io", Some(8443)));
assert_eq!(
Origin::parse("chrome-extension://124.0.0.1"),
Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
InnerOrigin::parse("chrome-extension://124.0.0.1"),
InnerOrigin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
);
assert_eq!(Origin::parse("parity.io/somepath"), Origin::new(Http, "parity.io", None));
assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545)));
assert_eq!(InnerOrigin::parse("parity.io/somepath"), InnerOrigin::new(Http, "parity.io", None));
assert_eq!(InnerOrigin::parse("127.0.0.1:8545/somepath"), InnerOrigin::new(Http, "127.0.0.1", Some(8545)));
}