Unverified Commit 3f1c7fcf authored by Niklas Adolfsson's avatar Niklas Adolfsson Committed by GitHub
Browse files

clients: feature gate `tls` (#545)

* clients: introduce tls feature flag

* Update tests/tests/integration_tests.rs

* fix: don't rebuild tls connector of every connect

* fix tests + remove url dep

* fix tests again
parent 3cb5eda9
......@@ -13,7 +13,7 @@ documentation = "https://docs.rs/jsonrpsee-http-client"
async-trait = "0.1"
fnv = "1"
hyper = { version = "0.14.10", features = ["client", "http1", "http2", "tcp"] }
hyper-rustls = { version = "0.23", features = ["webpki-tokio"] }
hyper-rustls = { version = "0.23", optional = true }
jsonrpsee-types = { path = "../types", version = "0.6.0" }
jsonrpsee-utils = { path = "../utils", version = "0.6.0", features = ["client", "http-helpers"] }
serde = { version = "1.0", default-features = false, features = ["derive"] }
......@@ -21,8 +21,11 @@ serde_json = "1.0"
thiserror = "1.0"
tokio = { version = "1.8", features = ["time"] }
tracing = "0.1"
url = "2.2"
[dev-dependencies]
jsonrpsee-test-utils = { path = "../test-utils" }
tokio = { version = "1.8", features = ["net", "rt-multi-thread", "macros"] }
[features]
default = ["tls"]
tls = ["hyper-rustls/webpki-tokio"]
......@@ -8,20 +8,39 @@
use crate::types::error::GenericTransportError;
use hyper::client::{Client, HttpConnector};
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper::Uri;
use jsonrpsee_types::CertificateStore;
use jsonrpsee_utils::http_helpers;
use thiserror::Error;
const CONTENT_TYPE_JSON: &str = "application/json";
#[derive(Debug, Clone)]
enum HyperClient {
/// Hyper client with https connector.
#[cfg(feature = "tls")]
Https(Client<hyper_rustls::HttpsConnector<HttpConnector>>),
/// Hyper client with http connector.
Http(Client<HttpConnector>),
}
impl HyperClient {
fn request(&self, req: hyper::Request<hyper::Body>) -> hyper::client::ResponseFuture {
match self {
Self::Http(client) => client.request(req),
#[cfg(feature = "tls")]
Self::Https(client) => client.request(req),
}
}
}
/// HTTP Transport Client.
#[derive(Debug, Clone)]
pub(crate) struct HttpTransportClient {
/// Target to connect to.
target: url::Url,
target: Uri,
/// HTTP client
client: Client<HttpsConnector<HttpConnector>>,
client: HyperClient,
/// Configurable max request body size
max_request_body_size: u32,
}
......@@ -33,22 +52,40 @@ impl HttpTransportClient {
max_request_body_size: u32,
cert_store: CertificateStore,
) -> Result<Self, Error> {
let target = url::Url::parse(target.as_ref()).map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.scheme() == "http" || target.scheme() == "https" {
let connector = match cert_store {
CertificateStore::Native => {
HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1()
}
CertificateStore::WebPki => {
HttpsConnectorBuilder::new().with_webpki_roots().https_or_http().enable_http1()
}
_ => return Err(Error::InvalidCertficateStore),
};
let client = Client::builder().build::<_, hyper::Body>(connector.build());
Ok(HttpTransportClient { target, client, max_request_body_size })
} else {
Err(Error::Url("URL scheme not supported, expects 'http' or 'https'".into()))
let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.port_u16().is_none() {
return Err(Error::Url("Port number is missing in the URL".into()));
}
let client = match target.scheme_str() {
Some("http") => {
let connector = HttpConnector::new();
let client = Client::builder().build::<_, hyper::Body>(connector);
HyperClient::Http(client)
}
#[cfg(feature = "tls")]
Some("https") => {
let connector = match cert_store {
CertificateStore::Native => {
hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1()
}
CertificateStore::WebPki => {
hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots().https_or_http().enable_http1()
}
_ => return Err(Error::InvalidCertficateStore),
};
let client = Client::builder().build::<_, hyper::Body>(connector.build());
HyperClient::Https(client)
}
_ => {
#[cfg(feature = "tls")]
let err = "URL scheme not supported, expects 'http' or 'https'";
#[cfg(not(feature = "tls"))]
let err = "URL scheme not supported, expects 'http'";
return Err(Error::Url(err.into()));
}
};
Ok(Self { target, client, max_request_body_size })
}
async fn inner_send(&self, body: String) -> Result<hyper::Response<hyper::Body>, Error> {
......@@ -58,7 +95,9 @@ impl HttpTransportClient {
return Err(Error::RequestTooLarge);
}
let req = hyper::Request::post(self.target.as_str())
// NOTE(niklasad1): this annoying we could just take `&str` here but more user-friendly to check
// that the URI is well-formed in the constructor.
let req = hyper::Request::post(self.target.clone())
.header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.body(From::from(body))
......@@ -135,12 +174,74 @@ where
mod tests {
use super::{CertificateStore, Error, HttpTransportClient};
fn assert_target(
client: &HttpTransportClient,
host: &str,
scheme: &str,
path_and_query: &str,
port: u16,
max_request_size: u32,
) {
assert_eq!(client.target.scheme_str(), Some(scheme));
assert_eq!(client.target.path_and_query().map(|pq| pq.as_str()), Some(path_and_query));
assert_eq!(client.target.host(), Some(host));
assert_eq!(client.target.port_u16(), Some(port));
assert_eq!(client.max_request_body_size, max_request_size);
}
#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap();
assert_target(&client, "localhost", "https", "/", 9933, 80);
}
#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[test]
fn faulty_port() {
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native).unwrap_err();
assert!(matches!(err, Error::Url(_)));
}
#[test]
fn url_with_path_works() {
let client =
HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native).unwrap();
assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337);
}
#[test]
fn url_with_query_works() {
let client = HttpTransportClient::new(
"http://127.0.0.1:9999/my?name1=value1&name2=value2",
u32::MAX,
CertificateStore::WebPki,
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX);
}
#[test]
fn url_with_fragment_is_ignored() {
let client =
HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native).unwrap();
assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999);
}
#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
......
......@@ -258,8 +258,7 @@ async fn ws_with_non_ascii_url_doesnt_hang_or_panic() {
#[tokio::test]
async fn http_with_non_ascii_url_doesnt_hang_or_panic() {
let client = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂").unwrap();
let err: Result<(), Error> = client.request("system_chain", None).await;
let err = HttpClientBuilder::default().build("http://♥♥♥♥♥♥∀∂");
assert!(matches!(err, Err(Error::Transport(_))));
}
......
......@@ -22,7 +22,7 @@ serde_json = "1"
soketto = "0.7.1"
thiserror = "1"
tokio = { version = "1.8", features = ["net", "time", "rt-multi-thread", "macros"] }
tokio-rustls = "0.23"
tokio-rustls = { version = "0.23", optional = true }
tokio-util = { version = "0.6", features = ["compat"] }
tracing = "0.1"
webpki-roots = "0.22.0"
......@@ -32,3 +32,7 @@ env_logger = "0.9"
jsonrpsee-test-utils = { path = "../test-utils" }
jsonrpsee-utils = { path = "../utils", features = ["client"] }
tokio = { version = "1.8", features = ["macros"] }
[features]
default = ["tls"]
tls = ["tokio-rustls"]
......@@ -32,23 +32,21 @@ use futures::{
};
use pin_project::pin_project;
use std::{io::Error as IoError, pin::Pin, task::Context, task::Poll};
use tokio::net::TcpStream;
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
/// Stream to represent either a unencrypted or encrypted socket stream.
#[pin_project(project = EitherStreamProj)]
#[derive(Debug, Copy, Clone)]
pub enum EitherStream<S, T> {
#[derive(Debug)]
pub enum EitherStream {
/// Unencrypted socket stream.
Plain(#[pin] S),
Plain(#[pin] TcpStream),
/// Encrypted socket stream.
Tls(#[pin] T),
#[cfg(feature = "tls")]
Tls(#[pin] tokio_rustls::client::TlsStream<TcpStream>),
}
impl<S, T> AsyncRead for EitherStream<S, T>
where
S: TokioAsyncReadCompatExt,
T: TokioAsyncReadCompatExt,
{
impl AsyncRead for EitherStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
......@@ -56,6 +54,7 @@ where
futures::pin_mut!(compat);
AsyncRead::poll_read(compat, cx, buf)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat();
futures::pin_mut!(compat);
......@@ -75,6 +74,7 @@ where
futures::pin_mut!(compat);
AsyncRead::poll_read_vectored(compat, cx, bufs)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat();
futures::pin_mut!(compat);
......@@ -84,11 +84,7 @@ where
}
}
impl<S, T> AsyncWrite for EitherStream<S, T>
where
S: TokioAsyncWriteCompatExt,
T: TokioAsyncWriteCompatExt,
{
impl AsyncWrite for EitherStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, IoError>> {
match self.project() {
EitherStreamProj::Plain(s) => {
......@@ -96,6 +92,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_write(compat, cx, buf)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
......@@ -111,6 +108,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_write_vectored(compat, cx, bufs)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
......@@ -126,6 +124,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_flush(compat, cx)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
......@@ -141,6 +140,7 @@ where
futures::pin_mut!(compat);
AsyncWrite::poll_close(compat, cx)
}
#[cfg(feature = "tls")]
EitherStreamProj::Tls(t) => {
let compat = t.compat_write();
futures::pin_mut!(compat);
......
......@@ -35,25 +35,21 @@ use std::{
convert::TryFrom,
io,
net::{SocketAddr, ToSocketAddrs},
sync::Arc,
time::Duration,
};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_rustls::{client::TlsStream, rustls, webpki::InvalidDnsNameError, TlsConnector};
type TlsOrPlain = EitherStream<TcpStream, TlsStream<TcpStream>>;
/// Sending end of WebSocket transport.
#[derive(Debug)]
pub struct Sender {
inner: connection::Sender<BufReader<BufWriter<TlsOrPlain>>>,
inner: connection::Sender<BufReader<BufWriter<EitherStream>>>,
}
/// Receiving end of WebSocket transport.
#[derive(Debug)]
pub struct Receiver {
inner: connection::Receiver<BufReader<BufWriter<TlsOrPlain>>>,
inner: connection::Receiver<BufReader<BufWriter<EitherStream>>>,
}
/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
......@@ -106,8 +102,9 @@ pub enum WsHandshakeError {
Transport(#[source] soketto::handshake::Error),
/// Invalid DNS name error for TLS
#[cfg(feature = "tls")]
#[error("Invalid DNS name: {0}")]
InvalidDnsName(#[source] InvalidDnsNameError),
InvalidDnsName(#[source] tokio_rustls::webpki::InvalidDnsNameError),
/// Server rejected the handshake.
#[error("Connection rejected with status code: {status_code}")]
......@@ -169,31 +166,28 @@ impl Receiver {
impl<'a> WsTransportClientBuilder<'a> {
/// Try to establish the connection.
pub async fn build(self) -> Result<(Sender, Receiver), WsHandshakeError> {
let connector = match self.target.mode {
Mode::Tls => {
let tls_connector = build_tls_config(&self.certificate_store)?;
Some(tls_connector)
}
Mode::Plain => None,
};
self.try_connect(connector).await
self.try_connect().await
}
async fn try_connect(
self,
mut tls_connector: Option<TlsConnector>,
) -> Result<(Sender, Receiver), WsHandshakeError> {
async fn try_connect(self) -> Result<(Sender, Receiver), WsHandshakeError> {
let mut target = self.target;
let mut err = None;
// Only build TLS connector if `wss` in URL.
#[cfg(feature = "tls")]
let mut connector = match target.mode {
Mode::Tls => Some(build_tls_config(&self.certificate_store)?),
Mode::Plain => None,
};
for _ in 0..self.max_redirections {
tracing::debug!("Connecting to target: {:?}", target);
// The sockaddrs might get reused if the server replies with a relative URI.
let sockaddrs = std::mem::take(&mut target.sockaddrs);
for sockaddr in &sockaddrs {
let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, &tls_connector).await {
#[cfg(feature = "tls")]
let tcp_stream = match connect(*sockaddr, self.timeout, &target.host, connector.as_ref()).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr);
......@@ -201,6 +195,17 @@ impl<'a> WsTransportClientBuilder<'a> {
continue;
}
};
#[cfg(not(feature = "tls"))]
let tcp_stream = match connect(*sockaddr, self.timeout).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!("Failed to connect to sockaddr: {:?}", sockaddr);
err = Some(Err(e));
continue;
}
};
let mut client = WsHandshakeClient::new(
BufReader::new(BufWriter::new(tcp_stream)),
&target.host_header,
......@@ -231,13 +236,17 @@ impl<'a> WsTransportClientBuilder<'a> {
// Absolute URI.
if uri.scheme().is_some() {
target = uri.try_into()?;
// Only build TLS connector if `wss` in redirection URL.
#[cfg(feature = "tls")]
match target.mode {
Mode::Tls if tls_connector.is_none() => {
tls_connector = Some(build_tls_config(&self.certificate_store)?);
Mode::Tls if connector.is_none() => {
connector = Some(build_tls_config(&self.certificate_store)?);
}
Mode::Tls => (),
// Drop connector if it was configured previously.
Mode::Plain => {
tls_connector = None;
connector = None;
}
};
}
......@@ -282,12 +291,13 @@ impl<'a> WsTransportClientBuilder<'a> {
}
}
#[cfg(feature = "tls")]
async fn connect(
sockaddr: SocketAddr,
timeout_dur: Duration,
host: &str,
tls_connector: &Option<TlsConnector>,
) -> Result<EitherStream<TcpStream, TlsStream<TcpStream>>, WsHandshakeError> {
tls_connector: Option<&tokio_rustls::TlsConnector>,
) -> Result<EitherStream, WsHandshakeError> {
let socket = TcpStream::connect(sockaddr);
let timeout = tokio::time::sleep(timeout_dur);
tokio::select! {
......@@ -297,11 +307,11 @@ async fn connect(
tracing::warn!("set nodelay failed: {:?}", err);
}
match tls_connector {
None => Ok(TlsOrPlain::Plain(socket)),
None => Ok(EitherStream::Plain(socket)),
Some(connector) => {
let server_name: rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?;
let server_name: tokio_rustls::rustls::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {} {:?}", host, e).into()))?;
let tls_stream = connector.connect(server_name, socket).await?;
Ok(TlsOrPlain::Tls(tls_stream))
Ok(EitherStream::Tls(tls_stream))
}
}
}
......@@ -309,14 +319,31 @@ async fn connect(
}
}
#[cfg(not(feature = "tls"))]
async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result<EitherStream, WsHandshakeError> {
let socket = TcpStream::connect(sockaddr);
let timeout = tokio::time::sleep(timeout_dur);
tokio::select! {
socket = socket => {
let socket = socket?;
if let Err(err) = socket.set_nodelay(true) {
tracing::warn!("set nodelay failed: {:?}", err);
}
Ok(EitherStream::Plain(socket))
}
_ = timeout => Err(WsHandshakeError::Timeout(timeout_dur))
}
}
impl From<io::Error> for WsHandshakeError {
fn from(err: io::Error) -> WsHandshakeError {
WsHandshakeError::Io(err)
}
}
impl From<InvalidDnsNameError> for WsHandshakeError {
fn from(err: InvalidDnsNameError) -> WsHandshakeError {
#[cfg(feature = "tls")]
impl From<tokio_rustls::webpki::InvalidDnsNameError> for WsHandshakeError {
fn from(err: tokio_rustls::webpki::InvalidDnsNameError) -> WsHandshakeError {
WsHandshakeError::InvalidDnsName(err)
}
}
......@@ -354,8 +381,15 @@ impl TryFrom<Uri> for Target {
fn try_from(uri: Uri) -> Result<Self, Self::Error> {
let mode = match uri.scheme_str() {
Some("ws") => Mode::Plain,
#[cfg(feature = "tls")]
Some("wss") => Mode::Tls,
_ => return Err(WsHandshakeError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())),
_ => {
#[cfg(feature = "tls")]
let err = "URL scheme not supported, expects 'ws' or 'wss'";
#[cfg(not(feature = "tls"))]
let err = "URL scheme not supported, expects 'ws'";
return Err(WsHandshakeError::Url(err.into()));
}
};
let host = uri.host().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("No host in URL".into()))?;
let port = uri
......@@ -370,8 +404,11 @@ impl TryFrom<Uri> for Target {
}
// NOTE: this is slow and should be used sparingly.
fn build_tls_config(cert_store: &CertificateStore) -> Result<TlsConnector, WsHandshakeError> {
let mut roots = tokio_rustls::rustls::RootCertStore::empty();
#[cfg(feature = "tls")]
fn build_tls_config(cert_store: &CertificateStore) -> Result<tokio_rustls::TlsConnector, WsHandshakeError> {
use tokio_rustls::rustls;
let mut roots = rustls::RootCertStore::empty();
match cert_store {
CertificateStore::Native => {
......@@ -403,7 +440,7 @@ fn build_tls_config(cert_store: &CertificateStore) -> Result<TlsConnector, WsHan
let config =
rustls::ClientConfig::builder().with_safe_defaults().with_root_certificates(roots).with_no_client_auth();
Ok(Arc::new(config).into())
Ok(std::sync::Arc::new(config).into())
}
#[cfg(test)]
......@@ -429,12 +466,20 @@ mod tests {
assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/");
}
#[cfg(feature = "tls")]
#[test]
fn wss_works() {
let target = parse_target("wss://kusama-rpc.polkadot.io:443").unwrap();
assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:443", Mode::Tls, "/");
}
#[cfg(not(feature = "tls"))]
#[test]
fn wss_fails_with_tls_feature() {
let err = parse_target("wss://kusama-rpc.polkadot.io:443").unwrap_err();
assert!(matches!(err, WsHandshakeError::Url(_)));
}
#[test]
fn faulty_url_scheme() {
let err = parse_target("http://kusama-rpc.polkadot.io:443").unwrap_err();
......@@ -451,19 +496,19 @@ mod tests {
#[test]
fn url_with_path_works() {
let target = parse_target("wss://127.0.0.1:443/my-special-path").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my-special-path");
let target = parse_target("ws://127.0.0.1:443/my-special-path").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Plain, "/my-special-path");
}
#[test]
fn url_with_query_works() {
let target = parse_target("wss://127.0.0.1:443/my?name1=value1&name2=value2").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my?name1=value1&name2=value2");
let target = parse_target("ws://127.0.0.1:443/my?name1=value1&name2=value2").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Plain, "/my?name1=value1&name2=value2");
}
#[test]
fn url_with_fragment_is_ignored() {
let target = parse_target("wss://127.0.0.1:443/my.htm#ignore").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Tls, "/my.htm");
let target = parse_target("ws://127.0.0.1:443/my.htm#ignore").unwrap();
assert_ws_target(target, "127.0.0.1", "127.0.0.1:443", Mode::Plain, "/my.htm");
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment