Unverified Commit 538854bc authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

fix(ws server): reply HTTP 403 on all failed conns (#819)

parent 96863bc2
Pipeline #203475 passed with stages
in 4 minutes and 40 seconds
......@@ -56,6 +56,7 @@ use jsonrpsee_types::error::{reject_too_big_request, reject_too_many_subscriptio
use jsonrpsee_types::Params;
use soketto::connection::Error as SokettoError;
use soketto::data::ByteSlice125;
use soketto::handshake::WebSocketKey;
use soketto::handshake::{server::Response, Server as SokettoServer};
use soketto::Sender;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
......@@ -253,45 +254,12 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
}
HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware, id_provider } => {
tracing::debug!("Accepting new connection: {}", conn_id);
let key = {
let req = server.receive_request().await?;
let host = std::str::from_utf8(req.headers().host)
.map_err(|_e| Error::HttpHeaderRejected("Host", "Invalid UTF-8".to_string()))?;
let origin = req.headers().origin.and_then(|h| {
let res = std::str::from_utf8(h).ok();
if res.is_none() {
tracing::warn!("Origin header invalid UTF-8; treated as no Origin header");
}
res
});
let host_check = cfg.access_control.verify_host(host);
let origin_check = cfg.access_control.verify_origin(origin, host);
let mut headers = HeaderMap::new();
let key = host_check.and(origin_check).map(|()| {
let key = req.key();
if let Ok(val) = HeaderValue::from_str(host) {
headers.insert(HOST, val);
}
if let Some(Ok(val)) = origin.map(HeaderValue::from_str) {
headers.insert(ORIGIN, val);
}
let key_and_headers = get_key_and_headers(&mut server, cfg).await;
key
});
middleware.on_connect(remote_addr, &headers);
key
};
match key {
Ok(key) => {
match key_and_headers {
Ok((key, headers)) => {
middleware.on_connect(remote_addr, &headers);
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
......@@ -1027,3 +995,42 @@ async fn execute_call<M: Middleware>(c: Call<'_, M>) -> MethodResult {
middleware.on_result(name, r.success, request_start);
response
}
/// Helper to fetch the `WebSocketKey` and `Headers` from the WebSocket handshake.
async fn get_key_and_headers(
server: &mut SokettoServer<'_, BufReader<BufWriter<Compat<TcpStream>>>>,
cfg: &Settings,
) -> Result<(WebSocketKey, HeaderMap), Error> {
let req = server.receive_request().await?;
tracing::trace!("Connection request: {:?}", req);
let host = std::str::from_utf8(req.headers().host).map_err(|e| Error::HttpHeaderRejected("Host", e.to_string()))?;
let origin = req.headers().origin.and_then(|h| {
let res = std::str::from_utf8(h).ok();
if res.is_none() {
tracing::warn!("Origin header invalid UTF-8; treated as no Origin header");
}
res
});
let host_check = cfg.access_control.verify_host(host);
let origin_check = cfg.access_control.verify_origin(origin, host);
let mut headers = HeaderMap::new();
host_check.and(origin_check).map(|()| {
let key = req.key();
if let Ok(val) = HeaderValue::from_str(host) {
headers.insert(HOST, val);
}
if let Some(Ok(val)) = origin.map(HeaderValue::from_str) {
headers.insert(ORIGIN, val);
}
(key, headers)
})
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment