Newer
Older
// Copyright 2019-2021 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any
// person obtaining a copy of this software and associated
// documentation files (the "Software"), to deal in the
// Software without restriction, including without
// limitation the rights to use, copy, modify, merge,
// publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software
// is furnished to do so, subject to the following
// conditions:
//
// The above copyright notice and this permission notice
// shall be included in all copies or substantial portions
// of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures_channel::mpsc::{self, Receiver, Sender};
use futures_channel::oneshot;
use futures_util::{
future::FutureExt,
io::{BufReader, BufWriter},
pin_mut, select,
sink::SinkExt,
stream::{self, StreamExt},
};
use soketto::handshake::{self, http::is_upgrade_request, server::Response, Error as SokettoError, Server};
use std::{io, net::SocketAddr, time::Duration};
use tokio::net::TcpStream;
use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
pub use hyper::{Body, HeaderMap, StatusCode, Uri};
type Error = Box<dyn std::error::Error>;
pub struct TestContext;
impl TestContext {
pub fn ok(&self) -> Result<(), anyhow::Error> {
Ok(())
}
pub fn err(&self) -> Result<(), anyhow::Error> {
Err(anyhow::anyhow!("RPC context failed"))
}
}
/// Request Id
#[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
#[serde(untagged)]
pub enum Id {
/// No id (notification)
Null,
/// Numeric id
Num(u64),
/// String id
Str(String),
}
#[derive(Debug)]
pub struct HttpResponse {
pub status: StatusCode,
pub header: HeaderMap,
pub body: String,
}
/// WebSocket client to construct with arbitrary payload to construct bad payloads.
pub struct WebSocketTestClient {
tx: soketto::Sender<BufReader<BufWriter<Compat<TcpStream>>>>,
rx: soketto::Receiver<BufReader<BufWriter<Compat<TcpStream>>>>,
}
impl std::fmt::Debug for WebSocketTestClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WebSocketTestClient")
}
}
#[derive(Debug)]
pub enum WebSocketTestError {
Redirect,
RejectedWithStatusCode(u16),
Soketto(SokettoError),
}
impl From<io::Error> for WebSocketTestError {
fn from(err: io::Error) -> Self {
WebSocketTestError::Soketto(SokettoError::Io(err))
}
}
pub async fn new(url: SocketAddr) -> Result<Self, WebSocketTestError> {
let socket = TcpStream::connect(url).await?;
let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/");
match client.handshake().await {
Ok(handshake::ServerResponse::Accepted { .. }) => {
let (tx, rx) = client.into_builder().finish();
Ok(Self { tx, rx })
}
Ok(handshake::ServerResponse::Redirect { .. }) => Err(WebSocketTestError::Redirect),
Ok(handshake::ServerResponse::Rejected { status_code }) => {
Err(WebSocketTestError::RejectedWithStatusCode(status_code))
Err(err) => Err(WebSocketTestError::Soketto(err)),
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
}
}
pub async fn send_request_text(&mut self, msg: impl AsRef<str>) -> Result<String, Error> {
self.tx.send_text(msg).await?;
self.tx.flush().await?;
let mut data = Vec::new();
self.rx.receive_data(&mut data).await?;
String::from_utf8(data).map_err(Into::into)
}
pub async fn send_request_binary(&mut self, msg: &[u8]) -> Result<String, Error> {
self.tx.send_binary(msg).await?;
self.tx.flush().await?;
let mut data = Vec::new();
self.rx.receive_data(&mut data).await?;
String::from_utf8(data).map_err(Into::into)
}
pub async fn close(&mut self) -> Result<(), Error> {
self.tx.close().await.map_err(Into::into)
}
}
#[derive(Debug, Clone)]
pub enum ServerMode {
// Send out a hardcoded response on every connection.
Response(String),
// Send out a subscription ID on a request and continuously send out data on the subscription.
Subscription { subscription_id: String, subscription_response: String },
// Send out a notification after timeout
Notification(String),
}
/// JSONRPC v2 dummy WebSocket server that sends a hardcoded response.
pub struct WebSocketTestServer {
local_addr: SocketAddr,
exit: Sender<()>,
}
impl WebSocketTestServer {
// Spawns a dummy `JSONRPC v2` WebSocket server that sends out a pre-configured `hardcoded response` for every
// connection.
pub async fn with_hardcoded_response(sockaddr: SocketAddr, response: String) -> Self {
let listener = tokio::net::TcpListener::bind(sockaddr).await.unwrap();
let local_addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel::<()>(4);
tokio::spawn(server_backend(listener, rx, ServerMode::Response(response)));
Self { local_addr, exit: tx }
}
// Spawns a dummy `JSONRPC v2` WebSocket server that sends out a pre-configured `hardcoded notification` for every
// connection.
pub async fn with_hardcoded_notification(sockaddr: SocketAddr, notification: String) -> Self {
let (tx, rx) = mpsc::channel::<()>(1);
let (addr_tx, addr_rx) = oneshot::channel();
std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
let listener = rt.block_on(tokio::net::TcpListener::bind(sockaddr)).unwrap();
let local_addr = listener.local_addr().unwrap();
addr_tx.send(local_addr).unwrap();
rt.block_on(server_backend(listener, rx, ServerMode::Notification(notification)));
});
let local_addr = addr_rx.await.unwrap();
Self { local_addr, exit: tx }
}
// Spawns a dummy `JSONRPC v2` WebSocket server that sends out a pre-configured subscription ID and subscription
// response.
//
// NOTE: ignores the actual subscription and unsubscription method.
pub async fn with_hardcoded_subscription(
sockaddr: SocketAddr,
subscription_id: String,
subscription_response: String,
) -> Self {
let listener = tokio::net::TcpListener::bind(sockaddr).await.unwrap();
let local_addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel::<()>(4);
tokio::spawn(server_backend(listener, rx, ServerMode::Subscription { subscription_id, subscription_response }));
Self { local_addr, exit: tx }
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub async fn close(&mut self) {
self.exit.send(()).await.unwrap();
}
}
async fn server_backend(listener: tokio::net::TcpListener, mut exit: Receiver<()>, mode: ServerMode) {
let mut connections = Vec::new();
loop {
let conn_fut = listener.accept().fuse();
let exit_fut = exit.next();
pin_mut!(exit_fut, conn_fut);
_ = exit_fut => break,
conn = conn_fut => {
if let Ok((stream, _)) = conn {
let (tx, rx) = mpsc::channel::<()>(4);
let handle = tokio::spawn(connection_task(stream, mode.clone(), rx));
connections.push((handle, tx));
}
}
}
}
// close connections
for (handle, mut exit) in connections {
// If the actual connection was never established i.e., returned early
// It will most likely be caught on the client-side but just to be explicit.
exit.send(()).await.expect("WebSocket connection was never established");
handle.await.unwrap();
}
}
async fn connection_task(socket: tokio::net::TcpStream, mode: ServerMode, mut exit: Receiver<()>) {
let mut server = Server::new(socket.compat());
let key = match server.receive_request().await {
Ok(req) => req.key(),
let accept = server.send_response(&Response::Accept { key, protocol: None }).await;
if accept.is_err() {
return;
}
let (mut sender, receiver) = server.into_builder().finish();
let ws_stream = stream::unfold(receiver, move |mut receiver| async {
let mut buf = Vec::new();
let ret = match receiver.receive_data(&mut buf).await {
Ok(_) => Ok(buf),
Err(err) => Err(err),
};
Some((ret, receiver))
});
loop {
let next_ws = ws_stream.next().fuse();
let next_exit = exit.next().fuse();
let time_out = tokio::time::sleep(Duration::from_millis(200)).fuse();
pin_mut!(time_out, next_exit, next_ws);
match &mode {
ServerMode::Subscription { subscription_response, .. } => {
if let Err(e) = sender.send_text(&subscription_response).await {
log::warn!("send response to subscription: {:?}", e);
}
},
ServerMode::Notification(n) => {
if let Err(e) = sender.send_text(&n).await {
log::warn!("send notification: {:?}", e);
}
},
_ => {}
}
}
ws = next_ws => {
// Got a request on the connection but don't care about the contents.
// Just send out the pre-configured hardcoded responses.
if let Some(Ok(_)) = ws {
match &mode {
ServerMode::Response(r) => {
if let Err(e) = sender.send_text(&r).await {
log::warn!("send response to request error: {:?}", e);
}
ServerMode::Subscription { subscription_id, .. } => {
if let Err(e) = sender.send_text(&subscription_id).await {
log::warn!("send subscription id error: {:?}", e);
}
}
}
}
_ = next_exit => break,
}
}
}
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
// Run a WebSocket server running on localhost that redirects requests for testing.
// Requests to any url except for `/myblock/two` will redirect one or two times (HTTP 301) and eventually end up in `/myblock/two`.
pub fn ws_server_with_redirect(other_server: String) -> String {
let addr = ([127, 0, 0, 1], 0).into();
let service = hyper::service::make_service_fn(move |_| {
let other_server = other_server.clone();
async move {
Ok::<_, hyper::Error>(hyper::service::service_fn(move |req| {
let other_server = other_server.clone();
async move { handler(req, other_server).await }
}))
}
});
let server = hyper::Server::bind(&addr).serve(service);
let addr = server.local_addr();
tokio::spawn(async move { server.await });
format!("ws://{}", addr)
}
/// Handle incoming HTTP Requests.
async fn handler(
req: hyper::Request<Body>,
other_server: String,
) -> Result<hyper::Response<Body>, soketto::BoxedError> {
if is_upgrade_request(&req) {
log::debug!("{:?}", req);
match req.uri().path() {
"/myblock/two" => {
let response = hyper::Response::builder()
.status(301)
.header("Location", other_server)
.body(Body::empty())
.unwrap();
Ok(response)
}
"/myblock/one" => {
let response =
hyper::Response::builder().status(301).header("Location", "two").body(Body::empty()).unwrap();
Ok(response)
}
_ => {
let response = hyper::Response::builder()
.status(301)
.header("Location", "/myblock/one")
.body(Body::empty())
.unwrap();
Ok(response)
}
}
} else {
panic!("expect upgrade to WS");
}
}