Unverified Commit bc688cc2 authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

feat: server configurable subscriptionID (#604)



* PoC

* introduce IdProvider trait

* revert Cow stuff

* Update ws-server/src/server.rs

* fix tests

* cargo fmt

* use 'static lifetime in SubscriptionId

* fix tests

* fmt

* make tests compile again

* fix tests

* Fix tests

* Move IdProvider impls to own module

* move `sub-id gen` types to `core`

* fix doc links

* make rand non-optional dep

* feature gate: id provider

Co-authored-by: David's avatarDavid Palm <dvdplm@gmail.com>
parent 7aa8b012
......@@ -40,9 +40,9 @@ client = ["futures-util"]
async-client = [
"client",
"rustc-hash",
"tokio/sync",
"tokio/macros",
"tokio/time",
"tokio/sync",
"tokio/macros",
"tokio/time",
"tracing"
]
......
......@@ -149,7 +149,7 @@ pub(crate) fn process_single_response(
let sub_id: Result<SubscriptionId, _> = response.result.try_into();
let sub_id = match sub_id {
Ok(sub_id) => sub_id.into_owned(),
Ok(sub_id) => sub_id,
Err(_) => {
let _ = send_back_oneshot.send(Err(Error::InvalidSubscriptionId));
return Ok(None);
......
......@@ -167,7 +167,7 @@ pub enum SubscriptionKind {
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum NotifResponse<Notif> {
/// Successful response
/// Successful response.
Ok(Notif),
/// Subscription was closed.
Err(SubscriptionClosed),
......
// Copyright 2019-2021 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any
// person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the
// Software without restriction, including without
// limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice
// shall be included in all copies or substantial portions
// of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use rand::distributions::Alphanumeric;
use rand::Rng;
use crate::traits::IdProvider;
use jsonrpsee_types::SubscriptionId;
/// Generates random integers as subscription ID.
#[derive(Debug)]
pub struct RandomIntegerIdProvider;
impl IdProvider for RandomIntegerIdProvider {
fn next_id(&self) -> SubscriptionId<'static> {
const JS_NUM_MASK: u64 = !0 >> 11;
(rand::random::<u64>() & JS_NUM_MASK).into()
}
}
/// Generates random strings of length `len` as subscription ID.
#[derive(Debug)]
pub struct RandomStringIdProvider {
len: usize,
}
impl RandomStringIdProvider {
/// Create a new random string provider.
pub fn new(len: usize) -> Self {
Self { len }
}
}
impl IdProvider for RandomStringIdProvider {
fn next_id(&self) -> SubscriptionId<'static> {
let mut rng = rand::thread_rng();
(&mut rng).sample_iter(Alphanumeric).take(self.len).map(char::from).collect::<String>().into()
}
}
/// No-op implementation to be used for servers that don't support subscriptions.
#[derive(Debug)]
pub struct NoopIdProvider;
impl IdProvider for NoopIdProvider {
fn next_id(&self) -> SubscriptionId<'static> {
0.into()
}
}
......@@ -41,6 +41,10 @@ pub mod middleware;
#[cfg(feature = "http-helpers")]
pub mod http_helpers;
/// Different ways of setting the "id" in JSON-RPC responses and results.
#[cfg(feature = "server")]
pub mod id_providers;
/// Shared code for JSON-RPC servers.
#[cfg(feature = "server")]
pub mod server;
......
......@@ -31,10 +31,11 @@ use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use crate::error::{Error, SubscriptionClosed, SubscriptionClosedReason};
use crate::id_providers::RandomIntegerIdProvider;
use crate::server::helpers::MethodSink;
use crate::server::resource_limiting::{ResourceGuard, ResourceTable, ResourceVec, Resources};
use crate::to_json_raw_value;
use crate::traits::ToRpcParams;
use crate::traits::{IdProvider, ToRpcParams};
use beef::Cow;
use futures_channel::{mpsc, oneshot};
use futures_util::{future::BoxFuture, FutureExt, StreamExt};
......@@ -50,25 +51,24 @@ use serde::{de::DeserializeOwned, Serialize};
/// implemented as a function pointer to a `Fn` function taking four arguments:
/// the `id`, `params`, a channel the function uses to communicate the result (or error)
/// back to `jsonrpsee`, and the connection ID (useful for the websocket transport).
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId) -> bool>;
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Id, Params, &MethodSink, ConnectionId, &dyn IdProvider) -> bool>;
/// Similar to [`SyncMethod`], but represents an asynchronous handler and takes an additional argument containing a [`ResourceGuard`] if configured.
pub type AsyncMethod<'a> =
Arc<dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, Option<ResourceGuard>) -> BoxFuture<'a, bool>>;
pub type AsyncMethod<'a> = Arc<
dyn Send + Sync + Fn(Id<'a>, Params<'a>, MethodSink, Option<ResourceGuard>, &dyn IdProvider) -> BoxFuture<'a, bool>,
>;
/// Connection ID, used for stateful protocol such as WebSockets.
/// For stateless protocols such as http it's unused, so feel free to set it some hardcoded value.
pub type ConnectionId = usize;
/// Subscription ID.
pub type SubscriptionId = u64;
/// Raw RPC response.
pub type RawRpcResponse = (String, mpsc::UnboundedReceiver<String>, mpsc::UnboundedSender<String>);
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, oneshot::Receiver<()>)>>>;
/// Represent a unique subscription entry based on [`SubscriptionId`] and [`ConnectionId`].
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
/// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`].
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct SubscriptionKey {
conn_id: ConnectionId,
sub_id: SubscriptionId,
sub_id: RpcSubscriptionId<'static>,
}
/// Callback wrapper that can be either sync or async.
......@@ -160,6 +160,7 @@ impl MethodCallback {
req: Request<'_>,
conn_id: ConnectionId,
claimed: Option<ResourceGuard>,
id_gen: &dyn IdProvider,
) -> MethodResult<bool> {
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
......@@ -173,7 +174,7 @@ impl MethodCallback {
conn_id
);
let result = (callback)(id, params, sink, conn_id);
let result = (callback)(id, params, sink, conn_id, id_gen);
// Release claimed resources
drop(claimed);
......@@ -191,7 +192,7 @@ impl MethodCallback {
conn_id
);
MethodResult::Async((callback)(id, params, sink, claimed))
MethodResult::Async((callback)(id, params, sink, claimed, id_gen))
}
};
......@@ -306,10 +307,16 @@ impl Methods {
}
/// Attempt to execute a callback, sending the resulting JSON (success or error) to the specified sink.
pub fn execute(&self, sink: &MethodSink, req: Request, conn_id: ConnectionId) -> MethodResult<bool> {
pub fn execute(
&self,
sink: &MethodSink,
req: Request,
conn_id: ConnectionId,
id_gen: &dyn IdProvider,
) -> MethodResult<bool> {
tracing::trace!("[Methods::execute] Executing request: {:?}", req);
match self.callbacks.get(&*req.method) {
Some(callback) => callback.execute(sink, req, conn_id, None),
Some(callback) => callback.execute(sink, req, conn_id, None, id_gen),
None => {
sink.send_error(req.id, ErrorCode::MethodNotFound.into());
MethodResult::Sync(false)
......@@ -325,11 +332,12 @@ impl Methods {
req: Request<'r>,
conn_id: ConnectionId,
resources: &Resources,
id_gen: &dyn IdProvider,
) -> Result<(&'static str, MethodResult<bool>), Cow<'r, str>> {
tracing::trace!("[Methods::execute_with_resources] Executing request: {:?}", req);
match self.callbacks.get_key_value(&*req.method) {
Some((&name, callback)) => match callback.claim(&req.method, resources) {
Ok(guard) => Ok((name, callback.execute(sink, req, conn_id, Some(guard)))),
Ok(guard) => Ok((name, callback.execute(sink, req, conn_id, Some(guard), id_gen))),
Err(err) => {
tracing::error!("[Methods::execute_with_resources] failed to lock resources: {:?}", err);
sink.send_error(req.id, ErrorCode::ServerIsBusy.into());
......@@ -419,7 +427,7 @@ impl Methods {
let (tx, mut rx) = mpsc::unbounded();
let sink = MethodSink::new(tx.clone());
if let MethodResult::Async(fut) = self.execute(&sink, req, 0) {
if let MethodResult::Async(fut) = self.execute(&sink, req, 0, &RandomIntegerIdProvider) {
fut.await;
}
......@@ -457,8 +465,8 @@ impl Methods {
let req = Request::new(sub_method.into(), Some(&params), Id::Number(0));
tracing::trace!("[Methods::subscribe] Calling subscription method: {:?}, params: {:?}", sub_method, params);
let (response, rx, tx) = self.inner_call(req).await;
let subscription_response = serde_json::from_str::<Response<SubscriptionId>>(&response)?;
let sub_id = subscription_response.result;
let subscription_response = serde_json::from_str::<Response<RpcSubscriptionId>>(&response)?;
let sub_id = subscription_response.result.into_owned();
Ok(Subscription { sub_id, rx, tx })
}
......@@ -519,7 +527,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_sync(Arc::new(move |id, params, sink, _| match callback(params, &*ctx) {
MethodCallback::new_sync(Arc::new(move |id, params, sink, _, _| match callback(params, &*ctx) {
Ok(res) => sink.send_response(id, res),
Err(err) => sink.send_call_error(id, err),
})),
......@@ -542,7 +550,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed| {
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| {
let ctx = ctx.clone();
let future = async move {
let result = match callback(params, ctx).await {
......@@ -577,7 +585,7 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let ctx = self.ctx.clone();
let callback = self.methods.verify_and_insert(
method_name,
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed| {
MethodCallback::new_async(Arc::new(move |id, params, sink, claimed, _| {
let ctx = ctx.clone();
tokio::task::spawn_blocking(move || {
......@@ -663,19 +671,19 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
let subscribers = subscribers.clone();
self.methods.mut_callbacks().insert(
subscribe_method_name,
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn_id| {
MethodCallback::new_sync(Arc::new(move |id, params, method_sink, conn_id, id_provider| {
let (conn_tx, conn_rx) = oneshot::channel::<()>();
let sub_id = {
const JS_NUM_MASK: SubscriptionId = !0 >> 11;
let sub_id = rand::random::<SubscriptionId>() & JS_NUM_MASK;
let uniq_sub = SubscriptionKey { conn_id, sub_id };
let sub_id: RpcSubscriptionId = id_provider.next_id().into_owned();
let uniq_sub = SubscriptionKey { conn_id, sub_id: sub_id.clone() };
subscribers.lock().insert(uniq_sub, (method_sink.clone(), conn_rx));
sub_id
};
method_sink.send_response(id.clone(), sub_id);
method_sink.send_response(id.clone(), &sub_id);
let sink = SubscriptionSink {
inner: method_sink.clone(),
......@@ -702,8 +710,8 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
{
self.methods.mut_callbacks().insert(
unsubscribe_method_name,
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_id| {
let sub_id = match params.one() {
MethodCallback::new_sync(Arc::new(move |id, params, sink, conn_id, _| {
let sub_id = match params.one::<RpcSubscriptionId>() {
Ok(sub_id) => sub_id,
Err(_) => {
tracing::error!(
......@@ -712,15 +720,21 @@ impl<Context: Send + Sync + 'static> RpcModule<Context> {
params,
id
);
let err = to_json_raw_value(&"Invalid subscription ID type, must be integer").ok();
let err =
to_json_raw_value(&"Invalid subscription ID type, must be Integer or String").ok();
return sink.send_error(id, invalid_subscription_err(err.as_deref()));
}
};
let sub_id = sub_id.into_owned();
if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id }).is_some() {
if subscribers.lock().remove(&SubscriptionKey { conn_id, sub_id: sub_id.clone() }).is_some() {
sink.send_response(id, "Unsubscribed")
} else {
let err = to_json_raw_value(&format!("Invalid subscription ID={}", sub_id)).ok();
let err = to_json_raw_value(&format!(
"Invalid subscription ID={}",
serde_json::to_string(&sub_id).expect("valid JSON; qed")
))
.ok();
sink.send_error(id, invalid_subscription_err(err.as_deref()))
}
})),
......@@ -780,7 +794,7 @@ impl SubscriptionSink {
fn build_message<T: Serialize>(&self, result: &T) -> Result<String, Error> {
serde_json::to_string(&SubscriptionResponse::new(
self.method.into(),
SubscriptionPayload { subscription: RpcSubscriptionId::Num(self.uniq_sub.sub_id), result },
SubscriptionPayload { subscription: self.uniq_sub.sub_id.clone(), result },
))
.map_err(Into::into)
}
......@@ -838,7 +852,7 @@ impl Drop for SubscriptionSink {
pub struct Subscription {
tx: mpsc::UnboundedSender<String>,
rx: mpsc::UnboundedReceiver<String>,
sub_id: u64,
sub_id: RpcSubscriptionId<'static>,
}
impl Subscription {
......@@ -848,8 +862,8 @@ impl Subscription {
}
/// Get the subscription ID
pub fn subscription_id(&self) -> u64 {
self.sub_id
pub fn subscription_id(&self) -> &RpcSubscriptionId {
&self.sub_id
}
/// Returns `Some((val, sub_id))` for the next element of type T from the underlying stream,
......
......@@ -24,6 +24,7 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use jsonrpsee_types::SubscriptionId;
use serde::Serialize;
use serde_json::value::RawValue;
......@@ -68,3 +69,9 @@ tuple_impls! {
15 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14)
16 => (0 T0 1 T1 2 T2 3 T3 4 T4 5 T5 6 T6 7 T7 8 T8 9 T9 10 T10 11 T11 12 T12 13 T13 14 T14 15 T15)
}
/// Trait to generate subscription IDs.
pub trait IdProvider: Send + Sync {
/// Returns the next ID for the subscription.
fn next_id(&self) -> SubscriptionId<'static>;
}
......@@ -38,6 +38,7 @@ use hyper::service::{make_service_fn, service_fn};
use hyper::Error as HyperError;
use jsonrpsee_core::error::{Error, GenericTransportError};
use jsonrpsee_core::http_helpers::read_body;
use jsonrpsee_core::id_providers::NoopIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink};
use jsonrpsee_core::server::resource_limiting::Resources;
......@@ -334,7 +335,7 @@ impl<M: Middleware> Server<M> {
middleware.on_call(req.method.as_ref());
// NOTE: we don't need to track connection id on HTTP, so using hardcoded 0 here.
match methods.execute_with_resources(&sink, req, 0, &resources) {
match methods.execute_with_resources(&sink, req, 0, &resources, &NoopIdProvider) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
}
......@@ -359,8 +360,14 @@ impl<M: Middleware> Server<M> {
if !batch.is_empty() {
let middleware = &middleware;
join_all(batch.into_iter().filter_map(move |req| {
match methods.execute_with_resources(&sink, req, 0, &resources) {
join_all(batch.into_iter().filter_map(
move |req| match methods.execute_with_resources(
&sink,
req,
0,
&resources,
&NoopIdProvider,
) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
None
......@@ -373,8 +380,8 @@ impl<M: Middleware> Server<M> {
middleware.on_result(name.as_ref(), false, request_start);
None
}
}
}))
},
))
.await;
} else {
// "If the batch rpc call itself fails to be recognized as an valid JSON or as an
......
......@@ -28,7 +28,7 @@ use std::collections::HashMap;
use jsonrpsee::core::server::rpc_module::*;
use jsonrpsee::core::Error;
use jsonrpsee::types::{EmptyParams, Params, SubscriptionId as RpcSubscriptionId};
use jsonrpsee::types::{EmptyParams, Params};
use serde::{Deserialize, Serialize};
#[test]
......@@ -201,7 +201,7 @@ async fn subscribing_without_server() {
for i in (0..=2).rev() {
let (val, id) = my_sub.next::<char>().await.unwrap().unwrap();
assert_eq!(val, std::char::from_digit(i, 10).unwrap());
assert_eq!(id, RpcSubscriptionId::Num(my_sub.subscription_id()));
assert_eq!(&id, my_sub.subscription_id());
}
let sub_err = my_sub.next::<char>().await.unwrap().unwrap_err();
......@@ -228,7 +228,7 @@ async fn close_test_subscribing_without_server() {
let mut my_sub = module.subscribe("my_sub", EmptyParams::new()).await.unwrap();
let (val, id) = my_sub.next::<String>().await.unwrap().unwrap();
assert_eq!(&val, "lo");
assert_eq!(id, RpcSubscriptionId::Num(my_sub.subscription_id()));
assert_eq!(&id, my_sub.subscription_id());
// close the subscription to ensure it doesn't return any items.
my_sub.close();
......
......@@ -20,3 +20,4 @@ serde_json = { version = "1", default-features = false, features = ["alloc", "ra
thiserror = "1.0"
soketto = "0.7.1"
hyper = "0.14.10"
rand = "0.8"
......@@ -308,6 +308,18 @@ impl<'a> From<SubscriptionId<'a>> for JsonValue {
}
}
impl<'a> From<u64> for SubscriptionId<'a> {
fn from(sub_id: u64) -> Self {
Self::Num(sub_id)
}
}
impl<'a> From<String> for SubscriptionId<'a> {
fn from(sub_id: String) -> Self {
Self::Str(sub_id.into())
}
}
impl<'a> TryFrom<JsonValue> for SubscriptionId<'a> {
type Error = ();
......
......@@ -24,4 +24,5 @@ tokio-util = { version = "0.6", features = ["compat"] }
anyhow = "1"
jsonrpsee-test-utils = { path = "../test-utils" }
jsonrpsee = { path = "../jsonrpsee", features = ["full"] }
rand = "0.8"
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
......@@ -40,6 +40,7 @@ mod tests;
pub use future::{ServerHandle as WsServerHandle, ShutdownWaiter as WsShutdownWaiter};
pub use jsonrpsee_core::server::rpc_module::{RpcModule, SubscriptionSink};
pub use jsonrpsee_core::{id_providers::*, traits::IdProvider};
pub use jsonrpsee_types as types;
pub use server::{Builder as WsServerBuilder, Server as WsServer};
pub use tracing;
......@@ -27,6 +27,7 @@
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::future::{FutureDriver, ServerHandle, StopMonitor};
......@@ -36,10 +37,12 @@ use futures_channel::mpsc;
use futures_util::future::{join_all, FutureExt};
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
use jsonrpsee_core::middleware::Middleware;
use jsonrpsee_core::server::helpers::{collect_batch_response, prepare_error, MethodSink};
use jsonrpsee_core::server::resource_limiting::Resources;
use jsonrpsee_core::server::rpc_module::{ConnectionId, MethodResult, Methods};
use jsonrpsee_core::traits::IdProvider;
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
use soketto::connection::Error as SokettoError;
use soketto::handshake::{server::Response, Server as SokettoServer};
......@@ -51,13 +54,23 @@ use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
const MAX_CONNECTIONS: u64 = 100;
/// A WebSocket JSON RPC server.
#[derive(Debug)]
pub struct Server<M> {
listener: TcpListener,
cfg: Settings,
stop_monitor: StopMonitor,
resources: Resources,
middleware: M,
id_provider: Arc<dyn IdProvider>,
}
impl<M> std::fmt::Debug for Server<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Server")
.field("listener", &self.listener)
.field("cfg", &self.cfg)
.field("stop_monitor", &self.stop_monitor)
.finish()
}
}
impl<M: Middleware> Server<M> {
......@@ -109,6 +122,7 @@ impl<M: Middleware> Server<M> {
let methods = &methods;
let cfg = &self.cfg;
let id_provider = self.id_provider.clone();
connections.add(Box::pin(handshake(
socket,
......@@ -119,6 +133,7 @@ impl<M: Middleware> Server<M> {
cfg,
stop_monitor: &stop_monitor,
middleware: middleware.clone(),
id_provider,
},
)));
......@@ -196,6 +211,7 @@ enum HandshakeResponse<'a, M> {
cfg: &'a Settings,
stop_monitor: &'a StopMonitor,
middleware: M,
id_provider: Arc<dyn IdProvider>,
},
}
......@@ -219,7 +235,7 @@ where
Ok(())
}
HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware } => {
HandshakeResponse::Accept { conn_id, methods, resources, cfg, stop_monitor, middleware, id_provider } => {
tracing::debug!("Accepting new connection: {}", conn_id);
let key = {
let req = server.receive_request().await?;
......@@ -250,6 +266,7 @@ where
cfg.max_request_body_size,
stop_monitor.clone(),
middleware,
id_provider,
))
.await;
......@@ -269,6 +286,7 @@ async fn background_task(
max_request_body_size: u32,
stop_server: StopMonitor,
middleware: impl Middleware,
id_provider: Arc<dyn IdProvider>,
) -> Result<(), Error> {
// And we can finally transition to a websocket background_task.
let mut builder = server.into_builder();
......@@ -352,7 +370,7 @@ async fn background_task(
tracing::debug!("recv method call={}", req.method);
tracing::trace!("recv: req={:?}", req);
match methods.execute_with_resources(&sink, req, conn_id, &resources) {
match methods.execute_with_resources(&sink, req, conn_id, &resources, &*id_provider) {
Ok((name, MethodResult::Sync(success))) => {
middleware.on_result(name, success, request_start);
middleware.on_response(request_start);
......@@ -385,6 +403,7 @@ async fn background_task(
let resources = &resources;
let methods = &methods;
let sink = sink.clone();
let id_provider = id_provider.clone();
let fut = async move {
// Batch responses must be sent back as a single message so we read the results from each
......@@ -397,7 +416,13 @@ async fn background_task(
tracing::trace!("recv: batch={:?}", batch);
if !batch.is_empty() {
join_all(batch.into_iter().filter_map(move |req| {
match methods.execute_with_resources(&sink_batch, req, conn_id, resources) {
match methods.execute_with_resources(
&sink_batch,
req,
conn_id,
resources,
&*id_provider,
) {