Unverified Commit 69441f34 authored by Alexandru Vasile's avatar Alexandru Vasile
Browse files

http-server: Add tower middleware on the HttpBuilder


Signed-off-by: default avatarAlexandru Vasile <alexandru.vasile@parity.io>
parent d64d1af4
Pipeline #207027 passed with stages
in 4 minutes and 56 seconds
......@@ -24,48 +24,20 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//! This example sets a custom tower service middleware to the RPC implementation.
use hyper::body::Bytes;
use hyper::Body;
use std::iter::once;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use std::time::Duration;
use tower_http::sensitive_headers::SetSensitiveRequestHeadersLayer;
use tower_http::trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer};
use tower_http::LatencyUnit;
use jsonrpsee::core::client::ClientT;
use jsonrpsee::core::logger::{self, Params, Request};
use jsonrpsee::http_client::HttpClientBuilder;
use jsonrpsee::http_server::{HttpServerBuilder, HttpServerHandle, RpcModule};
/// Define a custom logging mechanism to detect the time passed
/// between receiving the request and proving the response.
///
/// The implementation relies upon [logger::HttpLogger].
#[derive(Clone)]
struct Timings;
impl logger::HttpLogger for Timings {
type Instant = Instant;
fn on_request(&self, remote_addr: SocketAddr, request: &Request<Body>) -> Self::Instant {
println!("[Logger::on_request] remote_addr {}, request: {:?}", remote_addr, request);
Instant::now()
}
fn on_call(&self, name: &str, params: Params, kind: logger::MethodKind) {
println!("[Logger::on_call] method: '{}', params: {:?}, kind: {}", name, params, kind);
}
fn on_result(&self, name: &str, success: bool, started_at: Self::Instant) {
println!("[Logger::on_result] '{}', worked? {}, time elapsed {:?}", name, success, started_at.elapsed());
}
fn on_response(&self, result: &str, started_at: Self::Instant) {
println!("[Logger::on_response] result: {}, time elapsed {:?}", result, started_at.elapsed());
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::FmtSubscriber::builder()
......@@ -87,7 +59,7 @@ async fn main() -> anyhow::Result<()> {
async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> {
// Custom tower service to handle the RPC requests
let builder = tower::ServiceBuilder::new()
let service_builder = tower::ServiceBuilder::new()
// Add high level tracing/logging to all requests
.layer(
TraceLayer::new_for_http()
......@@ -101,13 +73,15 @@ async fn run_server() -> anyhow::Result<(SocketAddr, HttpServerHandle)> {
.layer(SetSensitiveRequestHeadersLayer::new(once(hyper::header::AUTHORIZATION)))
.timeout(Duration::from_secs(2));
let server = HttpServerBuilder::new().set_logger(Timings).build("127.0.0.1:0".parse::<SocketAddr>()?).await?;
let server =
HttpServerBuilder::new().set_middleware(service_builder).build("127.0.0.1:0".parse::<SocketAddr>()?).await?;
let addr = server.local_addr()?;
let mut module = RpcModule::new(());
module.register_method("say_hello", |_, _| Ok("lo")).unwrap();
let handler = server.start_with_builder(module, builder)?;
let handler = server.start(module)?;
Ok((addr, handler))
}
......@@ -24,6 +24,7 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use std::convert::Infallible;
use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin;
......@@ -34,6 +35,7 @@ use crate::response::{internal_error, malformed};
use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::stream::{StreamExt, TryStreamExt};
use futures_util::TryFutureExt;
use hyper::body::HttpBody;
use hyper::header::{HeaderMap, HeaderValue};
use hyper::server::conn::AddrStream;
......@@ -55,6 +57,7 @@ use jsonrpsee_types::{Id, Notification, Params, Request};
use serde::de::StdError;
use serde_json::value::RawValue;
use tokio::net::{TcpListener, ToSocketAddrs};
use tower::layer::util::Identity;
use tower::Layer;
use tracing_futures::Instrument;
......@@ -62,7 +65,7 @@ type Notif<'a> = Notification<'a, Option<&'a RawValue>>;
/// Builder to create JSON-RPC HTTP server.
#[derive(Debug)]
pub struct Builder<L = ()> {
pub struct Builder<B = Identity, L = ()> {
/// Access control based on HTTP headers.
access_control: AccessControl,
resources: Resources,
......@@ -74,6 +77,7 @@ pub struct Builder<L = ()> {
logger: L,
max_log_length: u32,
health_api: Option<HealthApi>,
service_builder: tower::ServiceBuilder<B>,
}
impl Default for Builder {
......@@ -88,6 +92,7 @@ impl Default for Builder {
logger: (),
max_log_length: 4096,
health_api: None,
service_builder: tower::ServiceBuilder::new(),
}
}
}
......@@ -99,9 +104,11 @@ impl Builder {
}
}
impl<L> Builder<L> {
impl<B, L> Builder<B, L> {
/// Add a logger to the builder [`Logger`](../jsonrpsee_core/logger/trait.Logger.html).
///
/// # Examples
///
/// ```
/// use std::{time::Instant, net::SocketAddr};
/// use hyper::Request;
......@@ -141,7 +148,7 @@ impl<L> Builder<L> {
///
/// let builder = HttpServerBuilder::new().set_logger(MyLogger);
/// ```
pub fn set_logger<T: Logger>(self, logger: T) -> Builder<T> {
pub fn set_logger<T: Logger>(self, logger: T) -> Builder<B, T> {
Builder {
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
......@@ -152,6 +159,7 @@ impl<L> Builder<L> {
logger,
max_log_length: self.max_log_length,
health_api: self.health_api,
service_builder: self.service_builder,
}
}
......@@ -216,8 +224,49 @@ impl<L> Builder<L> {
Ok(self)
}
/// Configure a custom [`tower::ServiceBuilder`] middleware for composing layers to be applied to the RPC service.
///
/// Default: No tower layers are applied to the RPC service.
///
/// # Examples
///
/// ```rust
///
/// use std::time::Duration;
/// use std::net::SocketAddr;
/// use jsonrpsee_http_server::HttpServerBuilder;
///
/// #[tokio::main]
/// async fn main() {
/// let builder = tower::ServiceBuilder::new()
/// .timeout(Duration::from_secs(2));
///
/// let server = HttpServerBuilder::new()
/// .set_middleware(builder)
/// .build("127.0.0.1:0".parse::<SocketAddr>().unwrap())
/// .await
/// .unwrap();
/// }
/// ```
pub fn set_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> Builder<T, L> {
Builder {
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
max_response_body_size: self.max_response_body_size,
batch_requests_supported: self.batch_requests_supported,
resources: self.resources,
tokio_runtime: self.tokio_runtime,
logger: self.logger,
max_log_length: self.max_log_length,
health_api: self.health_api,
service_builder,
}
}
/// Finalizes the configuration of the server with customized TCP settings on the socket and on hyper.
///
/// # Examples
///
/// ```rust
/// use jsonrpsee_http_server::HttpServerBuilder;
/// use socket2::{Domain, Socket, Type};
......@@ -252,7 +301,7 @@ impl<L> Builder<L> {
self,
listener: hyper::server::Builder<AddrIncoming>,
local_addr: SocketAddr,
) -> Result<Server<L>, Error> {
) -> Result<Server<B, L>, Error> {
Ok(Server {
access_control: self.access_control,
listener,
......@@ -265,6 +314,7 @@ impl<L> Builder<L> {
logger: self.logger,
max_log_length: self.max_log_length,
health_api: self.health_api,
service_builder: self.service_builder,
})
}
......@@ -292,7 +342,7 @@ impl<L> Builder<L> {
/// let server = HttpServerBuilder::new().build_from_tcp(socket).unwrap();
/// }
/// ```
pub fn build_from_tcp(self, listener: impl Into<StdTcpListener>) -> Result<Server<L>, Error> {
pub fn build_from_tcp(self, listener: impl Into<StdTcpListener>) -> Result<Server<B, L>, Error> {
let listener = listener.into();
let local_addr = listener.local_addr().ok();
......@@ -310,6 +360,7 @@ impl<L> Builder<L> {
logger: self.logger,
max_log_length: self.max_log_length,
health_api: self.health_api,
service_builder: self.service_builder,
})
}
......@@ -328,7 +379,7 @@ impl<L> Builder<L> {
/// assert!(jsonrpsee_http_server::HttpServerBuilder::default().build(addrs).await.is_ok());
/// }
/// ```
pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<L>, Error> {
pub async fn build(self, addrs: impl ToSocketAddrs) -> Result<Server<B, L>, Error> {
let listener = TcpListener::bind(addrs).await?.into_std()?;
let local_addr = listener.local_addr().ok();
......@@ -346,6 +397,7 @@ impl<L> Builder<L> {
logger: self.logger,
max_log_length: self.max_log_length,
health_api: self.health_api,
service_builder: self.service_builder,
})
}
}
......@@ -419,7 +471,7 @@ impl<L: Logger> ServiceData<L> {
async fn handle_request(
self,
request: hyper::Request<hyper::Body>,
) -> Result<hyper::Response<hyper::Body>, HyperError> {
) -> Result<hyper::Response<hyper::Body>, Infallible> {
let ServiceData {
remote_addr,
methods,
......@@ -539,7 +591,14 @@ pub struct TowerService<L> {
impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerService<L> {
type Response = hyper::Response<hyper::Body>;
type Error = hyper::Error;
// NOTE(lexnv): The `handle_request` method returns `Result<_, Infallible>`.
// This is because the RPC service will always return a valid HTTP response (ie return `Ok(_)`).
//
// The following associated type is required by the `impl<B, U, L: Logger> Server<B, L>` bounds.
// It satisfies the server's bounds when the `tower::ServiceBuilder<B>` is not set (ie `B: Identity`).
type Error = Box<(dyn StdError + Send + Sync + 'static)>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
/// Opens door for back pressure implementation.
......@@ -549,13 +608,16 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
fn call(&mut self, request: hyper::Request<hyper::Body>) -> Self::Future {
let data = self.inner.clone();
Box::pin(data.handle_request(request))
// Note that `handle_request` will never return error.
// The dummy error is set in place to satisfy the server's trait bounds regarding the
// `tower::ServiceBuilder` and the error will never be mapped.
Box::pin(data.handle_request(request).map_err(|_| "".into()))
}
}
/// An HTTP JSON RPC server.
#[derive(Debug)]
pub struct Server<L = ()> {
pub struct Server<B = Identity, L = ()> {
/// Hyper server.
listener: HyperBuilder<AddrIncoming>,
/// Local address
......@@ -578,80 +640,28 @@ pub struct Server<L = ()> {
tokio_runtime: Option<tokio::runtime::Handle>,
logger: L,
health_api: Option<HealthApi>,
service_builder: tower::ServiceBuilder<B>,
}
impl<L: Logger> Server<L> {
impl<B, U, L: Logger> Server<B, L>
where
B: Layer<TowerService<L>> + Send + 'static,
<B as Layer<TowerService<L>>>::Service: Send
+ Service<
hyper::Request<Body>,
Response = hyper::Response<U>,
Error = Box<(dyn StdError + Send + Sync + 'static)>,
>,
<<B as Layer<TowerService<L>>>::Service as Service<hyper::Request<Body>>>::Future: Send,
U: HttpBody + Send + 'static,
<U as HttpBody>::Error: Send + Sync + StdError,
<U as HttpBody>::Data: Send,
{
/// Returns socket address to which the server is bound.
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into()))
}
/// Start the server with a custom tower builder.
pub fn start_with_builder<T, U>(
mut self,
methods: impl Into<Methods>,
builder: tower::ServiceBuilder<T>,
) -> Result<ServerHandle, Error>
where
T: Layer<TowerService<L>> + Send + 'static,
<T as Layer<TowerService<L>>>::Service: Send
+ Service<
hyper::Request<Body>,
Response = hyper::Response<U>,
Error = Box<(dyn StdError + Send + Sync + 'static)>,
>,
<<T as Layer<TowerService<L>>>::Service as Service<hyper::Request<Body>>>::Future: Send,
U: HttpBody + Send + 'static,
<U as HttpBody>::Error: Send + Sync + StdError,
<U as HttpBody>::Data: Send,
{
let max_request_body_size = self.max_request_body_size;
let max_response_body_size = self.max_response_body_size;
let max_log_length = self.max_log_length;
let acl = self.access_control;
let (tx, mut rx) = mpsc::channel(1);
let listener = self.listener;
let resources = self.resources;
let logger = self.logger;
let batch_requests_supported = self.batch_requests_supported;
let methods = methods.into().initialize_resources(&resources)?;
let health_api = self.health_api;
let make_service = make_service_fn(move |conn: &AddrStream| {
let service = TowerService {
inner: ServiceData {
remote_addr: conn.remote_addr(),
methods: methods.clone(),
acl: acl.clone(),
resources: resources.clone(),
logger: logger.clone(),
health_api: health_api.clone(),
max_request_body_size,
max_response_body_size,
max_log_length,
batch_requests_supported,
},
};
let server = builder.service(service);
// For every request the `TowerService` is calling into `ServiceData::handle_request`
// where the RPSee bare implementation resides.
async move { Ok::<_, HyperError>(server) }
});
let rt = match self.tokio_runtime.take() {
Some(rt) => rt,
None => tokio::runtime::Handle::current(),
};
let handle = rt.spawn(async move {
let server = listener.serve(make_service);
let _ = server.with_graceful_shutdown(async move { rx.next().await.map_or((), |_| ()) }).await;
});
Ok(ServerHandle { handle: Some(handle), stop_sender: tx })
}
/// Start the server.
pub fn start(mut self, methods: impl Into<Methods>) -> Result<ServerHandle, Error> {
let max_request_body_size = self.max_request_body_size;
......@@ -682,9 +692,11 @@ impl<L: Logger> Server<L> {
},
};
let server = self.service_builder.service(service);
// For every request the `TowerService` is calling into `ServiceData::handle_request`
// where the RPSee bare implementation resides.
async move { Ok::<_, HyperError>(service) }
async move { Ok::<_, HyperError>(server) }
});
let rt = match self.tokio_runtime.take() {
......@@ -749,7 +761,7 @@ struct ProcessValidatedRequest<L: Logger> {
/// Process a verified request, it implies a POST request with content type JSON.
async fn process_validated_request<L: Logger>(
input: ProcessValidatedRequest<L>,
) -> Result<hyper::Response<hyper::Body>, HyperError> {
) -> Result<hyper::Response<hyper::Body>, Infallible> {
let ProcessValidatedRequest {
request,
logger,
......@@ -825,7 +837,7 @@ async fn process_health_request<L: Logger>(
max_response_body_size: u32,
request_start: L::Instant,
max_log_length: u32,
) -> Result<hyper::Response<hyper::Body>, HyperError> {
) -> Result<hyper::Response<hyper::Body>, Infallible> {
let trace = RpcTracing::method_call(&health_api.method);
async {
tx_log_from_str("HTTP health API", max_log_length);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment