Unverified Commit 0ccfbd7e authored by Alexandru Vasile's avatar Alexandru Vasile Committed by GitHub
Browse files

Uniform API for custom headers between clients (#814)



* ws-client: Replace `httparse::Header` with `http::HeaderMap`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* ws-client: Make headers optional

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-client: Expose custom header injection

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-client: Adjust testing for custom headers

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Make `http::HeaderMap` non-optional

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-client: Cache request headers

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Fix doc tests

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-client: Use `into_iter` for headers

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* docs: Improve custom headers documentation

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http: Use `hyper::http` instead of `http` directly

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* http-client: Adjust testing

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* Fix doc tests

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>

* client: Expose `http::HeaderMap` and `http::HeaderValue`

Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>
parent a26f1fb7
Pipeline #202760 passed with stages
in 4 minutes and 46 seconds
......@@ -30,6 +30,7 @@ use std::time::Duration;
use crate::transport::HttpTransportClient;
use crate::types::{ErrorResponse, Id, NotificationSer, ParamsSer, RequestSer, Response};
use async_trait::async_trait;
use hyper::http::HeaderMap;
use jsonrpsee_core::client::{CertificateStore, ClientT, IdKind, RequestIdManager, Subscription, SubscriptionClientT};
use jsonrpsee_core::tracing::RpcTracing;
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
......@@ -39,6 +40,29 @@ use serde::de::DeserializeOwned;
use tracing_futures::Instrument;
/// Http Client Builder.
///
/// # Examples
///
/// ```no_run
///
/// use jsonrpsee_http_client::{HttpClientBuilder, HeaderMap, HeaderValue};
///
/// #[tokio::main]
/// async fn main() {
/// // Build custom headers used for every submitted request.
/// let mut headers = HeaderMap::new();
/// headers.insert("Any-Header-You-Like", HeaderValue::from_static("42"));
///
/// // Build client
/// let client = HttpClientBuilder::default()
/// .set_headers(headers)
/// .build("wss://localhost:443")
/// .unwrap();
///
/// // use client....
/// }
///
/// ```
#[derive(Debug)]
pub struct HttpClientBuilder {
max_request_body_size: u32,
......@@ -47,6 +71,7 @@ pub struct HttpClientBuilder {
certificate_store: CertificateStore,
id_kind: IdKind,
max_log_length: u32,
headers: HeaderMap,
}
impl HttpClientBuilder {
......@@ -88,11 +113,24 @@ impl HttpClientBuilder {
self
}
/// Set a custom header passed to the server with every request (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn set_headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}
/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport =
HttpTransportClient::new(target, self.max_request_body_size, self.certificate_store, self.max_log_length)
.map_err(|e| Error::Transport(e.into()))?;
let transport = HttpTransportClient::new(
target,
self.max_request_body_size,
self.certificate_store,
self.max_log_length,
self.headers,
)
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
transport,
id_manager: Arc::new(RequestIdManager::new(self.max_concurrent_requests, self.id_kind)),
......@@ -110,6 +148,7 @@ impl Default for HttpClientBuilder {
certificate_store: CertificateStore::Native,
id_kind: IdKind::Number,
max_log_length: 4096,
headers: HeaderMap::new(),
}
}
}
......
......@@ -43,4 +43,5 @@ pub mod transport;
mod tests;
pub use client::{HttpClient, HttpClientBuilder};
pub use hyper::http::{HeaderMap, HeaderValue};
pub use jsonrpsee_types as types;
......@@ -7,6 +7,7 @@
// the JSON-RPC request id to a value that might have already been used.
use hyper::client::{Client, HttpConnector};
use hyper::http::{HeaderMap, HeaderValue};
use hyper::Uri;
use jsonrpsee_core::client::CertificateStore;
use jsonrpsee_core::error::GenericTransportError;
......@@ -48,6 +49,8 @@ pub struct HttpTransportClient {
///
/// Logs bigger than this limit will be truncated.
max_log_length: u32,
/// Custom headers to pass with every request.
headers: HeaderMap,
}
impl HttpTransportClient {
......@@ -57,6 +60,7 @@ impl HttpTransportClient {
max_request_body_size: u32,
cert_store: CertificateStore,
max_log_length: u32,
headers: HeaderMap,
) -> Result<Self, Error> {
let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.port_u16().is_none() {
......@@ -90,7 +94,20 @@ impl HttpTransportClient {
return Err(Error::Url(err.into()));
}
};
Ok(Self { target, client, max_request_body_size, max_log_length })
// Cache request headers: 2 default headers, followed by user custom headers.
// Maintain order for headers in case of duplicate keys:
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2
let mut cached_headers = HeaderMap::with_capacity(2 + headers.len());
cached_headers.insert(hyper::header::CONTENT_TYPE, HeaderValue::from_static(CONTENT_TYPE_JSON));
cached_headers.insert(hyper::header::ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON));
for (key, value) in headers.into_iter() {
if let Some(key) = key {
cached_headers.insert(key, value);
}
}
Ok(Self { target, client, max_request_body_size, max_log_length, headers: cached_headers })
}
async fn inner_send(&self, body: String) -> Result<hyper::Response<hyper::Body>, Error> {
......@@ -100,11 +117,9 @@ impl HttpTransportClient {
return Err(Error::RequestTooLarge);
}
let req = hyper::Request::post(&self.target)
.header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.body(From::from(body))
.expect("URI and request headers are valid; qed");
let mut req = hyper::Request::post(&self.target);
req.headers_mut().map(|headers| *headers = self.headers.clone());
let req = req.body(From::from(body)).expect("URI and request headers are valid; qed");
let response = self.client.request(req).await.map_err(|e| Error::Http(Box::new(e)))?;
if response.status().is_success() {
......@@ -179,7 +194,7 @@ where
#[cfg(test)]
mod tests {
use super::{CertificateStore, Error, HttpTransportClient};
use super::*;
fn assert_target(
client: &HttpTransportClient,
......@@ -198,37 +213,50 @@ mod tests {
#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap();
let client =
HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap();
assert_target(&client, "localhost", "https", "/", 9933, 80);
}
#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[test]
fn faulty_port() {
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[test]
fn url_with_path_works() {
let client =
HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native, 80)
.unwrap();
let client = HttpTransportClient::new(
"http://localhost:9944/my-special-path",
1337,
CertificateStore::Native,
80,
HeaderMap::new(),
)
.unwrap();
assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337);
}
......@@ -239,6 +267,7 @@ mod tests {
u32::MAX,
CertificateStore::WebPki,
80,
HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX);
......@@ -246,15 +275,23 @@ mod tests {
#[test]
fn url_with_fragment_is_ignored() {
let client =
HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native, 80).unwrap();
let client = HttpTransportClient::new(
"http://127.0.0.1:9944/my.htm#ignore",
999,
CertificateStore::Native,
80,
HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999);
}
#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
let client = HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99).unwrap();
let client =
HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99, HeaderMap::new())
.unwrap();
assert_eq!(client.max_request_body_size, eighty_bytes_limit);
let body = "a".repeat(81);
......
......@@ -42,7 +42,7 @@ use stream::EitherStream;
use thiserror::Error;
use tokio::net::TcpStream;
pub use http::{uri::InvalidUri, Uri};
pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri};
pub use soketto::handshake::client::Header;
/// Sending end of WebSocket transport.
......@@ -59,33 +59,32 @@ pub struct Receiver {
/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
#[derive(Debug)]
pub struct WsTransportClientBuilder<'a> {
pub struct WsTransportClientBuilder {
/// What certificate store to use
pub certificate_store: CertificateStore,
/// Timeout for the connection.
pub connection_timeout: Duration,
/// Custom headers to pass during the HTTP handshake. If `None`, no
/// custom header is passed.
pub headers: Vec<Header<'a>>,
/// Custom headers to pass during the HTTP handshake.
pub headers: http::HeaderMap,
/// Max payload size
pub max_request_body_size: u32,
/// Max number of redirections.
pub max_redirections: usize,
}
impl<'a> Default for WsTransportClientBuilder<'a> {
impl Default for WsTransportClientBuilder {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
max_request_body_size: TEN_MB_SIZE_BYTES,
connection_timeout: Duration::from_secs(10),
headers: Vec::new(),
headers: http::HeaderMap::new(),
max_redirections: 5,
}
}
}
impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Set whether to use system certificates (default is native).
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
......@@ -107,8 +106,8 @@ impl<'a> WsTransportClientBuilder<'a> {
/// Set a custom header passed to the server during the handshake (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn add_header(mut self, name: &'a str, value: &'a str) -> Self {
self.headers.push(Header { name, value: value.as_bytes() });
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}
......@@ -240,7 +239,7 @@ impl TransportReceiverT for Receiver {
}
}
impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Try to establish the connection.
pub async fn build(self, uri: Uri) -> Result<(Sender, Receiver), WsHandshakeError> {
let target: Target = uri.try_into()?;
......@@ -289,7 +288,12 @@ impl<'a> WsTransportClientBuilder<'a> {
&target.path_and_query,
);
client.set_headers(&self.headers);
let headers: Vec<_> = self
.headers
.iter()
.map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() })
.collect();
client.set_headers(&headers);
// Perform the initial handshake.
match client.handshake().await {
......
......@@ -13,6 +13,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client"
jsonrpsee-types = { path = "../../types", version = "0.14.0" }
jsonrpsee-client-transport = { path = "../transport", version = "0.14.0", features = ["ws"] }
jsonrpsee-core = { path = "../../core", version = "0.14.0", features = ["async-client"] }
http = "0.2.0"
[dev-dependencies]
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
......
......@@ -40,9 +40,10 @@ mod tests;
pub use jsonrpsee_core::client::Client as WsClient;
pub use jsonrpsee_types as types;
pub use http::{HeaderMap, HeaderValue};
use std::time::Duration;
use jsonrpsee_client_transport::ws::{Header, InvalidUri, Uri, WsTransportClientBuilder};
use jsonrpsee_client_transport::ws::{InvalidUri, Uri, WsTransportClientBuilder};
use jsonrpsee_core::client::{CertificateStore, ClientBuilder, IdKind};
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
......@@ -52,13 +53,17 @@ use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
///
/// ```no_run
///
/// use jsonrpsee_ws_client::WsClientBuilder;
/// use jsonrpsee_ws_client::{WsClientBuilder, HeaderMap, HeaderValue};
///
/// #[tokio::main]
/// async fn main() {
/// // build client
/// // Build custom headers used during the handshake process.
/// let mut headers = HeaderMap::new();
/// headers.insert("Any-Header-You-Like", HeaderValue::from_static("42"));
///
/// // Build client
/// let client = WsClientBuilder::default()
/// .add_header("Any-Header-You-Like", "42")
/// .set_headers(headers)
/// .build("wss://localhost:443")
/// .await
/// .unwrap();
......@@ -68,20 +73,20 @@ use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
///
/// ```
#[derive(Clone, Debug)]
pub struct WsClientBuilder<'a> {
pub struct WsClientBuilder {
certificate_store: CertificateStore,
max_request_body_size: u32,
request_timeout: Duration,
connection_timeout: Duration,
ping_interval: Option<Duration>,
headers: Vec<Header<'a>>,
headers: http::HeaderMap,
max_concurrent_requests: usize,
max_notifs_per_subscription: usize,
max_redirections: usize,
id_kind: IdKind,
}
impl<'a> Default for WsClientBuilder<'a> {
impl Default for WsClientBuilder {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
......@@ -89,7 +94,7 @@ impl<'a> Default for WsClientBuilder<'a> {
request_timeout: Duration::from_secs(60),
connection_timeout: Duration::from_secs(10),
ping_interval: None,
headers: Vec::new(),
headers: HeaderMap::new(),
max_concurrent_requests: 256,
max_notifs_per_subscription: 1024,
max_redirections: 5,
......@@ -98,7 +103,7 @@ impl<'a> Default for WsClientBuilder<'a> {
}
}
impl<'a> WsClientBuilder<'a> {
impl WsClientBuilder {
/// See documentation [`WsTransportClientBuilder::certificate_store`] (default is native).
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
......@@ -129,9 +134,9 @@ impl<'a> WsClientBuilder<'a> {
self
}
/// See documentation [`WsTransportClientBuilder::add_header`] (default is none).
pub fn add_header(mut self, name: &'a str, value: &'a str) -> Self {
self.headers.push(Header { name, value: value.as_bytes() });
/// See documentation [`WsTransportClientBuilder::set_headers`] (default is none).
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment