Unverified Commit baf0e6bc authored by Niklas Adolfsson's avatar Niklas Adolfsson
Browse files

ws server; expose headers in middleware

parent 11392d39
Pipeline #199363 passed with stages
in 6 minutes and 2 seconds
......@@ -38,6 +38,8 @@ use futures_channel::{mpsc, oneshot};
use futures_util::future::{Either, FutureExt};
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use http::header::{HOST, ORIGIN};
use http::{HeaderMap, HeaderValue};
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::server::access_control::AccessControl;
......@@ -248,7 +250,7 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
}
HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware, id_provider } => {
tracing::debug!("Accepting new connection: {}", conn_id);
let key = {
let key_and_headers = {
let req = server.receive_request().await?;
let host = std::str::from_utf8(req.headers().host)
......@@ -264,13 +266,28 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
let host_check = cfg.access_control.verify_host(host);
let origin_check = cfg.access_control.verify_origin(origin, host);
host_check.and(origin_check).map(|()| req.key())
host_check.and(origin_check).map(|()| {
let key = req.key();
let mut headers = HeaderMap::new();
if let Ok(val) = HeaderValue::from_str(host) {
headers.insert(HOST, val);
}
if let Some(Ok(val)) = origin.map(HeaderValue::from_str) {
headers.insert(ORIGIN, val);
}
(key, headers)
})
};
match key {
Ok(key) => {
let headers = match key_and_headers {
Ok((key, headers)) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
headers
}
Err(err) => {
tracing::warn!("Rejected connection: {:?}", err);
......@@ -279,23 +296,24 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
return Err(err);
}
}
};
let join_result = tokio::spawn(background_task(
let join_result = tokio::spawn(background_task(BackgroundTask {
server,
conn_id,
methods.clone(),
resources.clone(),
cfg.max_request_body_size,
cfg.max_response_body_size,
cfg.batch_requests_supported,
BoundedSubscriptions::new(cfg.max_subscriptions_per_connection),
stop_monitor.clone(),
methods: methods.clone(),
resources: resources.clone(),
max_request_body_size: cfg.max_request_body_size,
max_response_body_size: cfg.max_response_body_size,
batch_requests_supported: cfg.batch_requests_supported,
bounded_subscriptions: BoundedSubscriptions::new(cfg.max_subscriptions_per_connection),
stop_server: stop_monitor.clone(),
middleware,
id_provider,
cfg.ping_interval,
ping_interval: cfg.ping_interval,
remote_addr,
))
headers,
}))
.await;
match join_result {
......@@ -306,8 +324,8 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
}
}
async fn background_task(
server: SokettoServer<'_, BufReader<BufWriter<Compat<tokio::net::TcpStream>>>>,
struct BackgroundTask<'a, M> {
server: SokettoServer<'a, BufReader<BufWriter<Compat<tokio::net::TcpStream>>>>,
conn_id: ConnectionId,
methods: Methods,
resources: Resources,
......@@ -316,11 +334,31 @@ async fn background_task(
batch_requests_supported: bool,
bounded_subscriptions: BoundedSubscriptions,
stop_server: StopMonitor,
middleware: impl Middleware,
middleware: M,
id_provider: Arc<dyn IdProvider>,
ping_interval: Duration,
remote_addr: SocketAddr,
) -> Result<(), Error> {
headers: HeaderMap,
}
async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<(), Error> {
let BackgroundTask {
server,
conn_id,
methods,
resources,
max_request_body_size,
max_response_body_size,
batch_requests_supported,
bounded_subscriptions,
stop_server,
middleware,
id_provider,
ping_interval,
remote_addr,
headers,
} = input;
// And we can finally transition to a websocket background_task.
let mut builder = server.into_builder();
builder.set_max_message_size(max_request_body_size as usize);
......@@ -430,7 +468,6 @@ async fn background_task(
tracing::debug!("recv {} bytes", data.len());
let headers = http::HeaderMap::new();
let request_start = middleware.on_request(remote_addr, &headers);
let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace());
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment