Unverified Commit 150824b0 authored by Niklas Adolfsson's avatar Niklas Adolfsson
Browse files

add limit to batch responses as well

parent 3e885043
Pipeline #200745 passed with stages
in 5 minutes and 44 seconds
......@@ -142,7 +142,7 @@ pub async fn ws_server(handle: tokio::runtime::Handle) -> (String, jsonrpsee::ws
module
.register_subscription(SUB_METHOD_NAME, SUB_METHOD_NAME, UNSUB_METHOD_NAME, |_params, pending, _ctx| {
let mut sink = match pending.accept() {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
......
......@@ -263,35 +263,43 @@ impl MethodResponse {
pub struct BatchResponseBuilder {
/// Serialized JSON-RPC response,
result: String,
/// Indicates whether the call was successful or not.
success: bool,
/// Max limit for the batch
max_response_size: usize,
}
impl BatchResponseBuilder {
/// Create a new batch response builder.
pub fn new() -> Self {
Self { result: String::from("["), success: true }
/// Create a new batch response builder with limit.
pub fn new_with_limit(limit: usize) -> Self {
Self { result: String::from("["), max_response_size: limit }
}
/// Append a result from an individual method to the batch response.
pub fn append(&mut self, response: &MethodResponse) {
if !response.success {
self.success = false;
///
/// Fails if the max limit is exceeded and returns to error response to
/// return early in order to not process method call responses which are thrown away anyway.
pub fn append(mut self, response: &MethodResponse) -> Result<Self, BatchResponse> {
// `,` will occupy one extra byte for each entry
// on the last item the `,` is replaced by `]`.
let len = response.result.len() + self.result.len() + 1;
if len > self.max_response_size {
Err(BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest)))
} else {
self.result.push_str(&response.result);
self.result.push(',');
Ok(self)
}
self.result.push_str(&response.result);
self.result.push(',');
}
/// Finish the batch response
pub fn finish(mut self) -> BatchResponse {
if self.result.len() == 1 {
panic!("Batch response needs at least one item");
BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
} else {
self.result.pop();
self.result.push(']');
BatchResponse { result: self.result, success: true }
}
self.result.pop();
self.result.push(']');
BatchResponse { result: self.result, success: self.success }
}
}
......@@ -316,7 +324,7 @@ impl BatchResponse {
mod tests {
use crate::server::helpers::BoundedSubscriptions;
use super::{BoundedWriter, Id, Response};
use super::{BatchResponseBuilder, BoundedWriter, Id, MethodResponse, Response};
#[test]
fn bounded_serializer_work() {
......@@ -347,4 +355,54 @@ mod tests {
handles.swap_remove(0);
assert!(subs.acquire().is_some());
}
#[test]
fn batch_with_single_works() {
let method = MethodResponse::response(Id::Number(1), "a", usize::MAX);
assert_eq!(method.result.len(), 37);
// Recall a batch appends two bytes for the `[]`.
let batch = BatchResponseBuilder::new_with_limit(39).append(&method).unwrap().finish();
assert!(batch.success);
assert_eq!(batch.result, r#"[{"jsonrpc":"2.0","result":"a","id":1}]"#.to_string())
}
#[test]
fn batch_with_multiple_works() {
let m1 = MethodResponse::response(Id::Number(1), "a", usize::MAX);
assert_eq!(m1.result.len(), 37);
// Recall a batch appends two bytes for the `[]` and one byte for `,` to append a method call.
// so it should be 2 + (37 * n) + (n-1)
let limit = 2 + (37 * 2) + 1;
let batch = BatchResponseBuilder::new_with_limit(limit).append(&m1).unwrap().append(&m1).unwrap().finish();
assert!(batch.success);
assert_eq!(
batch.result,
r#"[{"jsonrpc":"2.0","result":"a","id":1},{"jsonrpc":"2.0","result":"a","id":1}]"#.to_string()
)
}
#[test]
fn batch_empty_err() {
let batch = BatchResponseBuilder::new_with_limit(1024).finish();
assert!(!batch.success);
let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}"#;
assert_eq!(batch.result, exp_err);
}
#[test]
fn batch_too_big() {
let method = MethodResponse::response(Id::Number(1), "a".repeat(28), 128);
assert_eq!(method.result.len(), 64);
let batch = BatchResponseBuilder::new_with_limit(63).append(&method).unwrap_err();
assert!(!batch.success);
let exp_err = r#"{"jsonrpc":"2.0","error":{"code":-32600,"message":"Invalid request"},"id":null}"#;
assert_eq!(batch.result, exp_err);
}
}
......@@ -32,7 +32,8 @@ use std::task::{Context, Poll};
use crate::response;
use crate::response::{internal_error, malformed};
use futures_channel::mpsc;
use futures_util::{stream::StreamExt, FutureExt};
use futures_util::future::FutureExt;
use futures_util::stream::{StreamExt, TryStreamExt};
use hyper::header::{HeaderMap, HeaderValue};
use hyper::server::conn::AddrStream;
use hyper::server::{conn::AddrIncoming, Builder as HyperBuilder};
......@@ -759,30 +760,32 @@ where
let Batch { data, call } = b;
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
return if !batch.is_empty() {
let batch = batch.into_iter().map(|req| (req, call.clone()));
let max_response_size = call.max_response_body_size;
let batch = batch.into_iter().map(|req| Ok((req, call.clone())));
let batch_stream = futures_util::stream::iter(batch);
let batch_stream = futures_util::stream::iter(batch);
let trace = RpcTracing::batch();
let _enter = trace.span().enter();
let trace = RpcTracing::batch();
let _enter = trace.span().enter();
let batch_response = batch_stream
.fold(BatchResponseBuilder::new(), |mut batch_response, (req, call)| async move {
let batch_response = batch_stream
.try_fold(
BatchResponseBuilder::new_with_limit(max_response_size as usize),
|batch_response, (req, call)| async move {
let params = Params::new(req.params.map(|params| params.get()));
let response = execute_call(Call { name: &req.method, params, id: req.id, call }).await;
batch_response.append(&response);
batch_response
})
.in_current_span()
.await;
let batch = batch_response.append(&response);
batch
},
)
.in_current_span()
.await;
batch_response.finish()
} else {
BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
return match batch_response {
Ok(batch) => batch.finish(),
Err(batch_err) => batch_err,
};
}
......
......@@ -476,6 +476,25 @@ async fn can_set_the_max_response_size() {
handle.stop().unwrap();
}
#[tokio::test]
async fn can_set_the_max_response_size_to_batch() {
let addr = "127.0.0.1:0";
// Set the max response size to 100 bytes
let server = HttpServerBuilder::default().max_response_body_size(100).build(addr).await.unwrap();
let mut module = RpcModule::new(());
module.register_method("anything", |_p, _cx| Ok("a".repeat(51))).unwrap();
let addr = server.local_addr().unwrap();
let uri = to_http_uri(addr);
let handle = server.start(module).unwrap();
// Two response will end up in a response of 102 bytes which is too big.
let req = r#"[{"jsonrpc":"2.0", "method":"anything", "id":1},{"jsonrpc":"2.0", "method":"anything", "id":2}]"#;
let response = http_request(req.into(), uri.clone()).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response.body, invalid_request(Id::Null));
handle.stop().unwrap();
}
#[tokio::test]
async fn disabled_batches() {
let addr = "127.0.0.1:0";
......
......@@ -38,6 +38,7 @@ use futures_channel::mpsc;
use futures_util::future::{Either, FutureExt};
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use futures_util::TryStreamExt;
use http::header::{HOST, ORIGIN};
use http::{HeaderMap, HeaderValue};
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
......@@ -884,26 +885,31 @@ where
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
return if !batch.is_empty() {
let batch = batch.into_iter().map(|req| (req, call.clone()));
let batch = batch.into_iter().map(|req| Ok((req, call.clone())));
let batch_stream = futures_util::stream::iter(batch);
let trace = RpcTracing::batch();
let _enter = trace.span().enter();
let max_response_size = call.max_response_body_size;
let batch_response = batch_stream
.fold(BatchResponseBuilder::new(), |mut batch_response, (req, call)| async move {
let params = Params::new(req.params.map(|params| params.get()));
.try_fold(
BatchResponseBuilder::new_with_limit(max_response_size as usize),
|batch_response, (req, call)| async move {
let params = Params::new(req.params.map(|params| params.get()));
let response =
execute_call(Call { name: &req.method, params, id: req.id, call }).in_current_span().await;
batch_response.append(response.as_inner());
let response =
execute_call(Call { name: &req.method, params, id: req.id, call }).in_current_span().await;
batch_response
})
batch_response.append(response.as_inner())
},
)
.await;
return batch_response.finish();
return match batch_response {
Ok(batch) => batch.finish(),
Err(batch_err) => batch_err,
};
} else {
BatchResponse::error(Id::Null, ErrorObject::from(ErrorCode::InvalidRequest))
};
......
......@@ -233,6 +233,28 @@ async fn can_set_the_max_response_body_size() {
handle.stop().unwrap();
}
#[tokio::test]
async fn can_set_the_max_response_size_to_batch() {
init_logger();
let addr = "127.0.0.1:0";
// Set the max response body size to 100 bytes
let server = WsServerBuilder::default().max_response_body_size(100).build(addr).await.unwrap();
let mut module = RpcModule::new(());
module.register_method("anything", |_p, _cx| Ok("a".repeat(51))).unwrap();
let addr = server.local_addr().unwrap();
let handle = server.start(module).unwrap();
let mut client = WebSocketTestClient::new(addr).await.unwrap();
// Two response will end up in a response bigger than 100 bytes.
let req = r#"[{"jsonrpc":"2.0", "method":"anything", "id":1},{"jsonrpc":"2.0", "method":"anything", "id":2}]"#;
let response = client.send_request_text(req).await.unwrap();
assert_eq!(response, invalid_request(Id::Null));
handle.stop().unwrap();
}
#[tokio::test]
async fn can_set_max_connections() {
let addr = "127.0.0.1:0";
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment