Unverified Commit 9bd21274 authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

feat(rpc module): `stream API` for SubscriptionSink (#639)



* feat(rpc module): add_stream to subscription sink

* fix some nits

* unify parameters to rpc methods

* Update core/src/server/rpc_module.rs

* Update tests/tests/integration_tests.rs

Co-authored-by: David's avatarDavid <dvdplm@gmail.com>

* address grumbles

* fix subscription tests

* new type for `SubscriptionCallback` and glue code

* remove unsed code

* remove todo

* add missing feature tokio/macros

* make `add_stream` cancel-safe

* rename add_stream and return status

* fix nits

* rename stream API -> streamify

* Update core/src/server/rpc_module.rs

* provide proper close reason

* spelling

* consume_and_streamify + docs

* fmt

* rename API pipe_from_stream

* improve logging; indicate which subscription method that failed

Co-authored-by: David's avatarDavid <dvdplm@gmail.com>
parent 429c196d
......@@ -11,6 +11,7 @@ anyhow = "1"
arrayvec = "0.7.1"
async-trait = "0.1"
beef = { version = "0.5.1", features = ["impl_serde"] }
async-channel = { version = "1.6", optional = true }
thiserror = "1"
futures-channel = { version = "0.3.14", default-features = false }
futures-util = { version = "0.3.14", default-features = false, optional = true }
......@@ -29,6 +30,7 @@ tokio = { version = "1.8", features = ["rt"], optional = true }
default = []
http-helpers = ["futures-util"]
server = [
"async-channel",
"futures-util",
"rustc-hash",
"tracing",
......
......@@ -28,7 +28,7 @@ use std::io;
use crate::{to_json_raw_value, Error};
use futures_channel::mpsc;
use futures_util::stream::StreamExt;
use futures_util::StreamExt;
use jsonrpsee_types::error::{
CallError, ErrorCode, ErrorObject, ErrorResponse, CALL_EXECUTION_FAILED_CODE, OVERSIZED_RESPONSE_CODE,
OVERSIZED_RESPONSE_MSG, UNKNOWN_ERROR_CODE,
......
......@@ -36,9 +36,9 @@ use crate::server::helpers::MethodSink;
use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources};
use crate::to_json_raw_value;
use crate::traits::{IdProvider, ToRpcParams};
use beef::Cow;
use futures_channel::{mpsc, oneshot};
use futures_util::{future::BoxFuture, FutureExt, StreamExt};
use futures_util::future::Either;
use futures_util::{future::BoxFuture, FutureExt, Stream, StreamExt};
use jsonrpsee_types::error::{invalid_subscription_err, ErrorCode, CALL_EXECUTION_FAILED_CODE};
use jsonrpsee_types::{
Id, Params, Request, Response, SubscriptionId as RpcSubscriptionId, SubscriptionPayload, SubscriptionResponse,
......@@ -51,16 +51,35 @@ use serde::{de::DeserializeOwned, Serialize};
/// implemented as a function pointer to a `Fn` function taking four arguments:
/// the `id`, `params`, a channel the function uses to communicate the result (or error)
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport).
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId, &dyn IdProvider) -> bool>;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink) -> bool>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler and takes an additional argument containing a [`ResourceGuard`] if configured.
pub type AsyncMethod<'a> = Arc<
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, Option<ResourceGuard>, &dyn IdProvider) -> BoxFuture<'a, bool>,
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, ConnectionId, Option<ResourceGuard>) -> BoxFuture<'a, bool>,
>;
/// Method callback for subscriptions.
pub type SubscriptionMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnState) -> bool>;
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Raw RPC response.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, mpsc::UnboundedSender<String>);
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, async_channel::Sender<()>);
/// Data for stateful connections.
pub struct ConnState<'a> {
/// Connection ID
pub conn_id: ConnectionId,
/// Channel to know whether the connection is closed or not.
pub close: async_channel::Receiver<()>,
/// ID provider.
pub id_provider: &'a dyn IdProvider,
}
impl<'a> std::fmt::Debug for ConnState<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnState").field("conn_id", &self.conn_id).field("close", &self.close).finish()
}
}
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, oneshot::Receiver<()>)>>>;
......@@ -73,11 +92,13 @@ struct SubscriptionKey {
/// Callback wrapper that can be either sync or async.
#[derive(Clone)]
enum MethodKind {
pub enum MethodKind {
/// Synchronous method handler.
Sync(SyncMethod),
/// Asynchronous method handler.
Async(AsyncMethod<'static>),
/// Subscription method handler
Subscription(SubscriptionMethod),
}
/// Information about resources the method uses during its execution. Initialized when the the server starts.
......@@ -144,6 +165,13 @@ impl MethodCallback {
MethodCallback { callback: MethodKind::Async(callback), resources: MethodResources::Uninitialized([].into()) }
}
fn new_subscription(callback: SubscriptionMethod) -> Self {
MethodCallback {
callback: MethodKind::Subscription(callback),
resources: MethodResources::Uninitialized([].into()),
}
}
/// Attempt to claim resources prior to executing a method. On success returns a guard that releases
/// claimed resources when dropped.
pub fn claim(&self, name: &str, resources: &Resources) -> Result<ResourceGuard, Error> {
......@@ -153,50 +181,9 @@ impl MethodCallback {
}
}
/// Execute the callback, sending the resulting JSON (success or error) to the specified sink.
pub fn execute(
&self,
sink: &MethodSink,
req: Request<'_>,
conn_id: ConnectionId,
claimed: Option<ResourceGuard>,
id_gen: &dyn IdProvider,
) -> MethodResult<bool> {
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let result = match &self.callback {
MethodKind::Sync(callback) => {
tracing::trace!(
"[MethodCallback::execute] Executing sync callback, params={:?}, req.id={:?}, conn_id={:?}",
params,
id,
conn_id
);
let result = (callback)(id, params, sink, conn_id, id_gen);
// Release claimed resources
drop(claimed);
MethodResult::Sync(result)
}
MethodKind::Async(callback) => {
let sink = sink.clone();
let params = params.into_owned();
let id = id.into_owned();
tracing::trace!(
"[MethodCallback::execute] Executing async callback, params={:?}, req.id={:?}, conn_id={:?}",
params,
id,
conn_id
);
MethodResult::Async((callback)(id, params, sink, claimed, id_gen))
}
};
result
/// Get handle to the callback.
pub fn inner(&self) -> &MethodKind {
&self.callback
}
}
......@@ -205,6 +192,7 @@ impl Debug for MethodKind {
match self {
Self::Async(_) => write!(f, "Async"),
Self::Sync(_) => write!(f, "Sync"),
Self::Subscription(_) => write!(f, "Subscription"),
}
}
}
......@@ -306,51 +294,6 @@ impl Methods {
self.callbacks.get_key_value(method_name).map(|(k, v)| (*k, v))
}
/// Attempt to execute a callback, sending the resulting JSON (success or error) to the specified sink.
pub fn execute(
&self,
sink: &MethodSink,
req: Request,
conn_id: ConnectionId,
id_gen: &dyn IdProvider,
) -> MethodResult<bool> {
tracing::trace!("[Methods::execute] Executing request: {:?}", req);
match self.callbacks.get(&*req.method) {
Some(callback) => callback.execute(sink, req, conn_id, None, id_gen),
None => {
sink.send_error(req.id, ErrorCode::MethodNotFound.into());
MethodResult::Sync(false)
}
}
}
/// Attempt to execute a callback while checking that the call does not exhaust the available resources,
// sending the resulting JSON (success or error) to the specified sink.
pub fn execute_with_resources<'r>(
&self,
sink: &MethodSink,
req: Request<'r>,
conn_id: ConnectionId,
resources: &Resources,
id_gen: &dyn IdProvider,
) -> Result<(&'static str, MethodResult<bool>), Cow<'r, str>> {
tracing::trace!("[Methods::execute_with_resources] Executing request: {:?}", req);
match self.callbacks.get_key_value(&*req.method) {
Some((&name, callback)) => match callback.claim(&req.method, resources) {
Ok(guard) => Ok((name, callback.execute(sink, req, conn_id, Some(guard), id_gen))),
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
Ok((name, MethodResult::Sync(false)))
}
},
None => {
sink.send_error(req.id, ErrorCode::MethodNotFound.into());
Err(req.method)
}
}
}
/// Helper to call a method on the `RPC module` without having to spin up a server.
///
/// The params must be serializable as JSON array, see [`ToRpcParams`] for further documentation.
......@@ -422,17 +365,27 @@ impl Methods {
Ok((resp, rx))
}
/// Wrapper over [`Methods::execute`] to execute a callback.
/// Execute a callback.
async fn inner_call(&self, req: Request<'_>) -> RawRpcResponse {
let (tx, mut rx) = mpsc::unbounded();
let sink = MethodSink::new(tx.clone());
let sink = MethodSink::new(tx);
let (close_tx, close_rx) = async_channel::unbounded();
if let MethodResult::Async(fut) = self.execute(&sink, req, 0, &RandomIntegerIdProvider) {
fut.await;
}
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let _result = match self.method(&req.method).map(|c| &c.callback) {
None => sink.send_error(req.id, ErrorCode::MethodNotFound.into()),
Some(MethodKind::Sync(cb)) => (cb)(id, params, &sink),
Some(MethodKind::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), sink, 0, None).await,
Some(MethodKind::Subscription(cb)) => {
let conn_state = ConnState { conn_id: 0, close: close_rx, id_provider: &RandomIntegerIdProvider };
(cb)(id, params, &sink, conn_state)
}
};
let resp = rx.next().await.expect("tx and rx still alive; qed");
(resp, rx, tx)
(resp, rx, close_tx)
}
/// Helper to create a subscription on the `RPC module` without having to spin up a server.
......@@ -527,7 +480,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_sync(Arc::new(move |id, params, sink, _, _| match callback(params, &*ctx) {
MethodCallback::new_sync(Arc::new(move |id, params, sink| match callback(params, &*ctx) {
Ok(res) => sink.send_response(id, res),
Err(err) => sink.send_call_error(id, err),
})),
......@@ -550,7 +503,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| {
MethodCallback::new_async(Arc::new(move |id, params, sink, _, claimed| {
let ctx = ctx.clone();
let future = async move {
let result = match callback(params, ctx).await {
......@@ -585,7 +538,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| {
MethodCallback::new_async(Arc::new(move |id, params, sink, _, claimed| {
let ctx = ctx.clone();
tokio::task::spawn_blocking(move || {
......@@ -671,12 +624,12 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
subscribe_method_name,
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn_id, id_provider| {
MethodCallback::new_subscription(Arc::new(move |id, params, method_sink, conn| {
let (conn_tx, conn_rx) = oneshot::channel::<()>();
let sub_id = {
let sub_id: RpcSubscriptionId = id_provider.next_id().into_owned();
let uniq_sub = SubscriptionKey { conn_id, sub_id: sub_id.clone() };
let sub_id: RpcSubscriptionId = conn.id_provider.next_id().into_owned();
let uniq_sub = SubscriptionKey { conn_id: conn.conn_id, sub_id: sub_id.clone() };
subscribers.lock().insert(uniq_sub, (method_sink.clone(), conn_rx));
......@@ -687,9 +640,10 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let sink = SubscriptionSink {
inner: method_sink.clone(),
close: conn.close,
method: notif_method_name,
subscribers: subscribers.clone(),
uniq_sub: SubscriptionKey { conn_id, sub_id },
uniq_sub: SubscriptionKey { conn_id: conn.conn_id, sub_id },
is_connected: Some(conn_tx),
};
if let Err(err) = callback(params, sink, ctx.clone()) {
......@@ -710,7 +664,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_id, _| {
MethodCallback::new_subscription(Arc::new(move |id, params, sink, conn| {
let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
......@@ -727,7 +681,11 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
};
let sub_id = sub_id.into_owned();
if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id: sub_id.clone() }).is_some() {
if subscribers
.lock()
.remove(&SubscriptionKey { conn_id: conn.conn_id, sub_id: sub_id.clone() })
.is_some()
{
sink.send_response(id, "Unsubscribed")
} else {
let err = to_json_raw_value(&format!(
......@@ -764,6 +722,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
pub struct SubscriptionSink {
/// Sink.
inner: MethodSink,
/// Close
close: async_channel::Receiver<()>,
/// MethodCallback.
method: &'static str,
/// Unique subscription.
......@@ -786,9 +746,71 @@ impl SubscriptionSink {
self.inner_send(msg).map_err(Into::into)
}
/// Consumes the `SubscriptionSink` and reads data from the `stream` and sends back data on the subscription
/// when items gets produced by the stream.
///
/// Returns `Ok(())` if the stream or connection was terminated.
/// Returns `Err(_)` if one of the items couldn't be serialized.
///
/// # Examples
///
/// ```no_run
///
/// use jsonrpsee_core::server::rpc_module::RpcModule;
///
/// let mut m = RpcModule::new(());
/// m.register_subscription("sub", "_", "unsub", |params, mut sink, _| {
/// let stream = futures_util::stream::iter(vec![1_u32, 2, 3]);
/// tokio::spawn(sink.pipe_from_stream(stream));
/// Ok(())
/// });
/// ```
pub async fn pipe_from_stream<S, T>(mut self, mut stream: S) -> Result<(), Error>
where
S: Stream<Item = T> + Unpin,
T: Serialize,
{
let mut close_stream = self.close.clone();
let mut item = stream.next();
let mut close = close_stream.next();
loop {
match futures_util::future::select(item, close).await {
Either::Left((Some(result), c)) => {
match self.send(&result) {
Ok(_) => (),
Err(Error::SubscriptionClosed(close_reason)) => {
self.close(&close_reason);
break Ok(());
}
Err(err) => {
tracing::error!("subscription `{}` failed to send item got error: {:?}", self.method, err);
break Err(err);
}
};
close = c;
item = stream.next();
}
// No messages should be sent over this channel
// if that occurred just ignore and continue.
Either::Right((Some(_), i)) => {
item = i;
close = close_stream.next();
}
// Connection terminated.
Either::Right((None, _)) => {
self.close(&SubscriptionClosed::new(SubscriptionClosedReason::ConnectionReset));
break Ok(());
}
// Stream terminated.
Either::Left((None, _)) => break Ok(()),
}
}
}
/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
self.inner.is_closed() || self.close.is_closed()
}
fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> {
......@@ -806,7 +828,7 @@ impl SubscriptionSink {
self.inner.send_raw(msg).map_err(|_| Some(SubscriptionClosedReason::ConnectionReset))
}
Some(_) => Err(Some(SubscriptionClosedReason::Unsubscribed)),
// NOTE(niklasad1): this should be unreachble, after the first error is detected the subscription is closed.
// NOTE(niklasad1): this should be unreachable, after the first error is detected the subscription is closed.
None => Err(None),
};
......@@ -823,11 +845,16 @@ impl SubscriptionSink {
}
/// Close the subscription sink with a customized error message.
pub fn close(&mut self, msg: &str) {
pub fn close_with_custom_message(&mut self, msg: &str) {
let close_reason = SubscriptionClosedReason::Server(msg.to_string()).into();
self.inner_close(Some(&close_reason));
}
/// Provide close from `SubscriptionClosed`.
pub fn close(&mut self, close_reason: &SubscriptionClosed) {
self.inner_close(Some(close_reason));
}
fn inner_close(&mut self, close_reason: Option<&SubscriptionClosed>) {
self.is_connected.take();
if let Some((sink, _)) = self.subscribers.lock().remove(&self.uniq_sub) {
......@@ -850,7 +877,7 @@ impl Drop for SubscriptionSink {
/// Wrapper struct that maintains a subscription "mainly" for testing.
#[derive(Debug)]
pub struct Subscription {
tx: mpsc::UnboundedSender<String>,
tx: async_channel::Sender<()>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: RpcSubscriptionId<'static>,
}
......@@ -858,7 +885,7 @@ pub struct Subscription {
impl Subscription {
/// Close the subscription channel.
pub fn close(&mut self) {
self.tx.close_channel();
self.tx.close();
}
/// Get the subscription ID
......
......@@ -40,14 +40,13 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::{Error as HyperError, Method};
use jsonrpsee_core::error::{Error, GenericTransportError};
use jsonrpsee_core::http_helpers::{self, read_body};
use jsonrpsee_core::id_providers::NoopIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink};
use jsonrpsee_core::server::resource_limiting::Resources;
use jsonrpsee_core::server::rpc_module::{MethodResult, Methods};
use jsonrpsee_core::server::rpc_module::{MethodKind, Methods};
use jsonrpsee_core::TEN_MB_SIZE_BYTES;
use jsonrpsee_types::error::ErrorCode;
use jsonrpsee_types::{Id, Notification, Request};
use jsonrpsee_types::{Id, Notification, Params, Request};
use serde_json::value::RawValue;
use socket2::{Domain, Socket, Type};
......@@ -454,48 +453,117 @@ async fn process_validated_request(
// Single request or notification
if is_single {
if let Ok(req) = serde_json::from_slice::<Request>(&body) {
middleware.on_call(req.method.as_ref());
let method = req.method.as_ref();
middleware.on_call(method);
// NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here.
match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
}
Ok((name, MethodResult::Async(fut))) => {
let success = fut.await;
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
middleware.on_result(name, success, request_start);
}
Err(name) => {
middleware.on_result(name.as_ref(), false, request_start);
let result = match methods.method_with_name(method) {
None => {
sink.send_error(req.id, ErrorCode::MethodNotFound.into());
false
}
}
Some((name, method_callback)) => match method_callback.inner() {
MethodKind::Sync(callback) => match method_callback.claim(&req.method, &resources) {
Ok(guard) => {
let result = (callback)(id, params, &sink);
drop(guard);
result
}
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
false
}
},
MethodKind::Async(callback) => match method_callback.claim(name, &resources) {
Ok(guard) => {
let result =
(callback)(id.into_owned(), params.into_owned(), sink.clone(), 0, Some(guard)).await;
result
}
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
false
}
},
MethodKind::Subscription(_) => {
tracing::error!("Subscriptions not supported on HTTP");
sink.send_error(req.id, ErrorCode::InternalError.into());
false
}
},
};
middleware.on_result(&req.method, result, request_start);
} else if let Ok(_req) = serde_json::from_slice::<Notif>(&body) {
return Ok::<_, HyperError>(response::ok_response("".into()));
} else {
let (id, code) = prepare_error(&body);
sink.send_error(id, code.into());
}
// Batch of requests or notifications
} else if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&body) {
if !batch.is_empty() {
let middleware = &middleware;
join_all(batch.into_iter().filter_map(move |req| {
match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
None
}
Ok((name, MethodResult::Async(fut))) => Some(async move {
let success = fut.await;
middleware.on_result(name, success, request_start);
}),
Err(name) => {
middleware.on_result(name.as_ref(), false, request_start);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
match methods.method_with_name(&req.method) {
None => {
sink.send_error(req.id, ErrorCode::MethodNotFound.into());
None
}
Some((name, method_callback)) => match method_callback.inner() {
MethodKind::Sync(callback) => match method_callback.claim(name, &resources) {
Ok(guard) => {
let result = (callback)(id, params, &sink);
middleware.on_result(name, result, request_start);
drop(guard);
None
}
Err(err) => {
tracing::error!(
"[Methods::execute_with_resources] failed to lock resources: {:?}",
err
);
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
middleware.on_result(name, false, request_start);
None
}
},
MethodKind::Async(callback) => match method_callback.claim(name, &resources) {
Ok(guard) => {
let sink = sink.clone();
let id = id.into_owned();
let params = params.into_owned();
let callback = callback.clone();
Some(async move {
let result = (callback)(id, params, sink, 0, Some(guard)).await;
middleware.on_result(name, result, request_start);
})
}
Err(err) => {
tracing::error!(
"[Methods::execute_with_resources] failed to lock resources: {:?}",
err
);