Unverified Commit 5ea8cbe8 authored by David's avatar David Committed by GitHub
Browse files

Share the request id code between the http and websocket clients (#490)

* Move RequestIdGuard to types::client

* Use the RequestIdGuard type for the http client as well

* Fix batch request ID handling (one batch uses one slot)

* Add a few tests to check that max_concurrent_requests work for http clients

* Docs
parent 5ae280a3
......@@ -28,12 +28,11 @@ use crate::transport::HttpTransportClient;
use crate::types::{
traits::Client,
v2::{Id, NotificationSer, ParamsSer, RequestSer, Response, RpcError},
Error, TEN_MB_SIZE_BYTES,
Error, RequestIdGuard, TEN_MB_SIZE_BYTES,
};
use async_trait::async_trait;
use fnv::FnvHashMap;
use serde::de::DeserializeOwned;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
/// Http Client Builder.
......@@ -41,6 +40,7 @@ use std::time::Duration;
pub struct HttpClientBuilder {
max_request_body_size: u32,
request_timeout: Duration,
max_concurrent_requests: usize,
}
impl HttpClientBuilder {
......@@ -56,17 +56,31 @@ impl HttpClientBuilder {
self
}
/// Set max concurrent requests.
pub fn max_concurrent_requests(mut self, max: usize) -> Self {
self.max_concurrent_requests = max;
self
}
/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport =
HttpTransportClient::new(target, self.max_request_body_size).map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient { transport, request_id: AtomicU64::new(0), request_timeout: self.request_timeout })
Ok(HttpClient {
transport,
id_guard: RequestIdGuard::new(self.max_concurrent_requests),
request_timeout: self.request_timeout,
})
}
}
impl Default for HttpClientBuilder {
fn default() -> Self {
Self { max_request_body_size: TEN_MB_SIZE_BYTES, request_timeout: Duration::from_secs(60) }
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
request_timeout: Duration::from_secs(60),
max_concurrent_requests: 256,
}
}
}
......@@ -75,10 +89,10 @@ impl Default for HttpClientBuilder {
pub struct HttpClient {
/// HTTP transport client.
transport: HttpTransportClient,
/// Request ID that wraps around when overflowing.
request_id: AtomicU64,
/// Request timeout. Defaults to 60sec.
request_timeout: Duration,
/// Request ID manager.
id_guard: RequestIdGuard,
}
#[async_trait]
......@@ -98,17 +112,27 @@ impl Client for HttpClient {
where
R: DeserializeOwned,
{
// NOTE: `fetch_add` wraps on overflow which is intended.
let id = self.request_id.fetch_add(1, Ordering::SeqCst);
// NOTE: the IDs wrap on overflow which is intended.
let id = self.id_guard.next_request_id()?;
let request = RequestSer::new(Id::Number(id), method, params);
let fut = self.transport.send_and_read_body(serde_json::to_string(&request).map_err(Error::ParseError)?);
let fut = self.transport.send_and_read_body(serde_json::to_string(&request).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?);
let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
Err(_e) => return Err(Error::RequestTimeout),
Ok(Err(e)) => return Err(Error::Transport(e.into())),
Err(_e) => {
self.id_guard.reclaim_request_id();
return Err(Error::RequestTimeout);
}
Ok(Err(e)) => {
self.id_guard.reclaim_request_id();
return Err(Error::Transport(e.into()));
}
};
self.id_guard.reclaim_request_id();
let response: Response<_> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(_) => {
......@@ -135,14 +159,17 @@ impl Client for HttpClient {
let mut ordered_requests = Vec::with_capacity(batch.len());
let mut request_set = FnvHashMap::with_capacity_and_hasher(batch.len(), Default::default());
let ids = self.id_guard.next_request_ids(batch.len())?;
for (pos, (method, params)) in batch.into_iter().enumerate() {
let id = self.request_id.fetch_add(1, Ordering::SeqCst);
batch_request.push(RequestSer::new(Id::Number(id), method, params));
ordered_requests.push(id);
request_set.insert(id, pos);
batch_request.push(RequestSer::new(Id::Number(ids[pos]), method, params));
ordered_requests.push(ids[pos]);
request_set.insert(ids[pos], pos);
}
let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);
let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?);
let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
......@@ -153,7 +180,10 @@ impl Client for HttpClient {
let rps: Vec<Response<_>> = match serde_json::from_slice(&body) {
Ok(response) => response,
Err(_) => {
let err: RpcError = serde_json::from_slice(&body).map_err(Error::ParseError)?;
let err: RpcError = serde_json::from_slice(&body).map_err(|e| {
self.id_guard.reclaim_request_id();
Error::ParseError(e)
})?;
return Err(Error::Request(err.to_string()));
}
};
......
......@@ -93,6 +93,21 @@ async fn http_method_call_works() {
assert_eq!(&response, "hello");
}
#[tokio::test]
async fn http_concurrent_method_call_limits_works() {
let server_addr = http_server().await;
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(1).build(&uri).unwrap();
let (first, second) = tokio::join!(
client.request::<String>("say_hello", ParamsSer::NoParams),
client.request::<String>("say_hello", ParamsSer::NoParams),
);
assert!(first.is_ok());
assert!(matches!(second, Err(Error::MaxSlotsExceeded)));
}
#[tokio::test]
async fn ws_subscription_several_clients() {
let (server_addr, _) = websocket_server_with_subscription().await;
......@@ -187,7 +202,7 @@ async fn ws_subscription_without_polling_doesnt_make_client_unuseable() {
}
#[tokio::test]
async fn ws_more_request_than_buffer_should_not_deadlock() {
async fn ws_making_more_requests_than_allowed_should_not_deadlock() {
let server_addr = websocket_server().await;
let server_url = format!("ws://{}", server_addr);
let client = Arc::new(WsClientBuilder::default().max_concurrent_requests(2).build(&server_url).await.unwrap());
......@@ -204,6 +219,25 @@ async fn ws_more_request_than_buffer_should_not_deadlock() {
}
}
#[tokio::test]
async fn http_making_more_requests_than_allowed_should_not_deadlock() {
let server_addr = http_server().await;
let server_url = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().max_concurrent_requests(2).build(&server_url).unwrap();
let client = Arc::new(client);
let mut requests = Vec::new();
for _ in 0..6 {
let c = client.clone();
requests.push(tokio::spawn(async move { c.request::<String>("say_hello", ParamsSer::NoParams).await }));
}
for req in requests {
let _ = req.await.unwrap();
}
}
#[tokio::test]
#[ignore]
async fn https_works() {
......
......@@ -30,6 +30,7 @@ use futures_channel::{mpsc, oneshot};
use futures_util::{future::FutureExt, sink::SinkExt, stream::StreamExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
/// Subscription kind
#[derive(Debug)]
......@@ -184,3 +185,67 @@ impl<Notif> Drop for Subscription<Notif> {
let _ = self.to_back.send(msg).now_or_never();
}
}
#[derive(Debug)]
/// Keep track of request IDs.
pub struct RequestIdGuard {
// Current pending requests.
current_pending: AtomicUsize,
/// Max concurrent pending requests allowed.
max_concurrent_requests: usize,
/// Get the next request ID.
current_id: AtomicU64,
}
impl RequestIdGuard {
/// Create a new `RequestIdGuard` with the provided concurrency limit.
pub fn new(limit: usize) -> Self {
Self { current_pending: AtomicUsize::new(0), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
}
fn get_slot(&self) -> Result<(), Error> {
self.current_pending
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val >= self.max_concurrent_requests {
None
} else {
Some(val + 1)
}
})
.map(|_| ())
.map_err(|_| Error::MaxSlotsExceeded)
}
/// Attempts to get the next request ID.
///
/// Fails if request limit has been exceeded.
pub fn next_request_id(&self) -> Result<u64, Error> {
self.get_slot()?;
let id = self.current_id.fetch_add(1, Ordering::SeqCst);
Ok(id)
}
/// Attempts to get the `n` number next IDs that only counts as one request.
///
/// Fails if request limit has been exceeded.
pub fn next_request_ids(&self, len: usize) -> Result<Vec<u64>, Error> {
self.get_slot()?;
let mut batch = Vec::with_capacity(len);
for _ in 0..len {
batch.push(self.current_id.fetch_add(1, Ordering::SeqCst));
}
Ok(batch)
}
/// Decrease the currently pending counter by one (saturated at 0).
pub fn reclaim_request_id(&self) {
// NOTE we ignore the error here, since we are simply saturating at 0
let _ = self.current_pending.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val > 0 {
Some(val - 1)
} else {
None
}
});
}
}
......@@ -28,8 +28,8 @@ use crate::transport::{Receiver as WsReceiver, Sender as WsSender, Target, WsTra
use crate::types::{
traits::{Client, SubscriptionClient},
v2::{Id, Notification, NotificationSer, ParamsSer, RequestSer, Response, RpcError, SubscriptionResponse},
BatchMessage, Error, FrontToBack, RegisterNotificationMessage, RequestMessage, Subscription, SubscriptionKind,
SubscriptionMessage, TEN_MB_SIZE_BYTES,
BatchMessage, Error, FrontToBack, RegisterNotificationMessage, RequestIdGuard, RequestMessage, Subscription,
SubscriptionKind, SubscriptionMessage, TEN_MB_SIZE_BYTES,
};
use crate::{
helpers::{
......@@ -49,11 +49,7 @@ use futures::{
use tokio::sync::Mutex;
use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
sync::atomic::{AtomicU64, AtomicUsize, Ordering},
time::Duration,
};
use std::{borrow::Cow, time::Duration};
/// Wrapper over a [`oneshot::Receiver`](futures::channel::oneshot::Receiver) that reads
/// the underlying channel once and then stores the result in String.
......@@ -103,67 +99,6 @@ pub struct WsClient {
id_guard: RequestIdGuard,
}
#[derive(Debug)]
struct RequestIdGuard {
// Current pending requests.
current_pending: AtomicUsize,
/// Max concurrent pending requests allowed.
max_concurrent_requests: usize,
/// Get the next request ID.
current_id: AtomicU64,
}
impl RequestIdGuard {
fn new(limit: usize) -> Self {
Self { current_pending: AtomicUsize::new(0), max_concurrent_requests: limit, current_id: AtomicU64::new(0) }
}
fn get_slot(&self) -> Result<(), Error> {
self.current_pending
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val >= self.max_concurrent_requests {
None
} else {
Some(val + 1)
}
})
.map(|_| ())
.map_err(|_| Error::MaxSlotsExceeded)
}
/// Attempts to get the next request ID.
///
/// Fails if request limit has been exceeded.
fn next_request_id(&self) -> Result<u64, Error> {
self.get_slot()?;
let id = self.current_id.fetch_add(1, Ordering::SeqCst);
Ok(id)
}
/// Attempts to get the `n` number next IDs that only counts as one request.
///
/// Fails if request limit has been exceeded.
fn next_request_ids(&self, len: usize) -> Result<Vec<u64>, Error> {
self.get_slot()?;
let mut batch = Vec::with_capacity(len);
for _ in 0..len {
batch.push(self.current_id.fetch_add(1, Ordering::SeqCst));
}
Ok(batch)
}
fn reclaim_request_id(&self) {
// NOTE we ignore the error here, since we are simply saturating at 0
let _ = self.current_pending.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |val| {
if val > 0 {
Some(val - 1)
} else {
None
}
});
}
}
/// Configuration.
#[derive(Clone, Debug)]
pub struct WsClientBuilder<'a> {
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment