Unverified Commit 8e945de4 authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

fix(rpc module): close subscription task when a subscription is `unsubscribed`...

fix(rpc module): close subscription task when a subscription is `unsubscribed` via the `unsubscribe call` (#743)

* refactor: remove SubscriptionSink::inner_send

* fix: close running task if unsubscribed

* Update core/src/server/rpc_module.rs

* Update core/src/server/rpc_module.rs

* fix nits

* Update core/src/server/rpc_module.rs

* add test for canceling subscriptions

* print subscription info; once per minute

* revert closure stuff

* Revert "print subscription info; once per minute"

This reverts commit 366176a8

.

* use tokio::sync::watch instead of oneshot

The receiver is clonable and it's possible to check whether the sender is still alive

* Update tests/tests/helpers.rs
Co-authored-by: David's avatarDavid <dvdplm@gmail.com>

* Update core/src/server/rpc_module.rs
Co-authored-by: David's avatarDavid <dvdplm@gmail.com>

* grumbles: use unwrap in tests

* add test for reuse pipe_from_stream
Co-authored-by: David's avatarDavid <dvdplm@gmail.com>
parent 9decd23c
Pipeline #191160 passed with stages
in 4 minutes and 57 seconds
......@@ -48,7 +48,7 @@ use jsonrpsee_types::{
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::Notify;
use tokio::sync::{watch, Notify};
/// A `MethodCallback` is an RPC endpoint, callable with a standard JSON-RPC request,
/// implemented as a function pointer to a `Fn` function taking four arguments:
......@@ -98,7 +98,7 @@ impl<'a> std::fmt::Debug for ConnState<'a> {
}
}
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, Arc<()>)>>>;
type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, watch::Sender<()>)>>>;
/// Represent a unique subscription entry based on [`RpcSubscriptionId`] and [`ConnectionId`].
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
......@@ -794,9 +794,9 @@ impl PendingSubscription {
let InnerPendingSubscription { sink, close_notify, method, uniq_sub, subscribers, id } = inner;
if sink.send_response(id, &uniq_sub.sub_id) {
let active_sub = Arc::new(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), active_sub.clone()));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, active_sub })
let (tx, rx) = watch::channel(());
subscribers.lock().insert(uniq_sub.clone(), (sink.clone(), tx));
Some(SubscriptionSink { inner: sink, close_notify, method, uniq_sub, subscribers, unsubscribe: rx })
} else {
None
}
......@@ -826,7 +826,8 @@ pub struct SubscriptionSink {
uniq_sub: SubscriptionKey,
/// Shared Mutex of subscriptions for this method.
subscribers: Subscribers,
active_sub: Arc<()>,
/// Future that returns when the unsubscribe method has been called.
unsubscribe: watch::Receiver<()>,
}
impl SubscriptionSink {
......@@ -843,7 +844,7 @@ impl SubscriptionSink {
}
let msg = self.build_message(result)?;
Ok(self.inner_send(msg))
Ok(self.inner.send_raw(msg).is_ok())
}
/// Reads data from the `stream` and sends back data on the subscription
......@@ -881,7 +882,7 @@ impl SubscriptionSink {
/// SubscriptionClosed::Failed(e) => {
/// sink.close(e);
/// }
/// };
/// }
/// });
/// });
/// ```
......@@ -891,14 +892,23 @@ impl SubscriptionSink {
T: Serialize,
E: std::fmt::Display,
{
let close_notify = match self.close_notify.clone() {
let conn_closed = match self.close_notify.clone() {
Some(close_notify) => close_notify,
None => return SubscriptionClosed::RemotePeerAborted,
None => {
return SubscriptionClosed::RemotePeerAborted;
}
};
let mut sub_closed = self.unsubscribe.clone();
let sub_closed_fut = sub_closed.changed();
let conn_closed_fut = conn_closed.notified();
pin_mut!(conn_closed_fut);
pin_mut!(sub_closed_fut);
let mut stream_item = stream.try_next();
let closed_fut = close_notify.notified();
pin_mut!(closed_fut);
let mut closed_fut = futures_util::future::select(conn_closed_fut, sub_closed_fut);
loop {
match futures_util::future::select(stream_item, closed_fut).await {
// The app sent us a value to send back to the subscribers
......@@ -922,7 +932,7 @@ impl SubscriptionSink {
break SubscriptionClosed::Failed(err);
}
Either::Left((Ok(None), _)) => break SubscriptionClosed::Success,
Either::Right(((), _)) => {
Either::Right((_, _)) => {
break SubscriptionClosed::RemotePeerAborted;
}
}
......@@ -956,13 +966,13 @@ impl SubscriptionSink {
self.pipe_from_try_stream::<_, _, Error>(stream.map(|item| Ok(item))).await
}
/// Returns whether this channel is closed without needing a context.
/// Returns whether the subscription is closed.
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || self.close_notify.is_none()
self.inner.is_closed() || self.close_notify.is_none() || !self.is_active_subscription()
}
fn is_active_subscription(&self) -> bool {
Arc::strong_count(&self.active_sub) > 1
!self.unsubscribe.has_changed().is_err()
}
fn build_message<T: Serialize>(&self, result: &T) -> Result<String, serde_json::Error> {
......@@ -981,14 +991,6 @@ impl SubscriptionSink {
.map_err(Into::into)
}
fn inner_send(&mut self, msg: String) -> bool {
if self.is_active_subscription() {
self.inner.send_raw(msg).is_ok()
} else {
false
}
}
/// Close the subscription, sending a notification with a special `error` field containing the provided error.
///
/// This can be used to signal an actual error, or just to signal that the subscription has been closed,
......
......@@ -27,10 +27,14 @@
use std::net::SocketAddr;
use std::time::Duration;
use futures::{SinkExt, StreamExt};
use jsonrpsee::core::error::SubscriptionClosed;
use jsonrpsee::http_server::{AccessControl, HttpServerBuilder, HttpServerHandle};
use jsonrpsee::types::error::{ErrorObject, SUBSCRIPTION_CLOSED_WITH_ERROR};
use jsonrpsee::ws_server::{WsServerBuilder, WsServerHandle};
use jsonrpsee::RpcModule;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle) {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
......@@ -40,10 +44,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
module
.register_subscription("subscribe_hello", "subscribe_hello", "unsubscribe_hello", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
if let Ok(false) = sink.send(&"hello from subscription") {
break;
......@@ -55,10 +56,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
module
.register_subscription("subscribe_foo", "subscribe_foo", "unsubscribe_foo", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
if let Ok(false) = sink.send(&1337_usize) {
break;
......@@ -75,10 +73,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
_ => return,
};
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let mut sink = pending.accept().unwrap();
std::thread::spawn(move || loop {
count = count.wrapping_add(1);
......@@ -92,10 +87,7 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
module
.register_subscription("subscribe_noop", "subscribe_noop", "unsubscribe_noop", |_, pending, _| {
let sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
let sink = pending.accept().unwrap();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_secs(1));
let err = ErrorObject::owned(
......@@ -108,6 +100,73 @@ pub async fn websocket_server_with_subscription() -> (SocketAddr, WsServerHandle
})
.unwrap();
module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| {
let mut sink = pending.accept().unwrap();
tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);
match sink.pipe_from_stream(stream).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
}
});
})
.unwrap();
module
.register_subscription("can_reuse_subscription", "n", "u_can_reuse_subscription", |_, pending, _| {
let mut sink = pending.accept().unwrap();
tokio::spawn(async move {
let stream1 = IntervalStream::new(interval(Duration::from_millis(50)))
.zip(futures::stream::iter(1..=5))
.map(|(_, c)| c);
let stream2 = IntervalStream::new(interval(Duration::from_millis(50)))
.zip(futures::stream::iter(6..=10))
.map(|(_, c)| c);
let result = sink.pipe_from_stream(stream1).await;
assert!(matches!(result, SubscriptionClosed::Success));
match sink.pipe_from_stream(stream2).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
}
});
})
.unwrap();
module
.register_subscription(
"subscribe_with_err_on_stream",
"n",
"unsubscribe_with_err_on_stream",
move |_, pending, _| {
let mut sink = pending.accept().unwrap();
let err: &'static str = "error on the stream";
// create stream that produce an error which will cancel the subscription.
let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]);
tokio::spawn(async move {
match sink.pipe_from_try_stream(stream).await {
SubscriptionClosed::Failed(e) => {
sink.close(e);
}
_ => unreachable!(),
}
});
},
)
.unwrap();
let addr = server.local_addr().unwrap();
let server_handle = server.start(module).unwrap();
......@@ -133,6 +192,31 @@ pub async fn websocket_server() -> SocketAddr {
addr
}
/// Yields one item then sleeps for an hour.
pub async fn websocket_server_with_sleeping_subscription(tx: futures::channel::mpsc::Sender<()>) -> SocketAddr {
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let addr = server.local_addr().unwrap();
let mut module = RpcModule::new(tx);
module
.register_subscription("subscribe_sleep", "n", "unsubscribe_sleep", |_, pending, mut tx| {
let mut sink = pending.accept().unwrap();
tokio::spawn(async move {
let interval = interval(Duration::from_secs(60 * 60));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);
sink.pipe_from_stream(stream).await;
let send_back = std::sync::Arc::make_mut(&mut tx);
send_back.send(()).await.unwrap();
});
})
.unwrap();
server.start(module).unwrap();
addr
}
pub async fn http_server() -> (SocketAddr, HttpServerHandle) {
http_server_with_access_control(AccessControl::default()).await
}
......
......@@ -30,17 +30,14 @@
use std::sync::Arc;
use std::time::Duration;
use futures::TryStreamExt;
use futures::{channel::mpsc, StreamExt, TryStreamExt};
use helpers::{http_server, http_server_with_access_control, websocket_server, websocket_server_with_subscription};
use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT};
use jsonrpsee::core::error::SubscriptionClosed;
use jsonrpsee::core::{Error, JsonValue};
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::rpc_params;
use jsonrpsee::types::error::ErrorObject;
use jsonrpsee::ws_client::WsClientBuilder;
use tokio::time::interval;
use tokio_stream::wrappers::IntervalStream;
mod helpers;
......@@ -386,41 +383,14 @@ async fn ws_server_should_stop_subscription_after_client_drop() {
#[tokio::test]
async fn ws_server_cancels_subscriptions_on_reset_conn() {
use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());
let (tx, rx) = mpsc::channel(1);
let mut module = RpcModule::new(tx);
module
.register_subscription("subscribe_for_ever", "n", "unsubscribe_for_ever", |_, pending, mut tx| {
// Create stream that produce one item then sleeps for an hour.
let interval = interval(Duration::from_secs(60 * 60));
let stream = IntervalStream::new(interval).map(move |_| 0_usize);
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
tokio::spawn(async move {
sink.pipe_from_stream(stream).await;
let send_back = Arc::make_mut(&mut tx);
send_back.send(()).await.unwrap();
});
})
.unwrap();
server.start(module).unwrap();
let server_url = format!("ws://{}", helpers::websocket_server_with_sleeping_subscription(tx).await);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut subs = Vec::new();
for _ in 0..10 {
subs.push(client.subscribe::<usize>("subscribe_for_ever", None, "unsubscribe_for_ever").await.unwrap());
subs.push(client.subscribe::<usize>("subscribe_sleep", None, "unsubscribe_sleep").await.unwrap());
}
// terminate connection.
......@@ -433,38 +403,8 @@ async fn ws_server_cancels_subscriptions_on_reset_conn() {
#[tokio::test]
async fn ws_server_cancels_sub_stream_after_err() {
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};
let err: &'static str = "error on the stream";
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());
let mut module = RpcModule::new(());
module
.register_subscription(
"subscribe_with_err_on_stream",
"n",
"unsubscribe_with_err_on_stream",
move |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
// create stream that produce an error which will cancel the subscription.
let stream = futures::stream::iter(vec![Ok(1_u32), Err(err), Ok(2), Ok(3)]);
tokio::spawn(async move {
match sink.pipe_from_try_stream(stream).await {
SubscriptionClosed::Failed(e) => sink.close(e),
_ => unreachable!(),
};
});
},
)
.unwrap();
server.start(module).unwrap();
let (addr, _handle) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut sub: Subscription<serde_json::Value> =
......@@ -477,35 +417,8 @@ async fn ws_server_cancels_sub_stream_after_err() {
#[tokio::test]
async fn ws_server_subscribe_with_stream() {
use futures::StreamExt;
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());
let mut module = RpcModule::new(());
module
.register_subscription("subscribe_5_ints", "n", "unsubscribe_5_ints", |_, pending, _| {
let mut sink = match pending.accept() {
Some(sink) => sink,
_ => return,
};
tokio::spawn(async move {
let interval = interval(Duration::from_millis(50));
let stream = IntervalStream::new(interval).zip(futures::stream::iter(1..=5)).map(|(_, c)| c);
match sink.pipe_from_stream(stream).await {
SubscriptionClosed::Success => {
sink.close(SubscriptionClosed::Success);
}
_ => unreachable!(),
};
});
})
.unwrap();
server.start(module).unwrap();
let (addr, _handle) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut sub1: Subscription<usize> = client.subscribe("subscribe_5_ints", None, "unsubscribe_5_ints").await.unwrap();
......@@ -530,6 +443,37 @@ async fn ws_server_subscribe_with_stream() {
assert!(sub1.next().await.is_none());
}
#[tokio::test]
async fn ws_server_pipe_from_stream_should_cancel_tasks_immediately() {
let (tx, rx) = mpsc::channel(1);
let server_url = format!("ws://{}", helpers::websocket_server_with_sleeping_subscription(tx).await);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
let mut subs = Vec::new();
for _ in 0..10 {
subs.push(client.subscribe::<i32>("subscribe_sleep", None, "unsubscribe_sleep").await.unwrap())
}
// This will call the `unsubscribe method`.
drop(subs);
let rx_len = rx.take(10).fold(0, |acc, _| async move { acc + 1 }).await;
assert_eq!(rx_len, 10);
}
#[tokio::test]
async fn ws_server_pipe_from_stream_can_be_reused() {
let (addr, _handle) = websocket_server_with_subscription().await;
let client = WsClientBuilder::default().build(&format!("ws://{}", addr)).await.unwrap();
let sub = client.subscribe::<i32>("can_reuse_subscription", None, "u_can_reuse_subscription").await.unwrap();
let items = sub.fold(0, |acc, _| async move { acc + 1 }).await;
assert_eq!(items, 10);
}
#[tokio::test]
async fn ws_batch_works() {
let server_addr = websocket_server().await;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment