Unverified Commit 69441f34 authored by Alexandru Vasile's avatar Alexandru Vasile
Browse files

http-server: Add tower middleware on the HttpBuilder



Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>
parent d64d1af4
Pipeline #207027 passed with stages
in 4 minutes and 56 seconds
...@@ -24,48 +24,20 @@ ...@@ -24,48 +24,20 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
//! This example sets a custom tower service middleware to the RPC implementation.
use hyper::body::Bytes; use hyper::body::Bytes;
use hyper::Body;
use std::iter::once; use std::iter::once;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::{Duration, Instant}; use std::time::Duration;
use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer; use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer;
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}; use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tower_http::LatencyUnit; use tower_http::LatencyUnit;
use jsonrpsee::core::client::ClientT; use jsonrpsee::core::client::ClientT;
use jsonrpsee::core::logger::{self, Params, Request};
use jsonrpsee::http_client::HttpClientBuilder; use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::http_server::{HttpServerBuilder, HttpServerHandle, RpcModule}; use jsonrpsee::http_server::{HttpServerBuilder, HttpServerHandle, RpcModule};
/// Define a custom logging mechanism to detect the time passed
/// between receiving the request and proving the response.
///
/// The implementation relies upon [logger::HttpLogger].
#[derive(Clone)]
struct Timings;
impl logger::HttpLogger for Timings {
type Instant = Instant;
fn on_request(&self, remote_addr: SocketAddr, request: &Request<Body>) -> Self::Instant {
println!("[Logger::on_request] remote_addr {}, request: {:?}", remote_addr, request);
Instant::now()
}
fn on_call(&self, name: &str, params: Params, kind: logger::MethodKind) {
println!("[Logger::on_call] method: '{}', params: {:?}, kind: {}", name, params, kind);
}
fn on_result(&self, name: &str, success: bool, started_at: Self::Instant) {
println!("[Logger::on_result] '{}', worked? {}, time elapsed {:?}", name, success, started_at.elapsed());
}
fn on_response(&self, result: &str, started_at: Self::Instant) {
println!("[Logger::on_response] result: {}, time elapsed {:?}", result, started_at.elapsed());
}
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
tracing_subscriber::FmtSubscriber::builder() tracing_subscriber::FmtSubscriber::builder()
...@@ -87,7 +59,7 @@ async fn main() -> anyhow::Result<()> { ...@@ -87,7 +59,7 @@ async fn main() -> anyhow::Result<()> {
async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> { async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> {
// Custom tower service to handle the RPC requests // Custom tower service to handle the RPC requests
let builder = tower::ServiceBuilder::new() let service_builder = tower::ServiceBuilder::new()
// Add high level tracing/logging to all requests // Add high level tracing/logging to all requests
.layer( .layer(
TraceLayer::new_for_http() TraceLayer::new_for_http()
...@@ -101,13 +73,15 @@ async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> { ...@@ -101,13 +73,15 @@ async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> {
.layer(SetSensitiveRequestHeadersLayer::new(once(hyper::header::AUTHORIZATION))) .layer(SetSensitiveRequestHeadersLayer::new(once(hyper::header::AUTHORIZATION)))
.timeout(Duration::from_secs(2)); .timeout(Duration::from_secs(2));
let server = HttpServerBuilder::new().set_logger(Timings).build("127.0.0.1:0".parse::<SocketAddr>()?).await?; let server =
HttpServerBuilder::new().set_middleware(service_builder).build("127.0.0.1:0".parse::<SocketAddr>()?).await?;
let addr = server.local_addr()?; let addr = server.local_addr()?;
let mut module = RpcModule::new(()); let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| Ok("lo")).unwrap(); module.register_method("say_hello", |_, _| Ok("lo")).unwrap();
let handler = server.start_with_builder(module, builder)?; let handler = server.start(module)?;
Ok((addr, handler)) Ok((addr, handler))
} }
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use std::convert::Infallible;
use std::future::Future; use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin; use std::pin::Pin;
...@@ -34,6 +35,7 @@ use crate::response::{internal_error, malformed}; ...@@ -34,6 +35,7 @@ use crate::response::{internal_error, malformed};
use futures_channel::mpsc; use futures_channel::mpsc;
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
use futures_util::stream::{StreamExt, TryStreamExt}; use futures_util::stream::{StreamExt, TryStreamExt};
use futures_util::TryFutureExt;
use hyper::body::HttpBody; use hyper::body::HttpBody;
use hyper::header::{HeaderMap, HeaderValue}; use hyper::header::{HeaderMap, HeaderValue};
use hyper::server::conn::AddrStream; use hyper::server::conn::AddrStream;
...@@ -55,6 +57,7 @@ use jsonrpsee_types::{Id, Notification, Params, Request}; ...@@ -55,6 +57,7 @@ use jsonrpsee_types::{Id, Notification, Params, Request};
use serde::de::StdError; use serde::de::StdError;
use serde_json::value::RawValue; use serde_json::value::RawValue;
use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::net::{TcpListener, ToSocketAddrs};
use tower::layer::util::Identity;
use tower::Layer; use tower::Layer;
use tracing_futures::Instrument; use tracing_futures::Instrument;
...@@ -62,7 +65,7 @@ type Notif<'a> = Notification<'a, Option<&'a RawValue>>; ...@@ -62,7 +65,7 @@ type Notif<'a> = Notification<'a, Option<&'a RawValue>>;
/// Builder to create JSON-RPC HTTP server. /// Builder to create JSON-RPC HTTP server.
#[derive(Debug)] #[derive(Debug)]
pub struct Builder<L = ()> { pub struct Builder<B = Identity, L = ()> {
/// Access control based on HTTP headers. /// Access control based on HTTP headers.
access_control: AccessControl, access_control: AccessControl,
resources: Resources, resources: Resources,
...@@ -74,6 +77,7 @@ pub struct Builder<L = ()> { ...@@ -74,6 +77,7 @@ pub struct Builder<L = ()> {
logger: L, logger: L,
max_log_length: u32, max_log_length: u32,
health_api: Option<HealthApi>, health_api: Option<HealthApi>,
service_builder: tower::ServiceBuilder<B>,
} }
impl Default for Builder { impl Default for Builder {
...@@ -88,6 +92,7 @@ impl Default for Builder { ...@@ -88,6 +92,7 @@ impl Default for Builder {
logger: (), logger: (),
max_log_length: 4096, max_log_length: 4096,
health_api: None, health_api: None,
service_builder: tower::ServiceBuilder::new(),
} }
} }
} }
...@@ -99,9 +104,11 @@ impl Builder { ...@@ -99,9 +104,11 @@ impl Builder {
} }
} }
impl<L> Builder<L> { impl<B, L> Builder<B, L> {
/// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html). /// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html).
/// ///
/// # Examples
///
/// ``` /// ```
/// use std::{time::Instant, net::SocketAddr}; /// use std::{time::Instant, net::SocketAddr};
/// use hyper::Request; /// use hyper::Request;
...@@ -141,7 +148,7 @@ impl<L> Builder<L> { ...@@ -141,7 +148,7 @@ impl<L> Builder<L> {
/// ///
/// let builder = HttpServerBuilder::new().set_logger(MyLogger); /// let builder = HttpServerBuilder::new().set_logger(MyLogger);
/// ``` /// ```
pub fn set_logger<T: Logger>(self, logger: T) -> Builder<T> { pub fn set_logger<T: Logger>(self, logger: T) -> Builder<B, T> {
Builder { Builder {
access_control: self.access_control, access_control: self.access_control,
max_request_body_size: self.max_request_body_size, max_request_body_size: self.max_request_body_size,
...@@ -152,6 +159,7 @@ impl<L> Builder<L> { ...@@ -152,6 +159,7 @@ impl<L> Builder<L> {
logger, logger,
max_log_length: self.max_log_length, max_log_length: self.max_log_length,
health_api: self.health_api, health_api: self.health_api,
service_builder: self.service_builder,
} }
} }
...@@ -216,8 +224,49 @@ impl<L> Builder<L> { ...@@ -216,8 +224,49 @@ impl<L> Builder<L> {
Ok(self) Ok(self)
} }
/// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied to the RPC service.
///
/// Default: No tower layers are applied to the RPC service.
///
/// # Examples
///
/// ```rust
///
/// use std::time::Duration;
/// use std::net::SocketAddr;
/// use jsonrpsee_http_server::HttpServerBuilder;
///
/// #[tokio::main]
/// async fn main() {
/// let builder = tower::ServiceBuilder::new()
/// .timeout(Duration::from_secs(2));
///
/// let server = HttpServerBuilder::new()
/// .set_middleware(builder)
/// .build("127.0.0.1:0".parse::<SocketAddr>().unwrap())
/// .await
/// .unwrap();
/// }
/// ```
pub fn set_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> Builder<T, L> {
Builder {
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
max_response_body_size: self.max_response_body_size,
batch_requests_supported: self.batch_requests_supported,
resources: self.resources,
tokio_runtime: self.tokio_runtime,
logger: self.logger,
max_log_length: self.max_log_length,
health_api: self.health_api,
service_builder,
}
}
/// Finalizes the configuration of the server with customized TCP settings on the socket and on hyper. /// Finalizes the configuration of the server with customized TCP settings on the socket and on hyper.
/// ///
/// # Examples
///
/// ```rust /// ```rust
/// use jsonrpsee_http_server::HttpServerBuilder; /// use jsonrpsee_http_server::HttpServerBuilder;
/// use socket2::{Domain, Socket, Type}; /// use socket2::{Domain, Socket, Type};
...@@ -252,7 +301,7 @@ impl<L> Builder<L> { ...@@ -252,7 +301,7 @@ impl<L> Builder<L> {
self, self,
listener: hyper::server::Builder<AddrIncoming>, listener: hyper::server::Builder<AddrIncoming>,
local_addr: SocketAddr, local_addr: SocketAddr,
) -> Result<Server<L>, Error> { ) -> Result<Server<B, L>, Error> {
Ok(Server { Ok(Server {
access_control: self.access_control, access_control: self.access_control,
listener, listener,
...@@ -265,6 +314,7 @@ impl<L> Builder<L> { ...@@ -265,6 +314,7 @@ impl<L> Builder<L> {
logger: self.logger, logger: self.logger,
max_log_length: self.max_log_length, max_log_length: self.max_log_length,
health_api: self.health_api, health_api: self.health_api,
service_builder: self.service_builder,
}) })
} }
...@@ -292,7 +342,7 @@ impl<L> Builder<L> { ...@@ -292,7 +342,7 @@ impl<L> Builder<L> {
/// let server = HttpServerBuilder::new().build_from_tcp(socket).unwrap(); /// let server = HttpServerBuilder::new().build_from_tcp(socket).unwrap();
/// } /// }
/// ``` /// ```
pub fn build_from_tcp(self, listener: impl Into<StdTcpListener>) -> Result<Server<L>, Error> { pub fn build_from_tcp(self, listener: impl Into<StdTcpListener>) -> Result<Server<B, L>, Error> {
let listener = listener.into(); let listener = listener.into();
let local_addr = listener.local_addr().ok(); let local_addr = listener.local_addr().ok();
...@@ -310,6 +360,7 @@ impl<L> Builder<L> { ...@@ -310,6 +360,7 @@ impl<L> Builder<L> {
logger: self.logger, logger: self.logger,
max_log_length: self.max_log_length, max_log_length: self.max_log_length,
health_api: self.health_api, health_api: self.health_api,
service_builder: self.service_builder,
}) })
} }
...@@ -328,7 +379,7 @@ impl<L> Builder<L> { ...@@ -328,7 +379,7 @@ impl<L> Builder<L> {
/// assert!(jsonrpsee_http_server::HttpServerBuilder::default().build(addrs).await.is_ok()); /// assert!(jsonrpsee_http_server::HttpServerBuilder::default().build(addrs).await.is_ok());
/// } /// }
/// ``` /// ```
pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<L>, Error> { pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<B, L>, Error> {
let listener = TcpListener::bind(addrs).await?.into_std()?; let listener = TcpListener::bind(addrs).await?.into_std()?;
let local_addr = listener.local_addr().ok(); let local_addr = listener.local_addr().ok();
...@@ -346,6 +397,7 @@ impl<L> Builder<L> { ...@@ -346,6 +397,7 @@ impl<L> Builder<L> {
logger: self.logger, logger: self.logger,
max_log_length: self.max_log_length, max_log_length: self.max_log_length,
health_api: self.health_api, health_api: self.health_api,
service_builder: self.service_builder,
}) })
} }
} }
...@@ -419,7 +471,7 @@ impl<L: Logger> ServiceData<L> { ...@@ -419,7 +471,7 @@ impl<L: Logger> ServiceData<L> {
async fn handle_request( async fn handle_request(
self, self,
request: hyper::Request<hyper::Body>, request: hyper::Request<hyper::Body>,
) -> Result<hyper::Response<hyper::Body>, HyperError> { ) -> Result<hyper::Response<hyper::Body>, Infallible> {
let ServiceData { let ServiceData {
remote_addr, remote_addr,
methods, methods,
...@@ -539,7 +591,14 @@ pub struct TowerService<L> { ...@@ -539,7 +591,14 @@ pub struct TowerService<L> {
impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerService<L> { impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerService<L> {
type Response = hyper::Response<hyper::Body>; type Response = hyper::Response<hyper::Body>;
type Error = hyper::Error;
// NOTE(lexnv): The `handle_request` method returns `Result<_, Infallible>`.
// This is because the RPC service will always return a valid HTTP response (ie return `Ok(_)`).
//
// The following associated type is required by the `impl<B, U, L: Logger> Server<B, L>` bounds.
// It satisfies the server's bounds when the `tower::ServiceBuilder<B>` is not set (ie `B: Identity`).
type Error = Box<(dyn StdError + Send + Sync + 'static)>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
/// Opens door for back pressure implementation. /// Opens door for back pressure implementation.
...@@ -549,13 +608,16 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe ...@@ -549,13 +608,16 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
fn call(&mut self, request: hyper::Request<hyper::Body>) -> Self::Future { fn call(&mut self, request: hyper::Request<hyper::Body>) -> Self::Future {
let data = self.inner.clone(); let data = self.inner.clone();
Box::pin(data.handle_request(request)) // Note that `handle_request` will never return error.
// The dummy error is set in place to satisfy the server's trait bounds regarding the
// `tower::ServiceBuilder` and the error will never be mapped.
Box::pin(data.handle_request(request).map_err(|_| "".into()))
} }
} }
/// An HTTP JSON RPC server. /// An HTTP JSON RPC server.
#[derive(Debug)] #[derive(Debug)]
pub struct Server<L = ()> { pub struct Server<B = Identity, L = ()> {
/// Hyper server. /// Hyper server.
listener: HyperBuilder<AddrIncoming>, listener: HyperBuilder<AddrIncoming>,
/// Local address /// Local address
...@@ -578,80 +640,28 @@ pub struct Server<L = ()> { ...@@ -578,80 +640,28 @@ pub struct Server<L = ()> {
tokio_runtime: Option<tokio::runtime::Handle>, tokio_runtime: Option<tokio::runtime::Handle>,
logger: L, logger: L,
health_api: Option<HealthApi>, health_api: Option<HealthApi>,
service_builder: tower::ServiceBuilder<B>,
} }
impl<L: Logger> Server<L> { impl<B, U, L: Logger> Server<B, L>
where
B: Layer<TowerService<L>> + Send + 'static,
<B as Layer<TowerService<L>>>::Service: Send
+ Service<
hyper::Request<Body>,
Response = hyper::Response<U>,
Error = Box<(dyn StdError + Send + Sync + 'static)>,
>,
<<B as Layer<TowerService<L>>>::Service as Service<hyper::Request<Body>>>::Future: Send,
U: HttpBody + Send + 'static,
<U as HttpBody>::Error: Send + Sync + StdError,
<U as HttpBody>::Data: Send,
{
/// Returns socket address to which the server is bound. /// Returns socket address to which the server is bound.
pub fn local_addr(&self) -> Result<SocketAddr, Error> { pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into())) self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into()))
} }
/// Start the server with a custom tower builder.
pub fn start_with_builder<T, U>(
mut self,
methods: impl Into<Methods>,
builder: tower::ServiceBuilder<T>,
) -> Result<ServerHandle, Error>
where
T: Layer<TowerService<L>> + Send + 'static,
<T as Layer<TowerService<L>>>::Service: Send
+ Service<
hyper::Request<Body>,
Response = hyper::Response<U>,
Error = Box<(dyn StdError + Send + Sync + 'static)>,
>,
<<T as Layer<TowerService<L>>>::Service as Service<hyper::Request<Body>>>::Future: Send,
U: HttpBody + Send + 'static,
<U as HttpBody>::Error: Send + Sync + StdError,
<U as HttpBody>::Data: Send,
{
let max_request_body_size = self.max_request_body_size;
let max_response_body_size = self.max_response_body_size;
let max_log_length = self.max_log_length;
let acl = self.access_control;
let (tx, mut rx) = mpsc::channel(1);
let listener = self.listener;
let resources = self.resources;
let logger = self.logger;
let batch_requests_supported = self.batch_requests_supported;
let methods = methods.into().initialize_resources(&resources)?;
let health_api = self.health_api;
let make_service = make_service_fn(move |conn: &AddrStream| {
let service = TowerService {
inner: ServiceData {
remote_addr: conn.remote_addr(),
methods: methods.clone(),
acl: acl.clone(),
resources: resources.clone(),
logger: logger.clone(),
health_api: health_api.clone(),
max_request_body_size,
max_response_body_size,
max_log_length,
batch_requests_supported,
},
};
let server = builder.service(service);
// For every request the `TowerService` is calling into `ServiceData::handle_request`
// where the RPSee bare implementation resides.
async move { Ok::<_, HyperError>(server) }
});
let rt = match self.tokio_runtime.take() {
Some(rt) => rt,
None => tokio::runtime::Handle::current(),
};
let handle = rt.spawn(async move {
let server = listener.serve(make_service);
let _ = server.with_graceful_shutdown(async move { rx.next().await.map_or((), |_| ()) }).await;
});
Ok(ServerHandle { handle: Some(handle), stop_sender: tx })
}
/// Start the server. /// Start the server.
pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> { pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> {
let max_request_body_size = self.max_request_body_size; let max_request_body_size = self.max_request_body_size;
...@@ -682,9 +692,11 @@ impl<L: Logger> Server<L> { ...@@ -682,9 +692,11 @@ impl<L: Logger> Server<L> {
}, },
}; };
let server = self.service_builder.service(service);
// For every request the `TowerService` is calling into `ServiceData::handle_request` // For every request the `TowerService` is calling into `ServiceData::handle_request`
// where the RPSee bare implementation resides. // where the RPSee bare implementation resides.
async move { Ok::<_, HyperError>(service) } async move { Ok::<_, HyperError>(server) }
}); });
let rt = match self.tokio_runtime.take() { let rt = match self.tokio_runtime.take() {
...@@ -749,7 +761,7 @@ struct ProcessValidatedRequest<L: Logger> { ...@@ -749,7 +761,7 @@ struct ProcessValidatedRequest<L: Logger> {
/// Process a verified request, it implies a POST request with content type JSON. /// Process a verified request, it implies a POST request with content type JSON.
async fn process_validated_request<L: Logger>( async fn process_validated_request<L: Logger>(
input: ProcessValidatedRequest<L>, input: ProcessValidatedRequest<L>,
) -> Result<hyper::Response<hyper::Body>, HyperError> { ) -> Result<hyper::Response<hyper::Body>, Infallible> {
let ProcessValidatedRequest { let ProcessValidatedRequest {
request, request,
logger, logger,
...@@ -825,7 +837,7 @@ async fn process_health_request<L: Logger>( ...@@ -825,7 +837,7 @@ async fn process_health_request<L: Logger>(
max_response_body_size: u32, max_response_body_size: u32,
request_start: L::Instant, request_start: L::Instant,
max_log_length: u32, max_log_length: u32,
) -> Result<hyper::Response<hyper::Body>, HyperError> { ) -> Result<hyper::Response<hyper::Body>, Infallible> {
let trace = RpcTracing::method_call(&health_api.method); let trace = RpcTracing::method_call(&health_api.method);
async { async {
tx_log_from_str("HTTP health API", max_log_length); tx_log_from_str("HTTP health API", max_log_length);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment