Unverified Commit a13ae7a2 authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

feat(clients): add explicit unsubscribe API (#789)

* feat(clients): add explicit unsubscribe API

* add tests for unsubscribe

* rephrase bad english
parent 6ca64b74
Pipeline #199594 passed with stages
in 5 minutes and 52 seconds
......@@ -84,7 +84,7 @@ pub(crate) fn process_subscription_response(
let request_id = match manager.get_request_id_by_subscription_id(&sub_id) {
Some(request_id) => request_id,
None => {
tracing::error!("Subscription ID: {:?} is not an active subscription", sub_id);
tracing::warn!("Subscription ID: {:?} is not an active subscription", sub_id);
return Err(None);
}
};
......@@ -100,7 +100,7 @@ pub(crate) fn process_subscription_response(
}
},
None => {
tracing::error!("Subscription ID: {:?} is not an active subscription", sub_id);
tracing::warn!("Subscription ID: {:?} is not an active subscription", sub_id);
Err(None)
}
}
......
......@@ -182,7 +182,7 @@ macro_rules! rpc_params {
}
/// Subscription kind
#[derive(Debug)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum SubscriptionKind {
/// Get notifications based on Subscription ID.
......@@ -202,7 +202,7 @@ pub struct Subscription<Notif> {
/// Channel from which we receive notifications from the server, as encoded `JsonValue`s.
notifs_rx: mpsc::Receiver<JsonValue>,
/// Callback kind.
kind: SubscriptionKind,
kind: Option<SubscriptionKind>,
/// Marker in order to pin the `Notif` parameter.
marker: PhantomData<Notif>,
}
......@@ -218,12 +218,25 @@ impl<Notif> Subscription<Notif> {
notifs_rx: mpsc::Receiver<JsonValue>,
kind: SubscriptionKind,
) -> Self {
Self { to_back, notifs_rx, kind, marker: PhantomData }
Self { to_back, notifs_rx, kind: Some(kind), marker: PhantomData }
}
/// Return the subscription type and, if applicable, ID.
pub fn kind(&self) -> &SubscriptionKind {
&self.kind
self.kind.as_ref().expect("only None after unsubscribe; qed")
}
/// Unsubscribe and consume the subscription.
pub async fn unsubscribe(mut self) -> Result<(), Error> {
let msg = match self.kind.take().expect("only None after unsubscribe; qed") {
SubscriptionKind::Method(notif) => FrontToBack::UnregisterNotification(notif),
SubscriptionKind::Subscription(sub_id) => FrontToBack::SubscriptionClosed(sub_id),
};
self.to_back.send(msg).await?;
// wait until notif channel is closed then the subscription was closed.
while self.notifs_rx.next().await.is_some() {}
Ok(())
}
}
......@@ -338,11 +351,11 @@ impl<Notif> Drop for Subscription<Notif> {
// the channel's buffer will be full.
// However, when a notification arrives, the background task will realize that the channel
// to the `Callback` has been closed.
let kind = std::mem::replace(&mut self.kind, SubscriptionKind::Subscription(SubscriptionId::Num(0)));
let msg = match kind {
SubscriptionKind::Method(notif) => FrontToBack::UnregisterNotification(notif),
SubscriptionKind::Subscription(sub_id) => FrontToBack::SubscriptionClosed(sub_id),
let msg = match self.kind.take() {
Some(SubscriptionKind::Method(notif)) => FrontToBack::UnregisterNotification(notif),
Some(SubscriptionKind::Subscription(sub_id)) => FrontToBack::SubscriptionClosed(sub_id),
None => return,
};
let _ = self.to_back.send(msg).now_or_never();
}
......
......@@ -45,8 +45,16 @@ use tokio_stream::wrappers::IntervalStream;
mod helpers;
fn init_logger() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
}
#[tokio::test]
async fn ws_subscription_works() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
......@@ -62,8 +70,41 @@ async fn ws_subscription_works() {
}
}
#[tokio::test]
async fn ws_unsubscription_works() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().max_concurrent_requests(1).build(&server_url).await.unwrap();
let mut sub: Subscription<usize> = client.subscribe("subscribe_foo", None, "unsubscribe_foo").await.unwrap();
// It's technically possible to have race-conditions between the notifications and the unsubscribe message.
// So let's wait for the first notification and then unsubscribe.
let _item = sub.next().await.unwrap().unwrap();
sub.unsubscribe().await.unwrap();
let mut success = false;
// Wait until a slot is available, as only one concurrent call is allowed.
// Then when this finishes we know that unsubscribe call has been finished.
for _ in 0..30 {
if client.request::<String>("say_hello", rpc_params![]).await.is_ok() {
success = true;
break;
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
assert!(success);
}
#[tokio::test]
async fn ws_subscription_with_input_works() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
......@@ -78,6 +119,8 @@ async fn ws_subscription_with_input_works() {
#[tokio::test]
async fn ws_method_call_works() {
init_logger();
let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
......@@ -87,6 +130,8 @@ async fn ws_method_call_works() {
#[tokio::test]
async fn ws_method_call_str_id_works() {
init_logger();
let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().id_format(IdKind::String).build(&server_url).await.unwrap();
......@@ -96,6 +141,8 @@ async fn ws_method_call_str_id_works() {
#[tokio::test]
async fn http_method_call_works() {
init_logger();
let (server_addr, _handle) = http_server().await;
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().build(&uri).unwrap();
......@@ -105,6 +152,8 @@ async fn http_method_call_works() {
#[tokio::test]
async fn http_method_call_str_id_works() {
init_logger();
let (server_addr, _handle) = http_server().await;
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().id_format(IdKind::String).build(&uri).unwrap();
......@@ -114,6 +163,8 @@ async fn http_method_call_str_id_works() {
#[tokio::test]
async fn http_concurrent_method_call_limits_works() {
init_logger();
let (server_addr, _handle) = http_server().await;
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(1).build(&uri).unwrap();
......@@ -127,6 +178,8 @@ async fn http_concurrent_method_call_limits_works() {
#[tokio::test]
async fn ws_subscription_several_clients() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
......@@ -143,6 +196,8 @@ async fn ws_subscription_several_clients() {
#[tokio::test]
async fn ws_subscription_several_clients_with_drop() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
......@@ -189,6 +244,8 @@ async fn ws_subscription_several_clients_with_drop() {
#[tokio::test]
async fn ws_subscription_without_polling_doesnt_make_client_unuseable() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
......@@ -219,6 +276,8 @@ async fn ws_subscription_without_polling_doesnt_make_client_unuseable() {
#[tokio::test]
async fn ws_making_more_requests_than_allowed_should_not_deadlock() {
init_logger();
let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = Arc::new(WsClientBuilder::default().max_concurrent_requests(2).build(&server_url).await.unwrap());
......@@ -237,6 +296,8 @@ async fn ws_making_more_requests_than_allowed_should_not_deadlock() {
#[tokio::test]
async fn http_making_more_requests_than_allowed_should_not_deadlock() {
init_logger();
let (server_addr, _handle) = http_server().await;
let server_url = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(2).build(&server_url).unwrap();
......@@ -256,6 +317,8 @@ async fn http_making_more_requests_than_allowed_should_not_deadlock() {
#[tokio::test]
async fn https_works() {
init_logger();
let client = HttpClientBuilder::default().build("https://kusama-rpc.polkadot.io:443").unwrap();
let response: String = client.request("system_chain", None).await.unwrap();
assert_eq!(&response, "Kusama");
......@@ -263,6 +326,8 @@ async fn https_works() {
#[tokio::test]
async fn wss_works() {
init_logger();
let client = WsClientBuilder::default().build("wss://kusama-rpc.polkadot.io:443").await.unwrap();
let response: String = client.request("system_chain", None).await.unwrap();
assert_eq!(&response, "Kusama");
......@@ -270,18 +335,24 @@ async fn wss_works() {
#[tokio::test]
async fn ws_with_non_ascii_url_doesnt_hang_or_panic() {
init_logger();
let err = WsClientBuilder::default().build("wss://♥♥♥♥♥♥∀∂").await;
assert!(matches!(err, Err(Error::Transport(_))));
}
#[tokio::test]
async fn http_with_non_ascii_url_doesnt_hang_or_panic() {
init_logger();
let err = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂");
assert!(matches!(err, Err(Error::Transport(_))));
}
#[tokio::test]
async fn ws_unsubscribe_releases_request_slots() {
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
......@@ -294,10 +365,7 @@ async fn ws_unsubscribe_releases_request_slots() {
#[tokio::test]
async fn server_should_be_able_to_close_subscriptions() {
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init()
.expect("setting default subscriber failed");
init_logger();
let (server_addr, _) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
......@@ -311,6 +379,8 @@ async fn server_should_be_able_to_close_subscriptions() {
#[tokio::test]
async fn ws_close_pending_subscription_when_server_terminated() {
init_logger();
let (server_addr, handle) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
......@@ -345,6 +415,8 @@ async fn ws_server_should_stop_subscription_after_client_drop() {
use futures::{channel::mpsc, SinkExt, StreamExt};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};
init_logger();
let server = WsServerBuilder::default().build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());
......@@ -385,6 +457,8 @@ async fn ws_server_should_stop_subscription_after_client_drop() {
#[tokio::test]
async fn ws_server_cancels_subscriptions_on_reset_conn() {
init_logger();
let (tx, rx) = mpsc::channel(1);
let server_url = format!("ws://{}", helpers::websocket_server_with_sleeping_subscription(tx).await);
......@@ -405,6 +479,8 @@ async fn ws_server_cancels_subscriptions_on_reset_conn() {
#[tokio::test]
async fn ws_server_cancels_sub_stream_after_err() {
init_logger();
let (addr, _handle) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", addr);
......@@ -419,6 +495,8 @@ async fn ws_server_cancels_sub_stream_after_err() {
#[tokio::test]
async fn ws_server_subscribe_with_stream() {
init_logger();
let (addr, _handle) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", addr);
......@@ -447,6 +525,8 @@ async fn ws_server_subscribe_with_stream() {
#[tokio::test]
async fn ws_server_pipe_from_stream_should_cancel_tasks_immediately() {
init_logger();
let (tx, rx) = mpsc::channel(1);
let server_url = format!("ws://{}", helpers::websocket_server_with_sleeping_subscription(tx).await);
......@@ -467,6 +547,8 @@ async fn ws_server_pipe_from_stream_should_cancel_tasks_immediately() {
#[tokio::test]
async fn ws_server_pipe_from_stream_can_be_reused() {
init_logger();
let (addr, _handle) = websocket_server_with_subscription().await;
let client = WsClientBuilder::default().build(&format!("ws://{}", addr)).await.unwrap();
let sub = client.subscribe::<i32>("can_reuse_subscription", None, "u_can_reuse_subscription").await.unwrap();
......@@ -478,6 +560,8 @@ async fn ws_server_pipe_from_stream_can_be_reused() {
#[tokio::test]
async fn ws_batch_works() {
init_logger();
let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
......@@ -497,6 +581,8 @@ async fn ws_server_limit_subs_per_conn_works() {
use jsonrpsee::types::error::{CallError, TOO_MANY_SUBSCRIPTIONS_CODE, TOO_MANY_SUBSCRIPTIONS_MSG};
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};
init_logger();
let server = WsServerBuilder::default().max_subscriptions_per_connection(10).build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());
......@@ -554,6 +640,8 @@ async fn ws_server_unsub_methods_should_ignore_sub_limit() {
use jsonrpsee::core::client::SubscriptionKind;
use jsonrpsee::{ws_server::WsServerBuilder, RpcModule};
init_logger();
let server = WsServerBuilder::default().max_subscriptions_per_connection(10).build("127.0.0.1:0").await.unwrap();
let server_url = format!("ws://{}", server.local_addr().unwrap());
......@@ -608,6 +696,8 @@ async fn ws_server_unsub_methods_should_ignore_sub_limit() {
async fn http_unsupported_methods_dont_work() {
use hyper::{Body, Client, Method, Request};
init_logger();
let (server_addr, _handle) = http_server().await;
let http_client = Client::new();
......@@ -637,6 +727,8 @@ async fn http_unsupported_methods_dont_work() {
async fn http_correct_content_type_required() {
use hyper::{Body, Client, Method, Request};
init_logger();
let (server_addr, _handle) = http_server().await;
let http_client = Client::new();
......@@ -680,6 +772,8 @@ async fn http_cors_preflight_works() {
use hyper::{Body, Client, Method, Request};
use jsonrpsee::http_server::AccessControlBuilder;
init_logger();
let acl = AccessControlBuilder::new().set_allowed_origins(vec!["https://foo.com"]).unwrap().build();
let (server_addr, _handle) = http_server_with_access_control(acl).await;
......@@ -743,6 +837,8 @@ fn comma_separated_header_values(headers: &hyper::HeaderMap, header: &str) -> Ve
#[tokio::test]
async fn ws_subscribe_with_bad_params() {
init_logger();
let (server_addr, _handle) = websocket_server_with_subscription().await;
let server_url = format!("ws://{}", server_addr);
let client = WsClientBuilder::default().build(&server_url).await.unwrap();
......@@ -758,6 +854,8 @@ async fn ws_subscribe_with_bad_params() {
async fn http_health_api_works() {
use hyper::{Body, Client, Request};
init_logger();
let (server_addr, _handle) = http_server().await;
let http_client = Client::new();
......@@ -777,6 +875,8 @@ async fn http_health_api_works() {
async fn ws_host_filtering_wildcard_works() {
use jsonrpsee::ws_server::*;
init_logger();
let acl = AccessControlBuilder::default()
.set_allowed_hosts(vec!["http://localhost:*", "http://127.0.0.1:*"])
.unwrap()
......@@ -799,6 +899,8 @@ async fn ws_host_filtering_wildcard_works() {
async fn http_host_filtering_wildcard_works() {
use jsonrpsee::http_server::*;
init_logger();
let acl = AccessControlBuilder::default()
.set_allowed_hosts(vec!["http://localhost:*", "http://127.0.0.1:*"])
.unwrap()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment