use futures_channel::mpsc::{self, Receiver, Sender}; use futures_util::{ future::FutureExt, io::{BufReader, BufWriter}, pin_mut, select, sink::SinkExt, stream::{self, StreamExt}, }; use serde::{Deserialize, Serialize}; use soketto::handshake; use soketto::handshake::{server::Response, Server}; use std::net::SocketAddr; use std::time::Duration; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; pub use hyper::{Body, HeaderMap, StatusCode, Uri}; type Error = Box; pub struct TestContext; impl TestContext { pub fn ok(&self) -> Result<(), anyhow::Error> { Ok(()) } pub fn err(&self) -> Result<(), anyhow::Error> { Err(anyhow::anyhow!("RPC context failed")) } } /// Request Id #[derive(Debug, PartialEq, Clone, Hash, Eq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Id { /// No id (notification) Null, /// Numeric id Num(u64), /// String id Str(String), } #[derive(Debug)] pub struct HttpResponse { pub status: StatusCode, pub header: HeaderMap, pub body: String, } /// WebSocket client to construct with arbitrary payload to construct bad payloads. pub struct WebSocketTestClient { tx: soketto::Sender>>>, rx: soketto::Receiver>>>, } impl WebSocketTestClient { pub async fn new(url: SocketAddr) -> Result { let socket = TcpStream::connect(url).await?; let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/"); match client.handshake().await { Ok(handshake::ServerResponse::Accepted { .. }) => { let (tx, rx) = client.into_builder().finish(); Ok(Self { tx, rx }) } r => Err(format!("WebSocketHandshake failed: {:?}", r).into()), } } pub async fn send_request_text(&mut self, msg: impl AsRef) -> Result { self.tx.send_text(msg).await?; self.tx.flush().await?; let mut data = Vec::new(); self.rx.receive_data(&mut data).await?; String::from_utf8(data).map_err(Into::into) } pub async fn send_request_binary(&mut self, msg: &[u8]) -> Result { self.tx.send_binary(msg).await?; self.tx.flush().await?; let mut data = Vec::new(); self.rx.receive_data(&mut data).await?; String::from_utf8(data).map_err(Into::into) } pub async fn close(&mut self) -> Result<(), Error> { self.tx.close().await.map_err(Into::into) } } #[derive(Debug, Clone)] pub enum ServerMode { // Send out a hardcoded response on every connection. Response(String), // Send out a subscription ID on a request and continuously send out data on the subscription. Subscription { subscription_id: String, subscription_response: String }, } /// JSONRPC v2 dummy WebSocket server that sends a hardcoded response. pub struct WebSocketTestServer { local_addr: SocketAddr, exit: Sender<()>, } impl WebSocketTestServer { // Spawns a dummy `JSONRPC v2` WebSocket server that sends out a pre-configured `hardcoded response` for every connection. pub async fn with_hardcoded_response(sockaddr: SocketAddr, response: String) -> Self { let listener = async_std::net::TcpListener::bind(sockaddr).await.unwrap(); let local_addr = listener.local_addr().unwrap(); let (tx, rx) = mpsc::channel::<()>(4); tokio::spawn(server_backend(listener, rx, ServerMode::Response(response))); Self { local_addr, exit: tx } } // Spawns a dummy `JSONRPC v2` WebSocket server that sends out a pre-configured subscription ID and subscription response. // // NOTE: ignores the actual subscription and unsubscription method. pub async fn with_hardcoded_subscription( sockaddr: SocketAddr, subscription_id: String, subscription_response: String, ) -> Self { let listener = async_std::net::TcpListener::bind(sockaddr).await.unwrap(); let local_addr = listener.local_addr().unwrap(); let (tx, rx) = mpsc::channel::<()>(4); tokio::spawn(server_backend(listener, rx, ServerMode::Subscription { subscription_id, subscription_response })); Self { local_addr, exit: tx } } pub fn local_addr(&self) -> SocketAddr { self.local_addr } pub async fn close(&mut self) { self.exit.send(()).await.unwrap(); } } async fn server_backend(listener: async_std::net::TcpListener, mut exit: Receiver<()>, mode: ServerMode) { let mut connections = Vec::new(); loop { let conn_fut = listener.accept().fuse(); let exit_fut = exit.next(); pin_mut!(exit_fut, conn_fut); select! { _ = exit_fut => break, conn = conn_fut => { if let Ok((stream, _)) = conn { let (tx, rx) = mpsc::channel::<()>(4); let handle = tokio::spawn(connection_task(stream, mode.clone(), rx)); connections.push((handle, tx)); } } } } // close connections for (handle, mut exit) in connections { // If the actual connection was never established i.e., returned early // It will most likely be caught on the client-side but just to be explicit. exit.send(()).await.expect("WebSocket connection was never established"); handle.await.unwrap(); } } async fn connection_task(socket: async_std::net::TcpStream, mode: ServerMode, mut exit: Receiver<()>) { let mut server = Server::new(socket); let websocket_key = match server.receive_request().await { Ok(req) => req.into_key(), Err(_) => return, }; let accept = server.send_response(&Response::Accept { key: &websocket_key, protocol: None }).await; if accept.is_err() { return; } let (mut sender, receiver) = server.into_builder().finish(); let ws_stream = stream::unfold(receiver, move |mut receiver| async { let mut buf = Vec::new(); let ret = match receiver.receive_data(&mut buf).await { Ok(_) => Ok(buf), Err(err) => Err(err), }; Some((ret, receiver)) }); pin_mut!(ws_stream); loop { let next_ws = ws_stream.next().fuse(); let next_exit = exit.next().fuse(); let time_out = tokio::time::sleep(Duration::from_secs(1)).fuse(); pin_mut!(time_out, next_exit, next_ws); select! { _ = time_out => { if let ServerMode::Subscription { subscription_response, .. } = &mode { if let Err(e) = sender.send_text(&subscription_response).await { log::warn!("send response to subscription: {:?}", e); } } } ws = next_ws => { // Got a request on the connection but don't care about the contents. // Just send out the pre-configured hardcoded responses. if let Some(Ok(_)) = ws { match &mode { ServerMode::Response(r) => { if let Err(e) = sender.send_text(&r).await { log::warn!("send response to request error: {:?}", e); } } ServerMode::Subscription { subscription_id, .. } => { if let Err(e) = sender.send_text(&subscription_id).await { log::warn!("send subscription id error: {:?}", e); } } } } } _ = next_exit => break, } } }