Unverified Commit 3ee635ff authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

middleware refactoring (#793)

* WIP: refactoring

* refactor http server

* fix tests

* Delete TODO.txt

* fix tests again

* add benches/src/lib.rs

* remove bench changes; fast less deps

* no more env_logger

* update examples

* ws server; expose headers in middleware

* add back uncommented code

* fix nits

* make the code more readable

* add back the tracing stuff

* simplify code but one extra clone

* fix tests again

* revert async accept API

* fix nits

* different traits for WS and HTTP middleware

* fix tests

* revert benchmark change

* Update core/src/server/helpers.rs

* Update ws-server/Cargo.toml

* add limit to batch responses as well

* pre-allocate string for batches

* small refactor
parent d974914f
Pipeline #201709 passed with stages
in 5 minutes
......@@ -15,7 +15,7 @@ jsonrpsee-client-transport = { path = "../transport", version = "0.14.0", featur
jsonrpsee-core = { path = "../../core", version = "0.14.0", features = ["async-wasm-client"] }
[dev-dependencies]
env_logger = "0.9"
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
jsonrpsee-test-utils = { path = "../../test-utils" }
tokio = { version = "1", features = ["macros"] }
serde_json = "1"
......@@ -15,7 +15,7 @@ jsonrpsee-client-transport = { path = "../transport", version = "0.14.0", featur
jsonrpsee-core = { path = "../../core", version = "0.14.0", features = ["async-client"] }
[dev-dependencies]
env_logger = "0.9"
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
jsonrpsee-test-utils = { path = "../../test-utils" }
tokio = { version = "1", features = ["macros"] }
serde_json = "1"
......
......@@ -205,7 +205,7 @@ async fn notification_without_polling_doesnt_make_client_unuseable() {
client.subscribe_to_method("test").with_default_timeout().await.unwrap().unwrap();
// don't poll the notification stream for 2 seconds, should be full now.
std::thread::sleep(std::time::Duration::from_secs(2));
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
// Capacity is `num_sender` + `capacity`
for _ in 0..5 {
......@@ -244,6 +244,11 @@ async fn batch_request_out_of_order_response() {
#[tokio::test]
async fn is_connected_works() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");
let server = WebSocketTestServer::with_hardcoded_response(
"127.0.0.1:0".parse().unwrap(),
ok_response(JsonValue::String("foo".into()), Id::Num(99_u64)),
......@@ -254,9 +259,11 @@ async fn is_connected_works() {
let uri = to_ws_uri_string(server.local_addr());
let client = WsClientBuilder::default().build(&uri).with_default_timeout().await.unwrap().unwrap();
assert!(client.is_connected());
client.request::<String>("say_hello", None).with_default_timeout().await.unwrap().unwrap_err();
// give the background thread some time to terminate.
std::thread::sleep(std::time::Duration::from_millis(100));
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!client.is_connected())
}
......@@ -295,7 +302,6 @@ fn assert_error_response(err: Error, exp: ErrorObjectOwned) {
#[tokio::test]
async fn redirections() {
let _ = env_logger::try_init();
let expected = "abc 123";
let server = WebSocketTestServer::with_hardcoded_response(
"127.0.0.1:0".parse().unwrap(),
......
......@@ -34,6 +34,7 @@ futures-timer = { version = "3", optional = true }
globset = { version = "0.4", optional = true }
lazy_static = { version = "1", optional = true }
unicase = { version = "2.6.0", optional = true }
http = { version = "0.2.7", optional = true }
[features]
default = []
......@@ -49,6 +50,7 @@ server = [
"tokio/sync",
"lazy_static",
"unicase",
"http",
]
client = ["futures-util/sink", "futures-channel/sink", "futures-channel/std"]
async-client = [
......
......@@ -39,9 +39,6 @@ pub mod error;
/// Traits
pub mod traits;
/// Middleware trait and implementation.
pub mod middleware;
cfg_http_helpers! {
pub mod http_helpers;
}
......@@ -49,6 +46,7 @@ cfg_http_helpers! {
cfg_server! {
pub mod id_providers;
pub mod server;
pub mod middleware;
}
cfg_client! {
......
......@@ -26,67 +26,149 @@
//! Middleware for `jsonrpsee` servers.
/// Defines a middleware with callbacks during the RPC request life-cycle. The primary use case for
/// this is to collect timings for a larger metrics collection solution but the only constraints on
/// the associated type is that it be [`Send`] and [`Copy`], giving users some freedom to do what
/// they need to do.
use std::net::SocketAddr;
pub use http::HeaderMap as Headers;
pub use jsonrpsee_types::Params;
/// Defines a middleware specifically for HTTP requests with callbacks during the RPC request life-cycle.
/// The primary use case for this is to collect timings for a larger metrics collection solution.
///
/// See [`HttpServerBuilder::set_middleware`](../../jsonrpsee_http_server/struct.HttpServerBuilder.html#method.set_middleware) method
/// for examples.
pub trait HttpMiddleware: Send + Sync + Clone + 'static {
/// Intended to carry timestamp of a request, for example `std::time::Instant`. How the middleware
/// measures time, if at all, is entirely up to the implementation.
type Instant: std::fmt::Debug + Send + Sync + Copy;
/// Called when a new JSON-RPC request comes to the server.
fn on_request(&self, remote_addr: SocketAddr, headers: &Headers) -> Self::Instant;
/// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times.
fn on_call(&self, method_name: &str, params: Params);
/// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple times.
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant);
/// Called once the JSON-RPC request is finished and response is sent to the output buffer.
fn on_response(&self, result: &str, _started_at: Self::Instant);
}
/// Defines a middleware specifically for WebSocket connections with callbacks during the RPC request life-cycle.
/// The primary use case for this is to collect timings for a larger metrics collection solution.
///
/// See the [`WsServerBuilder::set_middleware`](../../jsonrpsee_ws_server/struct.WsServerBuilder.html#method.set_middleware)
/// or the [`HttpServerBuilder::set_middleware`](../../jsonrpsee_http_server/struct.HttpServerBuilder.html#method.set_middleware) method
/// for examples.
pub trait Middleware: Send + Sync + Clone + 'static {
pub trait WsMiddleware: Send + Sync + Clone + 'static {
/// Intended to carry timestamp of a request, for example `std::time::Instant`. How the middleware
/// measures time, if at all, is entirely up to the implementation.
type Instant: Send + Copy;
type Instant: std::fmt::Debug + Send + Sync + Copy;
/// Called when a new client connects (WebSocket only)
fn on_connect(&self) {}
/// Called when a new client connects
fn on_connect(&self, remote_addr: SocketAddr, headers: &Headers);
/// Called when a new JSON-RPC comes to the server.
/// Called when a new JSON-RPC request comes to the server.
fn on_request(&self) -> Self::Instant;
/// Called on each JSON-RPC method call, batch requests will trigger `on_call` multiple times.
fn on_call(&self, _name: &str) {}
fn on_call(&self, method_name: &str, params: Params);
/// Called on each JSON-RPC method completion, batch requests will trigger `on_result` multiple times.
fn on_result(&self, _name: &str, _success: bool, _started_at: Self::Instant) {}
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant);
/// Called once the JSON-RPC request is finished and response is sent to the output buffer.
fn on_response(&self, _started_at: Self::Instant) {}
fn on_response(&self, result: &str, started_at: Self::Instant);
/// Called when a client disconnects
fn on_disconnect(&self, remote_addr: std::net::SocketAddr);
}
impl HttpMiddleware for () {
type Instant = ();
fn on_request(&self, _: std::net::SocketAddr, _: &Headers) -> Self::Instant {}
fn on_call(&self, _: &str, _: Params) {}
fn on_result(&self, _: &str, _: bool, _: Self::Instant) {}
/// Called when a client disconnects (WebSocket only)
fn on_disconnect(&self) {}
fn on_response(&self, _: &str, _: Self::Instant) {}
}
impl Middleware for () {
impl WsMiddleware for () {
type Instant = ();
fn on_connect(&self, _: std::net::SocketAddr, _: &Headers) {}
fn on_request(&self) -> Self::Instant {}
fn on_call(&self, _: &str, _: Params) {}
fn on_result(&self, _: &str, _: bool, _: Self::Instant) {}
fn on_response(&self, _: &str, _: Self::Instant) {}
fn on_disconnect(&self, _: std::net::SocketAddr) {}
}
impl<A, B> Middleware for (A, B)
impl<A, B> WsMiddleware for (A, B)
where
A: Middleware,
B: Middleware,
A: WsMiddleware,
B: WsMiddleware,
{
type Instant = (A::Instant, B::Instant);
fn on_connect(&self, remote_addr: std::net::SocketAddr, headers: &Headers) {
(self.0.on_connect(remote_addr, headers), self.1.on_connect(remote_addr, headers));
}
fn on_request(&self) -> Self::Instant {
(self.0.on_request(), self.1.on_request())
}
fn on_call(&self, name: &str) {
self.0.on_call(name);
self.1.on_call(name);
fn on_call(&self, method_name: &str, params: Params) {
self.0.on_call(method_name, params.clone());
self.1.on_call(method_name, params);
}
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant) {
self.0.on_result(method_name, success, started_at.0);
self.1.on_result(method_name, success, started_at.1);
}
fn on_response(&self, result: &str, started_at: Self::Instant) {
self.0.on_response(result, started_at.0);
self.1.on_response(result, started_at.1);
}
fn on_disconnect(&self, remote_addr: std::net::SocketAddr) {
(self.0.on_disconnect(remote_addr), self.1.on_disconnect(remote_addr));
}
}
impl<A, B> HttpMiddleware for (A, B)
where
A: HttpMiddleware,
B: HttpMiddleware,
{
type Instant = (A::Instant, B::Instant);
fn on_request(&self, remote_addr: std::net::SocketAddr, headers: &Headers) -> Self::Instant {
(self.0.on_request(remote_addr, headers), self.1.on_request(remote_addr, headers))
}
fn on_call(&self, method_name: &str, params: Params) {
self.0.on_call(method_name, params.clone());
self.1.on_call(method_name, params);
}
fn on_result(&self, name: &str, success: bool, started_at: Self::Instant) {
self.0.on_result(name, success, started_at.0);
self.1.on_result(name, success, started_at.1);
fn on_result(&self, method_name: &str, success: bool, started_at: Self::Instant) {
self.0.on_result(method_name, success, started_at.0);
self.1.on_result(method_name, success, started_at.1);
}
fn on_response(&self, started_at: Self::Instant) {
self.0.on_response(started_at.0);
self.1.on_response(started_at.1);
fn on_response(&self, result: &str, started_at: Self::Instant) {
self.0.on_response(result, started_at.0);
self.1.on_response(result, started_at.1);
}
}
......@@ -28,9 +28,8 @@ use std::io;
use std::sync::Arc;
use crate::tracing::tx_log_from_str;
use crate::{Error};
use crate::Error;
use futures_channel::mpsc;
use futures_util::StreamExt;
use jsonrpsee_types::error::{ErrorCode, ErrorObject, ErrorResponse, OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG};
use jsonrpsee_types::{Id, InvalidRequest, Response};
use serde::Serialize;
......@@ -108,39 +107,6 @@ impl MethodSink {
self.tx.is_closed()
}
/// Send a JSON-RPC response to the client. If the serialization of `result` exceeds `max_response_size`,
/// an error will be sent instead.
pub fn send_response(&self, id: Id, result: impl Serialize) -> bool {
let mut writer = BoundedWriter::new(self.max_response_size as usize);
let json = match serde_json::to_writer(&mut writer, &Response::new(result, id.clone())) {
Ok(_) => {
// Safety - serde_json does not emit invalid UTF-8.
unsafe { String::from_utf8_unchecked(writer.into_bytes()) }
}
Err(err) => {
tracing::error!("Error serializing response: {:?}", err);
if err.is_io() {
let data = format!("Exceeded max limit of {}", self.max_response_size);
let err = ErrorObject::owned(OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG, Some(data));
return self.send_error(id, err);
} else {
return self.send_error(id, ErrorCode::InternalError.into());
}
}
};
tx_log_from_str(&json, self.max_log_length);
if let Err(err) = self.send_raw(json) {
tracing::warn!("Error sending response {:?}", err);
false
} else {
true
}
}
/// Send a JSON-RPC error to the client
pub fn send_error(&self, id: Id, error: ErrorObject) -> bool {
let json = match serde_json::to_string(&ErrorResponse::borrowed(error, id)) {
......@@ -169,15 +135,20 @@ impl MethodSink {
/// Send a raw JSON-RPC message to the client, `MethodSink` does not check verify the validity
/// of the JSON being sent.
pub fn send_raw(&self, raw_json: String) -> Result<(), mpsc::TrySendError<String>> {
tracing::trace!("send: {:?}", raw_json);
self.tx.unbounded_send(raw_json)
pub fn send_raw(&self, json: String) -> Result<(), mpsc::TrySendError<String>> {
tx_log_from_str(&json, self.max_log_length);
self.tx.unbounded_send(json)
}
/// Close the channel for any further messages.
pub fn close(&self) {
self.tx.close_channel();
}
/// Get the maximum number of permitted subscriptions.
pub const fn max_response_size(&self) -> u32 {
self.max_response_size
}
}
/// Figure out if this is a sufficiently complete request that we can extract an [`Id`] out of, or just plain
......@@ -189,24 +160,6 @@ pub fn prepare_error(data: &[u8]) -> (Id<'_>, ErrorCode) {
}
}
/// Read all the results of all method calls in a batch request from the ['Stream']. Format the result into a single
/// `String` appropriately wrapped in `[`/`]`.
pub async fn collect_batch_response(rx: mpsc::UnboundedReceiver<String>) -> String {
let mut buf = String::with_capacity(2048);
buf.push('[');
let mut buf = rx
.fold(buf, |mut acc, response| async move {
acc.push_str(&response);
acc.push(',');
acc
})
.await;
// Remove trailing comma
buf.pop();
buf.push(']');
buf
}
/// A permitted subscription.
#[derive(Debug)]
pub struct SubscriptionPermit {
......@@ -260,11 +213,122 @@ impl BoundedSubscriptions {
}
}
/// Represent the response to method call.
#[derive(Debug)]
pub struct MethodResponse {
/// Serialized JSON-RPC response,
pub result: String,
/// Indicates whether the call was successful or not.
pub success: bool,
}
impl MethodResponse {
/// Send a JSON-RPC response to the client. If the serialization of `result` exceeds `max_response_size`,
/// an error will be sent instead.
pub fn response(id: Id, result: impl Serialize, max_response_size: usize) -> Self {
let mut writer = BoundedWriter::new(max_response_size);
match serde_json::to_writer(&mut writer, &Response::new(result, id.clone())) {
Ok(_) => {
// Safety - serde_json does not emit invalid UTF-8.
let result = unsafe { String::from_utf8_unchecked(writer.into_bytes()) };
Self { result, success: true }
}
Err(err) => {
tracing::error!("Error serializing response: {:?}", err);
if err.is_io() {
let data = format!("Exceeded max limit of {}", max_response_size);
let err = ErrorObject::owned(OVERSIZED_RESPONSE_CODE, OVERSIZED_RESPONSE_MSG, Some(data));
let result = serde_json::to_string(&ErrorResponse::borrowed(err, id)).unwrap();
Self { result, success: false }
} else {
let result =
serde_json::to_string(&ErrorResponse::borrowed(ErrorCode::InternalError.into(), id)).unwrap();
Self { result, success: false }
}
}
}
}
/// Create a `MethodResponse` from an error.
pub fn error<'a>(id: Id, err: impl Into<ErrorObject<'a>>) -> Self {
let result = serde_json::to_string(&ErrorResponse::borrowed(err.into(), id)).expect("valid JSON; qed");
Self { result, success: false }
}
}
/// Builder to build a `BatchResponse`.
#[derive(Debug, Default)]
pub struct BatchResponseBuilder {
/// Serialized JSON-RPC response,
result: String,
/// Max limit for the batch
max_response_size: usize,
}
impl BatchResponseBuilder {
/// Create a new batch response builder with limit.
pub fn new_with_limit(limit: usize) -> Self {
let mut initial = String::with_capacity(2048);
initial.push('[');
Self { result: initial, max_response_size: limit }
}
/// Append a result from an individual method to the batch response.
///
/// Fails if the max limit is exceeded and returns to error response to
/// return early in order to not process method call responses which are thrown away anyway.
pub fn append(mut self, response: &MethodResponse) -> Result<Self, BatchResponse> {
// `,` will occupy one extra byte for each entry
// on the last item the `,` is replaced by `]`.
let len = response.result.len() + self.result.len() + 1;
if len > self.max_response_size {
Err(BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)))
} else {
self.result.push_str(&response.result);
self.result.push(',');
Ok(self)
}
}
/// Finish the batch response
pub fn finish(mut self) -> BatchResponse {
if self.result.len() == 1 {
BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
} else {
self.result.pop();
self.result.push(']');
BatchResponse { result: self.result, success: true }
}
}
}
/// Response to a batch request.
#[derive(Debug)]
pub struct BatchResponse {
/// Formatted JSON-RPC response.
pub result: String,
/// Indicates whether the call was successful or not.
pub success: bool,
}
impl BatchResponse {
/// Create a `BatchResponse` from an error.
pub fn error(id: Id, err: impl Into<ErrorObject<'static>>) -> Self {
let result = serde_json::to_string(&ErrorResponse::borrowed(err.into(), id)).unwrap();
Self { result, success: false }
}
}
#[cfg(test)]
mod tests {
use crate::server::helpers::BoundedSubscriptions;
use super::{BoundedWriter, Id, Response};
use super::{BatchResponseBuilder, BoundedWriter, Id, MethodResponse, Response};
#[test]
fn bounded_serializer_work() {
......@@ -295,4 +359,54 @@ mod tests {
handles.swap_remove(0);
assert!(subs.acquire().is_some());
}
#[test]
fn batch_with_single_works() {
let method = MethodResponse::response(Id::Number(1), "a", usize::MAX);
assert_eq!(method.result.len(), 37);
// Recall a batch appends two bytes for the `[]`.
let batch = BatchResponseBuilder::new_with_limit(39).append(&method).unwrap().finish();
assert!(batch.success);
assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","result":"a","id":1}]"#.to_string())
}
#[test]
fn batch_with_multiple_works() {
let m1 = MethodResponse::response(Id::Number(1), "a", usize::MAX);
assert_eq!(m1.result.len(), 37);
// Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call.
// so it should be 2 + (37 * n) + (n-1)
let limit = 2 + (37 * 2) + 1;
let batch = BatchResponseBuilder::new_with_limit(limit).append(&m1).unwrap().append(&m1).unwrap().finish();
assert!(batch.success);
assert_eq!(
batch.result,
r#"[{"jsonrpc":"2.0","result":"a","id":1},{"jsonrpc":"2.0","result":"a","id":1}]"#.to_string()
)
}
#[test]
fn batch_empty_err() {
let batch = BatchResponseBuilder::new_with_limit(1024).finish();
assert!(!batch.success);
let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}"#;
assert_eq!(batch.result, exp_err);
}
#[test]
fn batch_too_big() {
let method = MethodResponse::response(Id::Number(1), "a".repeat(28), 128);
assert_eq!(method.result.len(), 64);
let batch = BatchResponseBuilder::new_with_limit(63).append(&method).unwrap_err();
assert!(!batch.success);
let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}"#;
assert_eq!(batch.result, exp_err);
}
}
......@@ -35,48 +35,57 @@ use crate::id_providers::RandomIntegerIdProvider;
use crate::server::helpers::{BoundedSubscriptions, MethodSink, SubscriptionPermit};
use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources};
use crate::traits::{IdProvider, ToRpcParams};
use futures_channel::mpsc;
use futures_channel::{mpsc, oneshot};
use futures_util::future::Either;
use futures_util::pin_mut;
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt, TryStream, TryStreamExt};
use jsonrpsee_types::error::{
CallError, ErrorCode, ErrorObject, ErrorObjectOwned, INTERNAL_ERROR_CODE,
SUBSCRIPTION_CLOSED_WITH_ERROR, SubscriptionAcceptRejectError
CallError, ErrorCode, ErrorObject, ErrorObjectOwned, SubscriptionAcceptRejectError, INTERNAL_ERROR_CODE,
SUBSCRIPTION_CLOSED_WITH_ERROR,
};
use jsonrpsee_types::response::{SubscriptionError, SubscriptionPayloadError};
use jsonrpsee_types::{
ErrorResponse, Id, Params, Request, Response, SubscriptionResult,
SubscriptionId as RpcSubscriptionId, SubscriptionPayload, SubscriptionResponse
ErrorResponse, Id, Params, Request, Response, SubscriptionId as RpcSubscriptionId, SubscriptionPayload,
SubscriptionResponse, SubscriptionResult,
};
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::watch;
use super::helpers::MethodResponse;
/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
/// the `id`, `params`, a channel the function uses to communicate the result (or error)
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport).
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink) -> bool>;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, MaxResponseSize) -> MethodResponse>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler and takes an additional argument containing a [`ResourceGuard`] if configured.
pub type AsyncMethod<'a> = Arc<
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, ConnectionId, Option<ResourceGuard>) -> BoxFuture<'a, bool>,
dyn Send
+ Sync
+ Fn(Id<'a>, Params<'a>, ConnectionId, MaxResponseSize, Option<ResourceGuard>) -> BoxFuture<'a, MethodResponse>,