Unverified Commit 32d29259 authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

clients: request ID as RAII guard (#543)

* request ID as RAII guard

* clippify

* fmt

* address grumbles: naming

`RequestIdGuard` -> `RequestIdManager`
`RequestId` -> `RequestIdGuard`
parent a8796c61
......@@ -28,7 +28,7 @@ use crate::transport::HttpTransportClient;
use crate::types::{
traits::Client,
v2::{Id, NotificationSer, ParamsSer, RequestSer, Response, RpcError},
CertificateStore, Error, RequestIdGuard, TEN_MB_SIZE_BYTES,
CertificateStore, Error, RequestIdManager, TEN_MB_SIZE_BYTES,
};
use async_trait::async_trait;
use fnv::FnvHashMap;
......@@ -75,7 +75,7 @@ impl HttpClientBuilder {
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
transport,
id_guard: RequestIdGuard::new(self.max_concurrent_requests),
id_manager: RequestIdManager::new(self.max_concurrent_requests),
request_timeout: self.request_timeout,
})
}
......@@ -100,7 +100,7 @@ pub struct HttpClient {
/// Request timeout. Defaults to 60sec.
request_timeout: Duration,
/// Request ID manager.
id_guard: RequestIdGuard,
id_manager: RequestIdManager,
}
#[async_trait]
......@@ -120,27 +120,20 @@ impl Client for HttpClient {
where
R: DeserializeOwned,
{
// NOTE: the IDs wrap on overflow which is intended.
let id = self.id_guard.next_request_id()?;
let request = RequestSer::new(Id::Number(id), method, params);
let fut = self.transport.send_and_read_body(serde_json::to_string(&request).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?);
let id = self.id_manager.next_request_id()?;
let request = RequestSer::new(Id::Number(*id.inner()), method, params);
let fut = self.transport.send_and_read_body(serde_json::to_string(&request).map_err(Error::ParseError)?);
let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
Err(_e) => {
self.id_guard.reclaim_request_id();
return Err(Error::RequestTimeout);
}
Ok(Err(e)) => {
self.id_guard.reclaim_request_id();
return Err(Error::Transport(e.into()));
}
};
self.id_guard.reclaim_request_id();
let response: Response<_> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(_) => {
......@@ -151,7 +144,7 @@ impl Client for HttpClient {
let response_id = response.id.as_number().copied().ok_or(Error::InvalidRequestId)?;
if response_id == id {
if response_id == *id.inner() {
Ok(response.result)
} else {
Err(Error::InvalidRequestId)
......@@ -167,17 +160,14 @@ impl Client for HttpClient {
let mut ordered_requests = Vec::with_capacity(batch.len());
let mut request_set = FnvHashMap::with_capacity_and_hasher(batch.len(), Default::default());
let ids = self.id_guard.next_request_ids(batch.len())?;
let ids = self.id_manager.next_request_ids(batch.len())?;
for (pos, (method, params)) in batch.into_iter().enumerate() {
batch_request.push(RequestSer::new(Id::Number(ids[pos]), method, params));
ordered_requests.push(ids[pos]);
request_set.insert(ids[pos], pos);
batch_request.push(RequestSer::new(Id::Number(ids.inner()[pos]), method, params));
ordered_requests.push(ids.inner()[pos]);
request_set.insert(ids.inner()[pos], pos);
}
let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?);
let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);
let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
......@@ -185,16 +175,11 @@ impl Client for HttpClient {
Ok(Err(e)) => return Err(Error::Transport(e.into())),
};
let rps: Vec<Response<_>> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(_) => {
let err: RpcError = serde_json::from_slice(&body).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
return Err(Error::Request(err.to_string()));
}
};
let rps: Vec<Response<_>> =
serde_json::from_slice(&body).map_err(|_| match serde_json::from_slice::<RpcError>(&body) {
Ok(e) => Error::Request(e.to_string()),
Err(e) => Error::ParseError(e),
})?;
// NOTE: `R::default` is placeholder and will be replaced in loop below.
let mut responses = vec![R::default(); ordered_requests.len()];
......
......@@ -30,7 +30,8 @@ use futures_channel::{mpsc, oneshot};
use futures_util::{future::FutureExt, sink::SinkExt, stream::StreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
/// Subscription kind
#[derive(Debug)]
......@@ -188,65 +189,83 @@ impl<Notif> Drop for Subscription<Notif> {
#[derive(Debug)]
/// Keep track of request IDs.
pub struct RequestIdGuard {
pub struct RequestIdManager {
// Current pending requests.
current_pending: AtomicUsize,
current_pending: Arc<()>,
/// Max concurrent pending requests allowed.
max_concurrent_requests: usize,
/// Get the next request ID.
current_id: AtomicU64,
}
impl RequestIdGuard {
impl RequestIdManager {
/// Create a new `RequestIdGuard` with the provided concurrency limit.
pub fn new(limit: usize) -> Self {
Self { current_pending: AtomicUsize::new(0), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
Self { current_pending: Arc::new(()), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
}
fn get_slot(&self) -> Result<(), Error> {
self.current_pending
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val >= self.max_concurrent_requests {
None
} else {
Some(val + 1)
}
})
.map(|_| ())
.map_err(|_| Error::MaxSlotsExceeded)
fn get_slot(&self) -> Result<Arc<()>, Error> {
// Strong count is 1 at start, so that's why we use `>` and not `>=`.
if Arc::strong_count(&self.current_pending) > self.max_concurrent_requests {
Err(Error::MaxSlotsExceeded)
} else {
Ok(self.current_pending.clone())
}
}
/// Attempts to get the next request ID.
///
/// Fails if request limit has been exceeded.
pub fn next_request_id(&self) -> Result<u64, Error> {
self.get_slot()?;
pub fn next_request_id(&self) -> Result<RequestIdGuard<u64>, Error> {
let rc = self.get_slot()?;
let id = self.current_id.fetch_add(1, Ordering::SeqCst);
Ok(id)
Ok(RequestIdGuard { _rc: rc, id })
}
/// Attempts to get the `n` number next IDs that only counts as one request.
///
/// Fails if request limit has been exceeded.
pub fn next_request_ids(&self, len: usize) -> Result<Vec<u64>, Error> {
self.get_slot()?;
let mut batch = Vec::with_capacity(len);
pub fn next_request_ids(&self, len: usize) -> Result<RequestIdGuard<Vec<u64>>, Error> {
let rc = self.get_slot()?;
let mut ids = Vec::with_capacity(len);
for _ in 0..len {
batch.push(self.current_id.fetch_add(1, Ordering::SeqCst));
ids.push(self.current_id.fetch_add(1, Ordering::SeqCst));
}
Ok(batch)
Ok(RequestIdGuard { _rc: rc, id: ids })
}
}
/// Reference counted request ID.
#[derive(Debug)]
pub struct RequestIdGuard<T> {
id: T,
/// Reference count decreased when dropped.
_rc: Arc<()>,
}
impl<T> RequestIdGuard<T> {
/// Get the actual ID.
pub fn inner(&self) -> &T {
&self.id
}
}
#[cfg(test)]
mod tests {
use super::RequestIdManager;
#[test]
fn request_id_guard_works() {
let manager = RequestIdManager::new(2);
let _first = manager.next_request_id().unwrap();
{
let _second = manager.next_request_ids(13).unwrap();
assert!(manager.next_request_id().is_err());
// second dropped here.
}
/// Decrease the currently pending counter by one (saturated at 0).
pub fn reclaim_request_id(&self) {
// NOTE we ignore the error here, since we are simply saturating at 0
let _ = self.current_pending.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val > 0 {
Some(val - 1)
} else {
None
}
});
assert!(manager.next_request_id().is_ok());
}
}
......
......@@ -28,7 +28,7 @@ use crate::transport::{Receiver as WsReceiver, Sender as WsSender, WsHandshakeEr
use crate::types::{
traits::{Client, SubscriptionClient},
v2::{Id, Notification, NotificationSer, ParamsSer, RequestSer, Response, RpcError, SubscriptionResponse},
BatchMessage, CertificateStore, Error, FrontToBack, RegisterNotificationMessage, RequestIdGuard, RequestMessage,
BatchMessage, CertificateStore, Error, FrontToBack, RegisterNotificationMessage, RequestIdManager, RequestMessage,
Subscription, SubscriptionKind, SubscriptionMessage, TEN_MB_SIZE_BYTES,
};
use crate::{
......@@ -98,7 +98,7 @@ pub struct WsClient {
/// Request timeout. Defaults to 60sec.
request_timeout: Duration,
/// Request ID manager.
id_guard: RequestIdGuard,
id_manager: RequestIdManager,
}
/// Builder for [`WsClient`].
......@@ -242,7 +242,7 @@ impl<'a> WsClientBuilder<'a> {
to_back,
request_timeout,
error: Mutex::new(ErrorFromBack::Unread(err_rx)),
id_guard: RequestIdGuard::new(max_concurrent_requests),
id_manager: RequestIdManager::new(max_concurrent_requests),
})
}
}
......@@ -273,12 +273,9 @@ impl Drop for WsClient {
impl Client for WsClient {
async fn notification<'a>(&self, method: &'a str, params: Option<ParamsSer<'a>>) -> Result<(), Error> {
// NOTE: we use this to guard against max number of concurrent requests.
let _req_id = self.id_guard.next_request_id()?;
let _req_id = self.id_manager.next_request_id()?;
let notif = NotificationSer::new(method, params);
let raw = serde_json::to_string(&notif).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
let raw = serde_json::to_string(&notif).map_err(Error::ParseError)?;
tracing::trace!("[frontend]: send notification: {:?}", raw);
let mut sender = self.to_back.clone();
......@@ -291,7 +288,6 @@ impl Client for WsClient {
_ = timeout => return Err(Error::RequestTimeout)
};
self.id_guard.reclaim_request_id();
match res {
Ok(()) => Ok(()),
Err(_) => Err(self.read_error_from_backend().await),
......@@ -303,27 +299,22 @@ impl Client for WsClient {
R: DeserializeOwned,
{
let (send_back_tx, send_back_rx) = oneshot::channel();
let req_id = self.id_guard.next_request_id()?;
let raw = serde_json::to_string(&RequestSer::new(Id::Number(req_id), method, params)).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
let req_id = self.id_manager.next_request_id()?;
let id = *req_id.inner();
let raw = serde_json::to_string(&RequestSer::new(Id::Number(id), method, params)).map_err(Error::ParseError)?;
tracing::trace!("[frontend]: send request: {:?}", raw);
if self
.to_back
.clone()
.send(FrontToBack::Request(RequestMessage { raw, id: req_id, send_back: Some(send_back_tx) }))
.send(FrontToBack::Request(RequestMessage { raw, id, send_back: Some(send_back_tx) }))
.await
.is_err()
{
self.id_guard.reclaim_request_id();
return Err(self.read_error_from_backend().await);
}
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
self.id_guard.reclaim_request_id();
let json_value = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
......@@ -336,34 +327,28 @@ impl Client for WsClient {
where
R: DeserializeOwned + Default + Clone,
{
let batch_ids = self.id_guard.next_request_ids(batch.len())?;
let batch_ids = self.id_manager.next_request_ids(batch.len())?;
let mut batches = Vec::with_capacity(batch.len());
for (idx, (method, params)) in batch.into_iter().enumerate() {
batches.push(RequestSer::new(Id::Number(batch_ids[idx]), method, params));
batches.push(RequestSer::new(Id::Number(batch_ids.inner()[idx]), method, params));
}
let (send_back_tx, send_back_rx) = oneshot::channel();
let raw = serde_json::to_string(&batches).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
let raw = serde_json::to_string(&batches).map_err(Error::ParseError)?;
tracing::trace!("[frontend]: send batch request: {:?}", raw);
if self
.to_back
.clone()
.send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids, send_back: send_back_tx }))
.send(FrontToBack::Batch(BatchMessage { raw, ids: batch_ids.inner().clone(), send_back: send_back_tx }))
.await
.is_err()
{
self.id_guard.reclaim_request_id();
return Err(self.read_error_from_backend().await);
}
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
self.id_guard.reclaim_request_id();
let json_values = match res {
Ok(Ok(v)) => v,
Ok(Err(err)) => return Err(err),
......@@ -397,12 +382,9 @@ impl SubscriptionClient for WsClient {
return Err(Error::SubscriptionNameConflict(unsubscribe_method.to_owned()));
}
let ids = self.id_guard.next_request_ids(2)?;
let raw =
serde_json::to_string(&RequestSer::new(Id::Number(ids[0]), subscribe_method, params)).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
let ids = self.id_manager.next_request_ids(2)?;
let raw = serde_json::to_string(&RequestSer::new(Id::Number(ids.inner()[0]), subscribe_method, params))
.map_err(Error::ParseError)?;
let (send_back_tx, send_back_rx) = oneshot::channel();
if self
......@@ -410,21 +392,19 @@ impl SubscriptionClient for WsClient {
.clone()
.send(FrontToBack::Subscribe(SubscriptionMessage {
raw,
subscribe_id: ids[0],
unsubscribe_id: ids[1],
subscribe_id: ids.inner()[0],
unsubscribe_id: ids.inner()[1],
unsubscribe_method: unsubscribe_method.to_owned(),
send_back: send_back_tx,
}))
.await
.is_err()
{
self.id_guard.reclaim_request_id();
return Err(self.read_error_from_backend().await);
}
let res = call_with_timeout(self.request_timeout, send_back_rx).await;
self.id_guard.reclaim_request_id();
let (notifs_rx, id) = match res {
Ok(Ok(val)) => val,
Ok(Err(err)) => return Err(err),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment