Unverified Commit 1fbaf45f authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

fix(client): adjust TransportSenderT (#852)



* fix(client): adjust TransportSenderT

This is trait contains `WebSocket` specific details and it's difficult to fix it properly with an extension trait
in the current design.

So this PR documents and marks it clearly that these methods are optional to implement, kind of ugly but better.

* fix build

* Update core/src/client/mod.rs

Co-authored-by: default avatarAlexandru Vasile <60601340+lexnv@users.noreply.github.com>

* Update core/src/client/mod.rs

Co-authored-by: default avatarAlexandru Vasile <60601340+lexnv@users.noreply.github.com>

* increase margin for failing test

* Update core/src/client/mod.rs

Co-authored-by: James Wilson's avatarJames Wilson <james@jsdw.me>

* remove optional prefix in trait

Co-authored-by: default avatarAlexandru Vasile <60601340+lexnv@users.noreply.github.com>
Co-authored-by: James Wilson's avatarJames Wilson <james@jsdw.me>
parent 80bcef02
Pipeline #207825 passed with stages
in 4 minutes and 46 seconds
...@@ -54,15 +54,6 @@ impl TransportSenderT for Sender { ...@@ -54,15 +54,6 @@ impl TransportSenderT for Sender {
self.0.send(Message::Text(msg)).await.map_err(|e| Error::WebSocket(e))?; self.0.send(Message::Text(msg)).await.map_err(|e| Error::WebSocket(e))?;
Ok(()) Ok(())
} }
async fn send_ping(&mut self) -> Result<(), Self::Error> {
tracing::trace!("send ping - not implemented for wasm");
Err(Error::NotSupported)
}
async fn close(&mut self) -> Result<(), Error> {
Ok(())
}
} }
#[async_trait(?Send)] #[async_trait(?Send)]
......
...@@ -243,102 +243,108 @@ impl Drop for Client { ...@@ -243,102 +243,108 @@ impl Drop for Client {
#[async_trait] #[async_trait]
impl ClientT for Client { impl ClientT for Client {
async fn notification<'a>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<(), Error> { async fn notification<'a>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<(), Error> {
// NOTE: we use this to guard against max number of concurrent requests. // NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_manager.next_request_id()?; let _req_id = self.id_manager.next_request_id()?;
let notif = NotificationSer::new(method, params); let notif = NotificationSer::new(method, params);
let trace = RpcTracing::batch(); let trace = RpcTracing::batch();
async { async {
let raw = serde_json::to_string(&notif).map_err(Error::ParseError)?; let raw = serde_json::to_string(&notif).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length); tx_log_from_str(&raw, self.max_log_length);
let mut sender = self.to_back.clone(); let mut sender = self.to_back.clone();
let fut = sender.send(FrontToBack::Notification(raw)); let fut = sender.send(FrontToBack::Notification(raw));
match future::select(fut, Delay::new(self.request_timeout)).await { match future::select(fut, Delay::new(self.request_timeout)).await {
Either::Left((Ok(()), _)) => Ok(()), Either::Left((Ok(()), _)) => Ok(()),
Either::Left((Err(_), _)) => Err(self.read_error_from_backend().await), Either::Left((Err(_), _)) => Err(self.read_error_from_backend().await),
Either::Right((_, _)) => Err(Error::RequestTimeout), Either::Right((_, _)) => Err(Error::RequestTimeout),
} }
}.instrument(trace.into_span()).await }
} .instrument(trace.into_span())
.await
}
async fn request<'a, R>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<R, Error> async fn request<'a, R>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<R, Error>
where where
R: DeserializeOwned, R: DeserializeOwned,
{ {
let (send_back_tx, send_back_rx) = oneshot::channel(); let (send_back_tx, send_back_rx) = oneshot::channel();
let guard = self.id_manager.next_request_id()?; let guard = self.id_manager.next_request_id()?;
let id = guard.inner(); let id = guard.inner();
let trace = RpcTracing::method_call(method); let trace = RpcTracing::method_call(method);
async { async {
let raw = serde_json::to_string(&RequestSer::new(&id, method, params)).map_err(Error::ParseError)?; let raw = serde_json::to_string(&RequestSer::new(&id, method, params)).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length); tx_log_from_str(&raw, self.max_log_length);
if self if self
.to_back .to_back
.clone() .clone()
.send(FrontToBack::Request(RequestMessage { raw, id: id.clone(), send_back: Some(send_back_tx) })) .send(FrontToBack::Request(RequestMessage { raw, id: id.clone(), send_back: Some(send_back_tx) }))
.await .await
.is_err() .is_err()
{ {
return Err(self.read_error_from_backend().await); return Err(self.read_error_from_backend().await);
} }
let res = call_with_timeout(self.request_timeout, send_back_rx).await; let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let json_value = match res { let json_value = match res {
Ok(Ok(v)) => v, Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err), Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await), Err(_) => return Err(self.read_error_from_backend().await),
}; };
rx_log_from_json(&Response::new(&json_value, id), self.max_log_length); rx_log_from_json(&Response::new(&json_value, id), self.max_log_length);
serde_json::from_value(json_value).map_err(Error::ParseError) serde_json::from_value(json_value).map_err(Error::ParseError)
}.instrument(trace.into_span()).await }
} .instrument(trace.into_span())
.await
}
async fn batch_request<'a, R>(&self, batch: Vec<(&'a str, Option<ParamsSer<'a>>)>) -> Result<Vec<R>, Error> async fn batch_request<'a, R>(&self, batch: Vec<(&'a str, Option<ParamsSer<'a>>)>) -> Result<Vec<R>, Error>
where where
R: DeserializeOwned + Default + Clone, R: DeserializeOwned + Default + Clone,
{ {
let trace = RpcTracing::batch(); let trace = RpcTracing::batch();
async { async {
let guard = self.id_manager.next_request_ids(batch.len())?; let guard = self.id_manager.next_request_ids(batch.len())?;
let batch_ids: Vec<Id> = guard.inner(); let batch_ids: Vec<Id> = guard.inner();
let mut batches = Vec::with_capacity(batch.len()); let mut batches = Vec::with_capacity(batch.len());
for (idx, (method, params)) in batch.into_iter().enumerate() { for (idx, (method, params)) in batch.into_iter().enumerate() {
batches.push(RequestSer::new(&batch_ids[idx], method, params)); batches.push(RequestSer::new(&batch_ids[idx], method, params));
} }
let (send_back_tx, send_back_rx) = oneshot::channel(); let (send_back_tx, send_back_rx) = oneshot::channel();
let raw = serde_json::to_string(&batches).map_err(Error::ParseError)?; let raw = serde_json::to_string(&batches).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length); tx_log_from_str(&raw, self.max_log_length);
if self if self
.to_back .to_back
.clone() .clone()
.send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids, send_back: send_back_tx })) .send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids, send_back: send_back_tx }))
.await .await
.is_err() .is_err()
{ {
return Err(self.read_error_from_backend().await); return Err(self.read_error_from_backend().await);
} }
let res = call_with_timeout(self.request_timeout, send_back_rx).await; let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let json_values = match res { let json_values = match res {
Ok(Ok(v)) => v, Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err), Ok(Err(err)) => return Err(err),
Err(_) => return Err(self.read_error_from_backend().await), Err(_) => return Err(self.read_error_from_backend().await),
}; };
rx_log_from_json(&json_values, self.max_log_length); rx_log_from_json(&json_values, self.max_log_length);
json_values.into_iter().map(|val| serde_json::from_value(val).map_err(Error::ParseError)).collect() json_values.into_iter().map(|val| serde_json::from_value(val).map_err(Error::ParseError)).collect()
}.instrument(trace.into_span()).await }
.instrument(trace.into_span())
.await
} }
} }
...@@ -356,52 +362,55 @@ impl SubscriptionClientT for Client { ...@@ -356,52 +362,55 @@ impl SubscriptionClientT for Client {
) -> Result<Subscription<N>, Error> ) -> Result<Subscription<N>, Error>
where where
N: DeserializeOwned, N: DeserializeOwned,
{ {
if subscribe_method == unsubscribe_method { if subscribe_method == unsubscribe_method {
return Err(Error::SubscriptionNameConflict(unsubscribe_method.to_owned())); return Err(Error::SubscriptionNameConflict(unsubscribe_method.to_owned()));
} }
let guard = self.id_manager.next_request_ids(2)?; let guard = self.id_manager.next_request_ids(2)?;
let mut ids: Vec<Id> = guard.inner(); let mut ids: Vec<Id> = guard.inner();
let trace = RpcTracing::method_call(subscribe_method); let trace = RpcTracing::method_call(subscribe_method);
async { async {
let id = ids[0].clone(); let id = ids[0].clone();
let raw = serde_json::to_string(&RequestSer::new(&id, subscribe_method, params)).map_err(Error::ParseError)?; let raw =
serde_json::to_string(&RequestSer::new(&id, subscribe_method, params)).map_err(Error::ParseError)?;
tx_log_from_str(&raw, self.max_log_length);
tx_log_from_str(&raw, self.max_log_length);
let (send_back_tx, send_back_rx) = oneshot::channel();
if self let (send_back_tx, send_back_rx) = oneshot::channel();
.to_back if self
.clone() .to_back
.send(FrontToBack::Subscribe(SubscriptionMessage { .clone()
raw, .send(FrontToBack::Subscribe(SubscriptionMessage {
subscribe_id: ids.swap_remove(0), raw,
unsubscribe_id: ids.swap_remove(0), subscribe_id: ids.swap_remove(0),
unsubscribe_method: unsubscribe_method.to_owned(), unsubscribe_id: ids.swap_remove(0),
send_back: send_back_tx, unsubscribe_method: unsubscribe_method.to_owned(),
})) send_back: send_back_tx,
.await }))
.is_err() .await
{ .is_err()
return Err(self.read_error_from_backend().await); {
} return Err(self.read_error_from_backend().await);
}
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
let (notifs_rx, sub_id) = match res {
Ok(Ok(val)) => val, let (notifs_rx, sub_id) = match res {
Ok(Err(err)) => return Err(err), Ok(Ok(val)) => val,
Err(_) => return Err(self.read_error_from_backend().await), Ok(Err(err)) => return Err(err),
}; Err(_) => return Err(self.read_error_from_backend().await),
};
rx_log_from_json(&Response::new(&sub_id, id), self.max_log_length);
rx_log_from_json(&Response::new(&sub_id, id), self.max_log_length);
Ok(Subscription::new(self.to_back.clone(), notifs_rx, SubscriptionKind::Subscription(sub_id)))
}.instrument(trace.into_span()).await Ok(Subscription::new(self.to_back.clone(), notifs_rx, SubscriptionKind::Subscription(sub_id)))
} }
.instrument(trace.into_span())
.await
}
/// Subscribe to a specific method. /// Subscribe to a specific method.
async fn subscribe_to_method<'a, N>(&self, method: &'a str) -> Result<Subscription<N>, Error> async fn subscribe_to_method<'a, N>(&self, method: &'a str) -> Result<Subscription<N>, Error>
......
...@@ -132,10 +132,18 @@ pub trait TransportSenderT: MaybeSend + 'static { ...@@ -132,10 +132,18 @@ pub trait TransportSenderT: MaybeSend + 'static {
/// Send. /// Send.
async fn send(&mut self, msg: String) -> Result<(), Self::Error>; async fn send(&mut self, msg: String) -> Result<(), Self::Error>;
/// This is optional because it's most likely relevant for WebSocket transports only.
/// You should only implement this is your transport supports sending periodic pings.
///
/// Send ping frame (opcode of 0x9). /// Send ping frame (opcode of 0x9).
async fn send_ping(&mut self) -> Result<(), Self::Error>; async fn send_ping(&mut self) -> Result<(), Self::Error> {
Ok(())
}
/// If the transport supports sending customized close messages. /// This is optional because it's most likely relevant for WebSocket transports only.
/// You should only implement this is your transport supports being closed.
///
/// Send customized close message.
async fn close(&mut self) -> Result<(), Self::Error> { async fn close(&mut self) -> Result<(), Self::Error> {
Ok(()) Ok(())
} }
......
...@@ -314,8 +314,8 @@ async fn multiple_blocking_calls_overlap() { ...@@ -314,8 +314,8 @@ async fn multiple_blocking_calls_overlap() {
assert_eq!(result.unwrap(), 42); assert_eq!(result.unwrap(), 42);
} }
// Each request takes 50ms, added 10ms margin for scheduling // Each request takes 50ms, added 50ms margin for scheduling
assert!(elapsed < Duration::from_millis(60), "Expected less than 60ms, got {:?}", elapsed); assert!(elapsed < Duration::from_millis(100), "Expected less than 100ms, got {:?}", elapsed);
} }
#[tokio::test] #[tokio::test]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment