Unverified Commit 6e21fa00 authored by Niklas Adolfsson's avatar Niklas Adolfsson
Browse files

different traits for WS and HTTP middleware

parent acdef050
Pipeline #200557 failed with stages
in 5 minutes and 40 seconds
......@@ -26,65 +26,144 @@
//! Middleware for `jsonrpsee` servers.
use http::HeaderMap;
use jsonrpsee_types::Params;
use std::net::SocketAddr;
/// Defines a middleware with callbacks during the RPC request life-cycle. The primary use case for
/// this is to collect timings for a larger metrics collection solution but the only constraints on
/// the associated type is that it be [`Send`] and [`Copy`], giving users some freedom to do what
/// they need to do.
/// Defines a middleware specifically for HTTP requests with callbacks during the RPC request life-cycle.
/// The primary use case for this is to collect timings for a larger metrics collection solution.
///
/// See [`HttpServerBuilder::set_middleware`](../../jsonrpsee_http_server/struct.HttpServerBuilder.html#method.set_middleware) method
/// for examples.
pub trait HttpMiddleware: Send + Sync + Clone + 'static {
/// Intended to carry timestamp of a request, for example `std::time::Instant`. How the middleware
/// measures time, if at all, is entirely up to the implementation.
type Instant: std::fmt::Debug + Send + Sync + Copy;
/// Called when a new JSON-RPC request comes to the server.
fn on_request(&self, remote_addr: SocketAddr, headers: &HeaderMap) -> Self::Instant;
/// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times.
fn on_call(&self, method_name: &str, params: Params);
/// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple times.
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant);
/// Called once the JSON-RPC request is finished and response is sent to the output buffer.
fn on_response(&self, result: &str, _started_at: Self::Instant);
}
/// Defines a middleware specifically for WebSocket connections with callbacks during the RPC request life-cycle.
/// The primary use case for this is to collect timings for a larger metrics collection solution.
///
/// See the [`WsServerBuilder::set_middleware`](../../jsonrpsee_ws_server/struct.WsServerBuilder.html#method.set_middleware)
/// or the [`HttpServerBuilder::set_middleware`](../../jsonrpsee_http_server/struct.HttpServerBuilder.html#method.set_middleware) method
/// for examples.
pub trait Middleware: Send + Sync + Clone + 'static {
pub trait WsMiddleware: Send + Sync + Clone + 'static {
/// Intended to carry timestamp of a request, for example `std::time::Instant`. How the middleware
/// measures time, if at all, is entirely up to the implementation.
type Instant: std::fmt::Debug + Send + Sync + Copy;
/// Called when a new client connects (WebSocket only)
fn on_connect(&self) {}
/// Called when a new client connects
fn on_connect(&self, remote_addr: SocketAddr, headers: &http::HeaderMap);
/// Called when a new JSON-RPC request comes to the server.
fn on_request(&self, remote_addr: std::net::SocketAddr, headers: &http::HeaderMap) -> Self::Instant;
fn on_request(&self) -> Self::Instant;
/// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times.
fn on_call(&self, _name: &str, _params: Params) {}
fn on_call(&self, method_name: &str, params: Params);
/// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple times.
fn on_result(&self, _name: &str, _success: bool, _started_at: Self::Instant) {}
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant);
/// Called once the JSON-RPC request is finished and response is sent to the output buffer.
fn on_response(&self, _result: &str, _started_at: Self::Instant) {}
fn on_response(&self, result: &str, started_at: Self::Instant);
/// Called when a client disconnects
fn on_disconnect(&self, remote_addr: std::net::SocketAddr);
}
impl HttpMiddleware for () {
type Instant = ();
fn on_request(&self, _: std::net::SocketAddr, _: &http::HeaderMap) -> Self::Instant {}
fn on_call(&self, _: &str, _: Params) {}
/// Called when a client disconnects (WebSocket only)
fn on_disconnect(&self) {}
fn on_result(&self, _: &str, _: bool, _: Self::Instant) {}
fn on_response(&self, _: &str, _: Self::Instant) {}
}
impl Middleware for () {
impl WsMiddleware for () {
type Instant = ();
fn on_request(&self, _ip_addr: std::net::SocketAddr, _headers: &http::HeaderMap) -> Self::Instant {}
fn on_connect(&self, _: std::net::SocketAddr, _: &http::HeaderMap) {}
fn on_request(&self) -> Self::Instant {}
fn on_call(&self, _: &str, _: Params) {}
fn on_result(&self, _: &str, _: bool, _: Self::Instant) {}
fn on_response(&self, _: &str, _: Self::Instant) {}
fn on_disconnect(&self, _: std::net::SocketAddr) {}
}
impl<A, B> WsMiddleware for (A, B)
where
A: WsMiddleware,
B: WsMiddleware,
{
type Instant = (A::Instant, B::Instant);
fn on_connect(&self, remote_addr: std::net::SocketAddr, headers: &http::HeaderMap) {
(self.0.on_connect(remote_addr, headers), self.1.on_connect(remote_addr, headers));
}
fn on_request(&self) -> Self::Instant {
(self.0.on_request(), self.1.on_request())
}
fn on_call(&self, method_name: &str, params: Params) {
self.0.on_call(method_name, params.clone());
self.1.on_call(method_name, params);
}
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant) {
self.0.on_result(method_name, success, started_at.0);
self.1.on_result(method_name, success, started_at.1);
}
fn on_response(&self, result: &str, started_at: Self::Instant) {
self.0.on_response(result, started_at.0);
self.1.on_response(result, started_at.1);
}
fn on_disconnect(&self, remote_addr: std::net::SocketAddr) {
(self.0.on_disconnect(remote_addr), self.1.on_disconnect(remote_addr));
}
}
impl<A, B> Middleware for (A, B)
impl<A, B> HttpMiddleware for (A, B)
where
A: Middleware,
B: Middleware,
A: HttpMiddleware,
B: HttpMiddleware,
{
type Instant = (A::Instant, B::Instant);
fn on_request(&self, ip_addr: std::net::SocketAddr, headers: &http::HeaderMap) -> Self::Instant {
(self.0.on_request(ip_addr, headers), self.1.on_request(ip_addr, headers))
fn on_request(&self, remote_addr: std::net::SocketAddr, headers: &HeaderMap) -> Self::Instant {
(self.0.on_request(remote_addr, headers), self.1.on_request(remote_addr, headers))
}
fn on_call(&self, name: &str, params: Params) {
self.0.on_call(name, params.clone());
self.1.on_call(name, params);
fn on_call(&self, method_name: &str, params: Params) {
self.0.on_call(method_name, params.clone());
self.1.on_call(method_name, params);
}
fn on_result(&self, name: &str, success: bool, started_at: Self::Instant) {
self.0.on_result(name, success, started_at.0);
self.1.on_result(name, success, started_at.1);
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant) {
self.0.on_result(method_name, success, started_at.0);
self.1.on_result(method_name, success, started_at.1);
}
fn on_response(&self, result: &str, started_at: Self::Instant) {
......
......@@ -36,7 +36,7 @@ use jsonrpsee::types::Params;
#[derive(Clone)]
struct Timings;
impl middleware::Middleware for Timings {
impl middleware::HttpMiddleware for Timings {
type Instant = Instant;
fn on_request(&self, remote_addr: SocketAddr, headers: &HeaderMap) -> Self::Instant {
......
......@@ -36,11 +36,15 @@ use jsonrpsee::ws_server::{RpcModule, WsServerBuilder};
#[derive(Clone)]
struct Timings;
impl middleware::Middleware for Timings {
impl middleware::WsMiddleware for Timings {
type Instant = Instant;
fn on_request(&self, remote_addr: SocketAddr, headers: &HeaderMap) -> Self::Instant {
println!("[Middleware::on_request] remote_addr {}, headers: {:?}", remote_addr, headers);
fn on_connect(&self, remote_addr: SocketAddr, headers: &HeaderMap) {
println!("[Middleware::on_connect] remote_addr {}, headers: {:?}", remote_addr, headers);
}
fn on_request(&self) -> Self::Instant {
println!("[Middleware::on_request]");
Instant::now()
}
......@@ -55,6 +59,10 @@ impl middleware::Middleware for Timings {
fn on_response(&self, result: &str, started_at: Self::Instant) {
println!("[Middleware::on_response] result: {}, time elapsed {:?}", result, started_at.elapsed());
}
fn on_disconnect(&self, remote_addr: SocketAddr) {
println!("[Middleware::on_disconnect] remote_addr: {}", remote_addr);
}
}
#[tokio::main]
......
......@@ -40,10 +40,14 @@ use jsonrpsee::ws_server::{RpcModule, WsServerBuilder};
#[derive(Clone)]
struct Timings;
impl middleware::Middleware for Timings {
impl middleware::WsMiddleware for Timings {
type Instant = Instant;
fn on_request(&self, _remote_addr: SocketAddr, _headers: &HeaderMap) -> Self::Instant {
fn on_connect(&self, remote_addr: SocketAddr, headers: &HeaderMap) {
println!("[Timings::on_connect] remote_addr {}, headers: {:?}", remote_addr, headers);
}
fn on_request(&self) -> Self::Instant {
Instant::now()
}
......@@ -58,6 +62,10 @@ impl middleware::Middleware for Timings {
fn on_response(&self, _result: &str, started_at: Self::Instant) {
println!("[Timings] Response duration {:?}", started_at.elapsed());
}
fn on_disconnect(&self, remote_addr: SocketAddr) {
println!("[Timings::on_disconnect] remote_addr: {}", remote_addr);
}
}
/// Example middleware to keep a watch on the number of total threads started in the system.
......@@ -79,18 +87,36 @@ impl ThreadWatcher {
}
}
impl middleware::Middleware for ThreadWatcher {
impl middleware::WsMiddleware for ThreadWatcher {
type Instant = isize;
fn on_request(&self, _remote_addr: SocketAddr, _headers: &HeaderMap) -> Self::Instant {
fn on_connect(&self, remote_addr: SocketAddr, headers: &HeaderMap) {
println!("[ThreadWatcher::on_connect] remote_addr {}, headers: {:?}", remote_addr, headers);
}
fn on_call(&self, _method: &str, _params: Params) {
let threads = Self::count_threads();
println!("[ThreadWatcher::on_call] Threads running on the machine at the start of a call: {}", threads);
}
fn on_request(&self) -> Self::Instant {
let threads = Self::count_threads();
println!("[ThreadWatcher] Threads running on the machine at the start of a call: {}", threads);
println!("[ThreadWatcher::on_request] Threads running on the machine at the start of a call: {}", threads);
threads as isize
}
fn on_result(&self, _name: &str, _succees: bool, started_at: Self::Instant) {
let current_nr_threads = Self::count_threads() as isize;
println!("[ThreadWatcher::on_result] {} threads", current_nr_threads - started_at);
}
fn on_response(&self, _result: &str, started_at: Self::Instant) {
let current_nr_threads = Self::count_threads() as isize;
println!("[ThreadWatcher] Request started {} threads", current_nr_threads - started_at);
println!("[ThreadWatcher::on_response] {} threads", current_nr_threads - started_at);
}
fn on_disconnect(&self, remote_addr: SocketAddr) {
println!("[ThreadWatcher::on_disconnect] remote_addr: {}", remote_addr);
}
}
......
......@@ -40,7 +40,7 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::{Error as HyperError, Method};
use jsonrpsee_core::error::{Error, GenericTransportError};
use jsonrpsee_core::http_helpers::{self, read_body};
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::middleware::HttpMiddleware as Middleware;
use jsonrpsee_core::server::access_control::AccessControl;
use jsonrpsee_core::server::helpers::{prepare_error, MethodResponse};
use jsonrpsee_core::server::helpers::{BatchResponse, BatchResponseBuilder};
......@@ -101,22 +101,38 @@ impl<M> Builder<M> {
/// ```
/// use std::{time::Instant, net::SocketAddr};
///
/// use jsonrpsee_core::middleware::Middleware;
/// use jsonrpsee_core::middleware::HttpMiddleware;
/// use jsonrpsee_core::HeaderMap;
/// use jsonrpsee_types::Params;
/// use jsonrpsee_http_server::HttpServerBuilder;
///
/// #[derive(Clone)]
/// struct MyMiddleware;
///
/// impl Middleware for MyMiddleware {
/// impl HttpMiddleware for MyMiddleware {
/// type Instant = Instant;
///
/// // Called once the HTTP request is received, it may be a single JSON-RPC call
/// // or batch.
/// fn on_request(&self, _remote_addr: SocketAddr, _headers: &HeaderMap) -> Instant {
/// Instant::now()
/// }
///
/// // Called once a single JSON-RPC method call is processed, it may be called multiple times
/// // on batches.
/// fn on_call(&self, method_name: &str, params: Params) {
/// println!("Call to method: '{}' params: {:?}", method_name, params);
/// }
///
/// // Called once a single JSON-RPC call is completed, it may be called multiple times
/// // on batches.
/// fn on_result(&self, method_name: &str, success: bool, started_at: Instant) {
/// println!("Call to '{}' took {:?}", method_name, started_at.elapsed());
/// }
///
/// fn on_result(&self, name: &str, success: bool, started_at: Instant) {
/// println!("Call to '{}' took {:?}", name, started_at.elapsed());
/// // Called the entire JSON-RPC is completed, called on once for both single calls or batches.
/// fn on_response(&self, result: &str, started_at: Instant) {
/// println!("complete JSON-RPC response: {}, took: {:?}", result, started_at.elapsed());
/// }
/// }
///
......
......@@ -30,7 +30,8 @@ use std::sync::{Arc, Mutex};
use std::time::Duration;
use hyper::HeaderMap;
use jsonrpsee::core::{client::ClientT, middleware::Middleware, Error};
use jsonrpsee::core::middleware::{HttpMiddleware, WsMiddleware};
use jsonrpsee::core::{client::ClientT, Error};
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::http_server::{HttpServerBuilder, HttpServerHandle};
use jsonrpsee::proc_macros::rpc;
......@@ -55,15 +56,15 @@ struct CounterInner {
calls: HashMap<String, (u32, Vec<u32>)>,
}
impl Middleware for Counter {
impl WsMiddleware for Counter {
/// Auto-incremented id of the call
type Instant = u32;
fn on_connect(&self) {
fn on_connect(&self, _remote_addr: SocketAddr, _headers: &HeaderMap) {
self.inner.lock().unwrap().connections.0 += 1;
}
fn on_request(&self, _remote_addr: SocketAddr, _headers: &HeaderMap) -> u32 {
fn on_request(&self) -> u32 {
let mut inner = self.inner.lock().unwrap();
let n = inner.requests.0;
......@@ -89,11 +90,42 @@ impl Middleware for Counter {
self.inner.lock().unwrap().requests.1 += 1;
}
fn on_disconnect(&self) {
fn on_disconnect(&self, _remote_addr: SocketAddr) {
self.inner.lock().unwrap().connections.1 += 1;
}
}
impl HttpMiddleware for Counter {
/// Auto-incremented id of the call
type Instant = u32;
fn on_request(&self, _remote_addr: SocketAddr, _headers: &HeaderMap) -> u32 {
let mut inner = self.inner.lock().unwrap();
let n = inner.requests.0;
inner.requests.0 += 1;
n
}
fn on_call(&self, name: &str, _params: Params) {
let mut inner = self.inner.lock().unwrap();
let entry = inner.calls.entry(name.into()).or_insert((0, Vec::new()));
entry.0 += 1;
}
fn on_result(&self, name: &str, success: bool, n: u32) {
if success {
self.inner.lock().unwrap().calls.get_mut(name).unwrap().1.push(n);
}
}
fn on_response(&self, _result: &str, _: u32) {
self.inner.lock().unwrap().requests.1 += 1;
}
}
fn test_module() -> RpcModule<()> {
#[rpc(server)]
pub trait Rpc {
......
......@@ -41,7 +41,7 @@ use futures_util::stream::StreamExt;
use http::header::{HOST, ORIGIN};
use http::{HeaderMap, HeaderValue};
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::middleware::WsMiddleware as Middleware;
use jsonrpsee_core::server::access_control::AccessControl;
use jsonrpsee_core::server::helpers::{
prepare_error, BatchResponse, BatchResponseBuilder, BoundedSubscriptions, MethodResponse, MethodSink,
......@@ -252,7 +252,7 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
}
HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware, id_provider } => {
tracing::debug!("Accepting new connection: {}", conn_id);
let key_and_headers = {
let key = {
let req = server.receive_request().await?;
let host = std::str::from_utf8(req.headers().host)
......@@ -268,10 +268,10 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
let host_check = cfg.access_control.verify_host(host);
let origin_check = cfg.access_control.verify_origin(origin, host);
host_check.and(origin_check).map(|()| {
let key = req.key();
let mut headers = HeaderMap::new();
let mut headers = HeaderMap::new();
let key = host_check.and(origin_check).map(|()| {
let key = req.key();
if let Ok(val) = HeaderValue::from_str(host) {
headers.insert(HOST, val);
......@@ -281,15 +281,18 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
headers.insert(ORIGIN, val);
}
(key, headers)
})
key
});
middleware.on_connect(remote_addr, &headers);
key
};
let headers = match key_and_headers {
Ok((key, headers)) => {
match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
headers
}
Err(err) => {
tracing::warn!("Rejected connection: {:?}", err);
......@@ -315,7 +318,6 @@ async fn handshake<M: Middleware>(socket: tokio::net::TcpStream, mode: Handshake
id_provider,
ping_interval: cfg.ping_interval,
remote_addr,
headers,
}))
.await;
......@@ -342,7 +344,6 @@ struct BackgroundTask<'a, M> {
id_provider: Arc<dyn IdProvider>,
ping_interval: Duration,
remote_addr: SocketAddr,
headers: HeaderMap,
}
async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<(), Error> {
......@@ -361,7 +362,6 @@ async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<
id_provider,
ping_interval,
remote_addr,
headers,
} = input;
// And we can finally transition to a websocket background_task.
......@@ -374,8 +374,6 @@ async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<
let stop_server2 = stop_server.clone();
let sink = MethodSink::new_with_limit(tx, max_response_body_size, max_log_length);
middleware.on_connect();
// Send results back to the client.
tokio::spawn(async move {
// Received messages from the WebSocket.
......@@ -471,7 +469,7 @@ async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<
};
};
let request_start = middleware.on_request(remote_addr, &headers);
let request_start = middleware.on_request();
let first_non_whitespace = data.iter().find(|byte| !byte.is_ascii_whitespace());
match first_non_whitespace {
......@@ -559,7 +557,7 @@ async fn background_task<M: Middleware>(input: BackgroundTask<'_, M>) -> Result<
}
};
middleware.on_disconnect();
middleware.on_disconnect(remote_addr);
// Drive all running methods to completion.
// **NOTE** Do not return early in this function. This `await` needs to run to guarantee
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment