Newer
Older
// Copyright 2019-2021 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any
// person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the
// Software without restriction, including without
// limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice
// shall be included in all copies or substantial portions
// of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
// IN background_task WITH THE SOFTWARE OR THE USE OR OTHER
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::future::{FutureDriver, StopHandle, StopMonitor};
use crate::types::{
error::Error,
v2::{ErrorCode, Id, Request},
TEN_MB_SIZE_BYTES,
};
use futures_channel::mpsc;
use soketto::handshake::{server::Response, Server as SokettoServer};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
David
committed
use jsonrpsee_utils::server::helpers::{collect_batch_response, prepare_error, send_error};
use jsonrpsee_utils::server::rpc_module::{ConnectionId, Methods};
/// Default maximum connections allowed.
const MAX_CONNECTIONS: u64 = 100;
/// A WebSocket JSON RPC server.
#[derive(Debug)]
pub struct Server {
listener: TcpListener,
/// Returns socket address to which the server is bound.
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
/// Returns the handle to stop the running server.
pub fn stop_handle(&self) -> StopHandle {
/// Start responding to connections requests. This will block current thread until the server is stopped.
pub async fn start(self, methods: impl Into<Methods>) {
let mut connections = FutureDriver::default();
let mut incoming = Incoming::new(self.listener, &stop_monitor);
match connections.select_with(&mut incoming).await {
Ok((socket, _addr)) => {
if let Err(e) = socket.set_nodelay(true) {
log::error!("Could not set NODELAY on socket: {:?}", e);
continue;
if connections.count() >= self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while.");
connections.add(Box::pin(handshake(socket, HandshakeResponse::Reject { status_code: 429 })));
let methods = &methods;
let cfg = &self.cfg;
connections.add(Box::pin(handshake(
socket,
HandshakeResponse::Accept { conn_id: id, methods, cfg, stop_monitor: &stop_monitor },
)));
id = id.wrapping_add(1);
}
Err(IncomingError::Io(err)) => {
log::error!("Error while awaiting a new connection: {:?}", err);
}
Err(IncomingError::Shutdown) => break,
/// This is a glorified select listening to new connections, while also checking
/// for `stop_receiver` signal.
listener: TcpListener,
impl<'a> Incoming<'a> {
fn new(listener: TcpListener, stop_monitor: &'a StopMonitor) -> Self {
Incoming { listener, stop_monitor }
Shutdown,
Io(std::io::Error),
}
type Output = Result<(TcpStream, SocketAddr), IncomingError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = Pin::into_inner(self);
return Poll::Ready(Err(IncomingError::Shutdown));
this.listener.poll_accept(cx).map_err(IncomingError::Io)
enum HandshakeResponse<'a> {
Reject { status_code: u16 },
Accept { conn_id: ConnectionId, methods: &'a Methods, cfg: &'a Settings, stop_monitor: &'a StopMonitor },
}
async fn handshake(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_>) -> Result<(), Error> {
// For each incoming background_task we perform a handshake.
let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));
match mode {
HandshakeResponse::Reject { status_code } => {
// Forced rejection, don't need to read anything from the socket
let reject = Response::Reject { status_code };
server.send_response(&reject).await?;
let (mut sender, _) = server.into_builder().finish();
// Gracefully shut down the connection
sender.close().await?;
HandshakeResponse::Accept { conn_id, methods, cfg, stop_monitor } => {
let key = {
let req = server.receive_request().await?;
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin);
host_check.and(origin_check).map(|()| req.key())
};
match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
server.send_response(&reject).await?;
return Err(error);
}
}
let join_result = tokio::spawn(background_task(
server,
conn_id,
methods.clone(),
cfg.max_request_body_size,
stop_monitor.clone(),
))
.await;
match join_result {
Err(_) => Err(Error::Custom("Background task was aborted".into())),
Ok(result) => result,
}
}
}
}
async fn background_task(
server: SokettoServer<'_, BufReader<BufWriter<Compat<tokio::net::TcpStream>>>>,
conn_id: ConnectionId,
methods: Methods,
max_request_body_size: u32,
) -> Result<(), Error> {
// And we can finally transition to a websocket background_task.
let (mut sender, mut receiver) = server.into_builder().finish();
let (tx, mut rx) = mpsc::unbounded::<String>();
Niklas Adolfsson
committed
match rx.next().await {
Some(response) => {
log::debug!("send: {}", response);
let _ = sender.send_text(response).await;
let _ = sender.flush().await;
}
None => break,
};
Niklas Adolfsson
committed
// terminate connection.
let _ = sender.close().await;
let mut method_executors = FutureDriver::default();
method_executors.select_with(receiver.receive_data(&mut data)).await?;
if data.len() > max_request_body_size as usize {
log::warn!("Request is too big ({} bytes, max is {})", data.len(), max_request_body_size);
send_error(Id::Null, &tx, ErrorCode::OversizedRequest.into());
David
committed
match data.get(0) {
Some(b'{') => {
if let Ok(req) = serde_json::from_slice::<Request>(&data) {
David
committed
log::debug!("recv: {:?}", req);
if let Some(fut) = methods.execute(&tx, req, conn_id) {
method_executors.add(fut);
}
David
committed
} else {
let (id, code) = prepare_error(&data);
send_error(id, &tx, code.into());
David
committed
Some(b'[') => {
if let Ok(batch) = serde_json::from_slice::<Vec<Request>>(&data) {
David
committed
if !batch.is_empty() {
// Batch responses must be sent back as a single message so we read the results from each
// request in the batch and read the results off of a new channel, `rx_batch`, and then send the
// complete batch response back to the client over `tx`.
David
committed
let (tx_batch, mut rx_batch) = mpsc::unbounded::<String>();
for fut in batch.into_iter().filter_map(|req| methods.execute(&tx_batch, req, conn_id)) {
method_executors.add(fut);
}
David
committed
// Closes the receiving half of a channel without dropping it. This prevents any further
// messages from being sent on the channel.
David
committed
rx_batch.close();
let results = collect_batch_response(rx_batch).await;
if let Err(err) = tx.unbounded_send(results) {
log::error!("Error sending batch response to the client: {:?}", err)
}
} else {
send_error(Id::Null, &tx, ErrorCode::InvalidRequest.into());
David
committed
}
} else {
let (id, code) = prepare_error(&data);
send_error(id, &tx, code.into());
}
}
_ => send_error(Id::Null, &tx, ErrorCode::ParseError.into()),
// Drive all running methods to completion
method_executors.await;
// Drop the monitor for this task since we are shutting down
drop(stop_monitor);
Niklas Adolfsson
committed
Ok(())
impl AllowedValue {
fn verify(&self, header: &str, value: Option<&[u8]>) -> Result<(), Error> {
if let (AllowedValue::OneOf(list), Some(value)) = (self, value) {
if !list.iter().any(|o| o.as_bytes() == value) {
let error = format!("{} denied: {}", header, String::from_utf8_lossy(value));
log::warn!("{}", error);
return Err(Error::Request(error));
}
}
Ok(())
}
}
struct Settings {
/// Maximum size in bytes of a request.
max_request_body_size: u32,
/// Maximum number of incoming connections allowed.
max_connections: u64,
/// Policy by which to accept or deny incoming requests based on the `Origin` header.
allowed_origins: AllowedValue,
/// Policy by which to accept or deny incoming requests based on the `Host` header.
allowed_hosts: AllowedValue,
}
impl Default for Settings {
fn default() -> Self {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_connections: MAX_CONNECTIONS,
allowed_origins: AllowedValue::Any,
allowed_hosts: AllowedValue::Any,
}
}
/// Builder to configure and create a JSON-RPC Websocket server
#[derive(Debug)]
pub struct Builder {
settings: Settings,
}
impl Builder {
/// Set the maximum size of a request body in bytes. Default is 10 MiB.
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.settings.max_request_body_size = size;
self
}
/// Set the maximum number of connections allowed. Default is 100.
pub fn max_connections(mut self, max: u64) -> Self {
self.settings.max_connections = max;
self
}
/// Set a list of allowed origins. During the handshake, the `Origin` header will be
/// checked against the list, connections without a matching origin will be denied.
/// Values should be hostnames with protocol.
///
/// ```rust
/// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default();
/// builder.set_allowed_origins(["https://example.com"]);
/// ```
///
/// By default allows any `Origin`.
///
/// Will return an error if `list` is empty. Use [`allow_all_origins`](Builder::allow_all_origins) to restore the
/// default.
pub fn set_allowed_origins<Origin, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Origin>,
Origin: Into<String>,
{
let list: Box<_> = list.into_iter().map(Into::into).collect();
return Err(Error::EmptyAllowList("Origin"));
self.settings.allowed_origins = AllowedValue::OneOf(list);
Ok(self)
}
/// Restores the default behavior of allowing connections with `Origin` header
/// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins).
pub fn allow_all_origins(mut self) -> Self {
self.settings.allowed_origins = AllowedValue::Any;
self
}
/// Set a list of allowed hosts. During the handshake, the `Host` header will be
/// checked against the list. Connections without a matching host will be denied.
/// Values should be hostnames without protocol.
///
/// ```rust
/// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default();
/// builder.set_allowed_hosts(["example.com"]);
/// ```
///
/// By default allows any `Host`.
///
/// Will return an error if `list` is empty. Use [`allow_all_hosts`](Builder::allow_all_hosts) to restore the
/// default.
pub fn set_allowed_hosts<Host, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Host>,
Host: Into<String>,
{
let list: Box<_> = list.into_iter().map(Into::into).collect();
if list.len() == 0 {
return Err(Error::EmptyAllowList("Host"));
}
self.settings.allowed_hosts = AllowedValue::OneOf(list);
Ok(self)
}
/// Restores the default behavior of allowing connections with `Host` header
/// containing any value. This will undo any list set by [`set_allowed_hosts`](Builder::set_allowed_hosts).
pub fn allow_all_hosts(mut self) -> Self {
self.settings.allowed_hosts = AllowedValue::Any;
/// Finalize the configuration of the server. Consumes the [`Builder`].
pub async fn build(self, addr: impl ToSocketAddrs) -> Result<Server, Error> {
let listener = TcpListener::bind(addr).await?;
Ok(Server { listener, cfg: self.settings, stop_monitor })