diff --git a/substrate/bin/node/cli/tests/benchmark_pallet_works.rs b/substrate/bin/node/cli/tests/benchmark_pallet_works.rs index 1a278f985da4fcb5649695f5f29f35fafa70ae53..053516b9521d18174af3adf0f979d5f6fd1d8415 100644 --- a/substrate/bin/node/cli/tests/benchmark_pallet_works.rs +++ b/substrate/bin/node/cli/tests/benchmark_pallet_works.rs @@ -16,6 +16,8 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see <https://www.gnu.org/licenses/>. +#![cfg(feature = "runtime-benchmarks")] + use assert_cmd::cargo::cargo_bin; use std::process::Command; diff --git a/substrate/bin/node/cli/tests/benchmark_storage_works.rs b/substrate/bin/node/cli/tests/benchmark_storage_works.rs index 1f1181218db03fe65ff6a6c5171b3b0d60b386c0..953c07ca7f0dbf20d923cfccd6e1a7badfe85382 100644 --- a/substrate/bin/node/cli/tests/benchmark_storage_works.rs +++ b/substrate/bin/node/cli/tests/benchmark_storage_works.rs @@ -16,6 +16,8 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see <https://www.gnu.org/licenses/>. +#![cfg(feature = "runtime-benchmarks")] + use assert_cmd::cargo::cargo_bin; use std::{ path::Path, diff --git a/substrate/bin/node/cli/tests/common.rs b/substrate/bin/node/cli/tests/common.rs index a0de7300fca90b30d18e836ca218d8c47bb1c779..4c4824391fdc17be26082cfecaa40d88a9632b5b 100644 --- a/substrate/bin/node/cli/tests/common.rs +++ b/substrate/bin/node/cli/tests/common.rs @@ -20,54 +20,26 @@ use assert_cmd::cargo::cargo_bin; use nix::{ - sys::signal::{kill, Signal::SIGINT}, + sys::signal::{kill, Signal, Signal::SIGINT}, unistd::Pid, }; use node_primitives::{Hash, Header}; +use regex::Regex; use std::{ io::{BufRead, BufReader, Read}, ops::{Deref, DerefMut}, - path::Path, - process::{self, Child, Command, ExitStatus}, + path::{Path, PathBuf}, + process::{self, Child, Command}, time::Duration, }; -use tokio::time::timeout; -/// Wait for the given `child` the given number of `secs`. -/// -/// Returns the `Some(exit status)` or `None` if the process did not finish in the given time. -pub fn wait_for(child: &mut Child, secs: u64) -> Result<ExitStatus, ()> { - let result = wait_timeout::ChildExt::wait_timeout(child, Duration::from_secs(5.min(secs))) - .map_err(|_| ())?; - if let Some(exit_status) = result { - Ok(exit_status) - } else { - if secs > 5 { - eprintln!("Child process taking over 5 seconds to exit gracefully"); - let result = wait_timeout::ChildExt::wait_timeout(child, Duration::from_secs(secs - 5)) - .map_err(|_| ())?; - if let Some(exit_status) = result { - return Ok(exit_status) - } - } - eprintln!("Took too long to exit (> {} seconds). Killing...", secs); - let _ = child.kill(); - child.wait().unwrap(); - Err(()) - } -} - -/// Wait for at least n blocks to be finalized within a specified time. -pub async fn wait_n_finalized_blocks( - n: usize, - timeout_secs: u64, - url: &str, -) -> Result<(), tokio::time::error::Elapsed> { - timeout(Duration::from_secs(timeout_secs), wait_n_finalized_blocks_from(n, url)).await +/// Run the given `future` and panic if the `timeout` is hit. +pub async fn run_with_timeout(timeout: Duration, future: impl futures::Future<Output = ()>) { + tokio::time::timeout(timeout, future).await.expect("Hit timeout"); } /// Wait for at least n blocks to be finalized from a specified node -pub async fn wait_n_finalized_blocks_from(n: usize, url: &str) { +pub async fn wait_n_finalized_blocks(n: usize, url: &str) { use substrate_rpc_client::{ws_client, ChainApi}; let mut built_blocks = std::collections::HashSet::new(); @@ -87,46 +59,54 @@ pub async fn wait_n_finalized_blocks_from(n: usize, url: &str) { /// Run the node for a while (3 blocks) pub async fn run_node_for_a_while(base_path: &Path, args: &[&str]) { - let mut cmd = Command::new(cargo_bin("substrate")) - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .args(args) - .arg("-d") - .arg(base_path) - .spawn() - .unwrap(); + run_with_timeout(Duration::from_secs(60 * 10), async move { + let mut cmd = Command::new(cargo_bin("substrate")) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .args(args) + .arg("-d") + .arg(base_path) + .spawn() + .unwrap(); - let stderr = cmd.stderr.take().unwrap(); + let stderr = cmd.stderr.take().unwrap(); - let mut child = KillChildOnDrop(cmd); + let mut child = KillChildOnDrop(cmd); - let (ws_url, _) = find_ws_url_from_output(stderr); + let ws_url = extract_info_from_output(stderr).0.ws_url; - // Let it produce some blocks. - let _ = wait_n_finalized_blocks(3, 30, &ws_url).await; + // Let it produce some blocks. + wait_n_finalized_blocks(3, &ws_url).await; - assert!(child.try_wait().unwrap().is_none(), "the process should still be running"); + child.assert_still_running(); - // Stop the process - kill(Pid::from_raw(child.id().try_into().unwrap()), SIGINT).unwrap(); - assert!(wait_for(&mut child, 40).map(|x| x.success()).unwrap()); + // Stop the process + child.stop(); + }) + .await } -/// Run the node asserting that it fails with an error -pub fn run_node_assert_fail(base_path: &Path, args: &[&str]) { - let mut cmd = Command::new(cargo_bin("substrate")); +pub struct KillChildOnDrop(pub Child); - let mut child = KillChildOnDrop(cmd.args(args).arg("-d").arg(base_path).spawn().unwrap()); +impl KillChildOnDrop { + /// Stop the child and wait until it is finished. + /// + /// Asserts if the exit status isn't success. + pub fn stop(&mut self) { + self.stop_with_signal(SIGINT); + } - // Let it produce some blocks, but it should die within 10 seconds. - assert_ne!( - wait_timeout::ChildExt::wait_timeout(&mut *child, Duration::from_secs(10)).unwrap(), - None, - "the process should not be running anymore" - ); -} + /// Same as [`Self::stop`] but takes the `signal` that is sent to stop the child. + pub fn stop_with_signal(&mut self, signal: Signal) { + kill(Pid::from_raw(self.id().try_into().unwrap()), signal).unwrap(); + assert!(self.wait().unwrap().success()); + } -pub struct KillChildOnDrop(pub Child); + /// Asserts that the child is still running. + pub fn assert_still_running(&mut self) { + assert!(self.try_wait().unwrap().is_none(), "the process should still be running"); + } +} impl Drop for KillChildOnDrop { fn drop(&mut self) { @@ -148,18 +128,22 @@ impl DerefMut for KillChildOnDrop { } } -/// Read the WS address from the output. +/// Information extracted from a running node. +pub struct NodeInfo { + pub ws_url: String, + pub db_path: PathBuf, +} + +/// Extract [`NodeInfo`] from a running node by parsing its output. /// -/// This is hack to get the actual binded sockaddr because -/// substrate assigns a random port if the specified port was already binded. -pub fn find_ws_url_from_output(read: impl Read + Send) -> (String, String) { +/// Returns the [`NodeInfo`] and all the read data. +pub fn extract_info_from_output(read: impl Read + Send) -> (NodeInfo, String) { let mut data = String::new(); let ws_url = BufReader::new(read) .lines() .find_map(|line| { - let line = - line.expect("failed to obtain next line from stdout for WS address discovery"); + let line = line.expect("failed to obtain next line while extracting node info"); data.push_str(&line); data.push_str("\n"); @@ -176,5 +160,9 @@ pub fn find_ws_url_from_output(read: impl Read + Send) -> (String, String) { panic!("We should get a WebSocket address") }); - (ws_url, data) + // Database path is printed before the ws url! + let re = Regex::new(r"Database: .+ at (\S+)").unwrap(); + let db_path = PathBuf::from(re.captures(data.as_str()).unwrap().get(1).unwrap().as_str()); + + (NodeInfo { ws_url, db_path }, data) } diff --git a/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs b/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs index fc0bf69a099ba659876a4c890434892ef5a34fdb..3d5598f3fbe23584f19b9f0fb66686b08b58fdf2 100644 --- a/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs +++ b/substrate/bin/node/cli/tests/running_the_node_and_interrupt.rs @@ -18,94 +18,81 @@ #![cfg(unix)] use assert_cmd::cargo::cargo_bin; -use nix::{ - sys::signal::{ - kill, - Signal::{self, SIGINT, SIGTERM}, - }, - unistd::Pid, +use nix::sys::signal::Signal::{self, SIGINT, SIGTERM}; +use std::{ + process::{self, Child, Command}, + time::Duration, }; -use std::process::{self, Child, Command}; use tempfile::tempdir; pub mod common; #[tokio::test] async fn running_the_node_works_and_can_be_interrupted() { - async fn run_command_and_kill(signal: Signal) { - let base_path = tempdir().expect("could not create a temp dir"); - let mut cmd = common::KillChildOnDrop( + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + async fn run_command_and_kill(signal: Signal) { + let base_path = tempdir().expect("could not create a temp dir"); + let mut cmd = common::KillChildOnDrop( + Command::new(cargo_bin("substrate")) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .args(&["--dev", "-d"]) + .arg(base_path.path()) + .arg("--db=paritydb") + .arg("--no-hardware-benchmarks") + .spawn() + .unwrap(), + ); + + let stderr = cmd.stderr.take().unwrap(); + + let ws_url = common::extract_info_from_output(stderr).0.ws_url; + + common::wait_n_finalized_blocks(3, &ws_url).await; + + cmd.assert_still_running(); + + cmd.stop_with_signal(signal); + + // Check if the database was closed gracefully. If it was not, + // there may exist a ref cycle that prevents the Client from being dropped properly. + // + // parity-db only writes the stats file on clean shutdown. + let stats_file = base_path.path().join("chains/dev/paritydb/full/stats.txt"); + assert!(std::path::Path::exists(&stats_file)); + } + + run_command_and_kill(SIGINT).await; + run_command_and_kill(SIGTERM).await; + }) + .await; +} + +#[tokio::test] +async fn running_two_nodes_with_the_same_ws_port_should_work() { + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + fn start_node() -> Child { Command::new(cargo_bin("substrate")) .stdout(process::Stdio::piped()) .stderr(process::Stdio::piped()) - .args(&["--dev", "-d"]) - .arg(base_path.path()) - .arg("--db=paritydb") - .arg("--no-hardware-benchmarks") + .args(&["--dev", "--tmp", "--ws-port=45789", "--no-hardware-benchmarks"]) .spawn() - .unwrap(), - ); - - let stderr = cmd.stderr.take().unwrap(); - - let (ws_url, _) = common::find_ws_url_from_output(stderr); - - common::wait_n_finalized_blocks(3, 30, &ws_url) - .await - .expect("Blocks are produced in time"); - assert!(cmd.try_wait().unwrap().is_none(), "the process should still be running"); - kill(Pid::from_raw(cmd.id().try_into().unwrap()), signal).unwrap(); - assert_eq!( - common::wait_for(&mut cmd, 30).map(|x| x.success()), - Ok(true), - "the process must exit gracefully after signal {}", - signal, - ); - // Check if the database was closed gracefully. If it was not, - // there may exist a ref cycle that prevents the Client from being dropped properly. - // - // parity-db only writes the stats file on clean shutdown. - let stats_file = base_path.path().join("chains/dev/paritydb/full/stats.txt"); - assert!(std::path::Path::exists(&stats_file)); - } - - run_command_and_kill(SIGINT).await; - run_command_and_kill(SIGTERM).await; -} + .unwrap() + } -#[tokio::test] -async fn running_two_nodes_with_the_same_ws_port_should_work() { - fn start_node() -> Child { - Command::new(cargo_bin("substrate")) - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .args(&["--dev", "--tmp", "--ws-port=45789", "--no-hardware-benchmarks"]) - .spawn() - .unwrap() - } - - let mut first_node = common::KillChildOnDrop(start_node()); - let mut second_node = common::KillChildOnDrop(start_node()); - - let stderr = first_node.stderr.take().unwrap(); - let (ws_url, _) = common::find_ws_url_from_output(stderr); - - common::wait_n_finalized_blocks(3, 30, &ws_url).await.unwrap(); - - assert!(first_node.try_wait().unwrap().is_none(), "The first node should still be running"); - assert!(second_node.try_wait().unwrap().is_none(), "The second node should still be running"); - - kill(Pid::from_raw(first_node.id().try_into().unwrap()), SIGINT).unwrap(); - kill(Pid::from_raw(second_node.id().try_into().unwrap()), SIGINT).unwrap(); - - assert_eq!( - common::wait_for(&mut first_node, 30).map(|x| x.success()), - Ok(true), - "The first node must exit gracefully", - ); - assert_eq!( - common::wait_for(&mut second_node, 30).map(|x| x.success()), - Ok(true), - "The second node must exit gracefully", - ); + let mut first_node = common::KillChildOnDrop(start_node()); + let mut second_node = common::KillChildOnDrop(start_node()); + + let stderr = first_node.stderr.take().unwrap(); + let ws_url = common::extract_info_from_output(stderr).0.ws_url; + + common::wait_n_finalized_blocks(3, &ws_url).await; + + first_node.assert_still_running(); + second_node.assert_still_running(); + + first_node.stop(); + second_node.stop(); + }) + .await; } diff --git a/substrate/bin/node/cli/tests/telemetry.rs b/substrate/bin/node/cli/tests/telemetry.rs index 633cc996ca615cb1c8b64bce44ff6b0dcbd90752..a68746a2c0011028656fbcd6354f3226c990b016 100644 --- a/substrate/bin/node/cli/tests/telemetry.rs +++ b/substrate/bin/node/cli/tests/telemetry.rs @@ -17,78 +17,75 @@ // along with this program. If not, see <https://www.gnu.org/licenses/>. use assert_cmd::cargo::cargo_bin; -use nix::{ - sys::signal::{kill, Signal::SIGINT}, - unistd::Pid, -}; -use std::process; +use std::{process, time::Duration}; + +use crate::common::KillChildOnDrop; pub mod common; pub mod websocket_server; #[tokio::test] async fn telemetry_works() { - let config = websocket_server::Config { - capacity: 1, - max_frame_size: 1048 * 1024, - send_buffer_len: 32, - bind_address: "127.0.0.1:0".parse().unwrap(), - }; - let mut server = websocket_server::WsServer::new(config).await.unwrap(); - - let addr = server.local_addr().unwrap(); - - let server_task = tokio::spawn(async move { - loop { - use websocket_server::Event; - match server.next_event().await { - // New connection on the listener. - Event::ConnectionOpen { address } => { - println!("New connection from {:?}", address); - server.accept(); - }, - - // Received a message from a connection. - Event::BinaryFrame { message, .. } => { - let json: serde_json::Value = serde_json::from_slice(&message).unwrap(); - let object = - json.as_object().unwrap().get("payload").unwrap().as_object().unwrap(); - if matches!(object.get("best"), Some(serde_json::Value::String(_))) { - break - } - }, - - Event::TextFrame { .. } => panic!("Got a TextFrame over the socket, this is a bug"), - - // Connection has been closed. - Event::ConnectionError { .. } => {}, + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + let config = websocket_server::Config { + capacity: 1, + max_frame_size: 1048 * 1024, + send_buffer_len: 32, + bind_address: "127.0.0.1:0".parse().unwrap(), + }; + let mut server = websocket_server::WsServer::new(config).await.unwrap(); + + let addr = server.local_addr().unwrap(); + + let server_task = tokio::spawn(async move { + loop { + use websocket_server::Event; + match server.next_event().await { + // New connection on the listener. + Event::ConnectionOpen { address } => { + println!("New connection from {:?}", address); + server.accept(); + }, + + // Received a message from a connection. + Event::BinaryFrame { message, .. } => { + let json: serde_json::Value = serde_json::from_slice(&message).unwrap(); + let object = + json.as_object().unwrap().get("payload").unwrap().as_object().unwrap(); + if matches!(object.get("best"), Some(serde_json::Value::String(_))) { + break + } + }, + + Event::TextFrame { .. } => + panic!("Got a TextFrame over the socket, this is a bug"), + + // Connection has been closed. + Event::ConnectionError { .. } => {}, + } } - } - }); - - let mut substrate = process::Command::new(cargo_bin("substrate")); - - let mut substrate = substrate - .args(&["--dev", "--tmp", "--telemetry-url"]) - .arg(format!("ws://{} 10", addr)) - .arg("--no-hardware-benchmarks") - .stdout(process::Stdio::piped()) - .stderr(process::Stdio::piped()) - .stdin(process::Stdio::null()) - .spawn() - .unwrap(); - - server_task.await.expect("server task panicked"); - - assert!(substrate.try_wait().unwrap().is_none(), "the process should still be running"); - - // Stop the process - kill(Pid::from_raw(substrate.id().try_into().unwrap()), SIGINT).unwrap(); - assert!(common::wait_for(&mut substrate, 40).map(|x| x.success()).unwrap_or_default()); - - let output = substrate.wait_with_output().unwrap(); - - println!("{}", String::from_utf8(output.stdout).unwrap()); - eprintln!("{}", String::from_utf8(output.stderr).unwrap()); - assert!(output.status.success()); + }); + + let mut substrate = process::Command::new(cargo_bin("substrate")); + + let mut substrate = KillChildOnDrop( + substrate + .args(&["--dev", "--tmp", "--telemetry-url"]) + .arg(format!("ws://{} 10", addr)) + .arg("--no-hardware-benchmarks") + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::piped()) + .stdin(process::Stdio::null()) + .spawn() + .unwrap(), + ); + + server_task.await.expect("server task panicked"); + + substrate.assert_still_running(); + + // Stop the process + substrate.stop(); + }) + .await; } diff --git a/substrate/bin/node/cli/tests/temp_base_path_works.rs b/substrate/bin/node/cli/tests/temp_base_path_works.rs index 4e743f2d3abd443b9d4309615fcbf9fae3880fd8..89e901c00e118a0f60bfddf80bb9580deb523670 100644 --- a/substrate/bin/node/cli/tests/temp_base_path_works.rs +++ b/substrate/bin/node/cli/tests/temp_base_path_works.rs @@ -19,45 +19,42 @@ #![cfg(unix)] use assert_cmd::cargo::cargo_bin; -use nix::{ - sys::signal::{kill, Signal::SIGINT}, - unistd::Pid, -}; -use regex::Regex; use std::{ - io::Read, - path::PathBuf, process::{Command, Stdio}, + time::Duration, }; pub mod common; #[tokio::test] async fn temp_base_path_works() { - let mut cmd = Command::new(cargo_bin("substrate")); - let mut child = common::KillChildOnDrop( - cmd.args(&["--dev", "--tmp", "--no-hardware-benchmarks"]) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .unwrap(), - ); - - let mut stderr = child.stderr.take().unwrap(); - let (ws_url, mut data) = common::find_ws_url_from_output(&mut stderr); - - // Let it produce some blocks. - common::wait_n_finalized_blocks(3, 30, &ws_url).await.unwrap(); - assert!(child.try_wait().unwrap().is_none(), "the process should still be running"); - - // Stop the process - kill(Pid::from_raw(child.id().try_into().unwrap()), SIGINT).unwrap(); - assert!(common::wait_for(&mut child, 40).map(|x| x.success()).unwrap_or_default()); - - // Ensure the database has been deleted - stderr.read_to_string(&mut data).unwrap(); - let re = Regex::new(r"Database: .+ at (\S+)").unwrap(); - let db_path = PathBuf::from(re.captures(data.as_str()).unwrap().get(1).unwrap().as_str()); - - assert!(!db_path.exists()); + common::run_with_timeout(Duration::from_secs(60 * 10), async move { + let mut cmd = Command::new(cargo_bin("substrate")); + let mut child = common::KillChildOnDrop( + cmd.args(&["--dev", "--tmp", "--no-hardware-benchmarks"]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(), + ); + + let mut stderr = child.stderr.take().unwrap(); + let node_info = common::extract_info_from_output(&mut stderr).0; + + // Let it produce some blocks. + common::wait_n_finalized_blocks(3, &node_info.ws_url).await; + + // Ensure the db path exists while the node is running + assert!(node_info.db_path.exists()); + + child.assert_still_running(); + + // Stop the process + child.stop(); + + if node_info.db_path.exists() { + panic!("Database path `{}` wasn't deleted!", node_info.db_path.display()); + } + }) + .await; } diff --git a/substrate/client/cli/src/lib.rs b/substrate/client/cli/src/lib.rs index adbb299317ac9c1182435cba5e61d401d7629956..e73321ecce5b380a8f9d73e1b822554ff222cec0 100644 --- a/substrate/client/cli/src/lib.rs +++ b/substrate/client/cli/src/lib.rs @@ -197,10 +197,15 @@ pub trait SubstrateCli: Sized { command: &T, ) -> error::Result<Runner<Self>> { let tokio_runtime = build_runtime()?; + + // `capture` needs to be called in a tokio context. + // Also capture them as early as possible. + let signals = tokio_runtime.block_on(async { Signals::capture() })?; + let config = command.create_configuration(self, tokio_runtime.handle().clone())?; command.init(&Self::support_url(), &Self::impl_version(), |_, _| {}, &config)?; - Runner::new(config, tokio_runtime) + Runner::new(config, tokio_runtime, signals) } /// Create a runner for the command provided in argument. The `logger_hook` can be used to setup @@ -231,10 +236,15 @@ pub trait SubstrateCli: Sized { F: FnOnce(&mut LoggerBuilder, &Configuration), { let tokio_runtime = build_runtime()?; + + // `capture` needs to be called in a tokio context. + // Also capture them as early as possible. + let signals = tokio_runtime.block_on(async { Signals::capture() })?; + let config = command.create_configuration(self, tokio_runtime.handle().clone())?; command.init(&Self::support_url(), &Self::impl_version(), logger_hook, &config)?; - Runner::new(config, tokio_runtime) + Runner::new(config, tokio_runtime, signals) } /// Native runtime version. fn native_runtime_version(chain_spec: &Box<dyn ChainSpec>) -> &'static RuntimeVersion; diff --git a/substrate/client/cli/src/runner.rs b/substrate/client/cli/src/runner.rs index 8917a37c499c06ee8fde40cde0267d28a6ad9c49..47adfcf89fe99685016a8697513d0ddf310062d1 100644 --- a/substrate/client/cli/src/runner.rs +++ b/substrate/client/cli/src/runner.rs @@ -18,54 +18,72 @@ use crate::{error::Error as CliError, Result, SubstrateCli}; use chrono::prelude::*; -use futures::{future, future::FutureExt, pin_mut, select, Future}; +use futures::{ + future::{self, BoxFuture, FutureExt}, + pin_mut, select, Future, +}; use log::info; use sc_service::{Configuration, Error as ServiceError, TaskManager}; use sc_utils::metrics::{TOKIO_THREADS_ALIVE, TOKIO_THREADS_TOTAL}; use std::{marker::PhantomData, time::Duration}; -#[cfg(target_family = "unix")] -async fn main<F, E>(func: F) -> std::result::Result<(), E> -where - F: Future<Output = std::result::Result<(), E>> + future::FusedFuture, - E: std::error::Error + Send + Sync + 'static + From<ServiceError>, -{ - use tokio::signal::unix::{signal, SignalKind}; +/// Abstraction over OS signals to handle the shutdown of the node smoothly. +/// +/// On `unix` this represents `SigInt` and `SigTerm`. +pub struct Signals(BoxFuture<'static, ()>); - let mut stream_int = signal(SignalKind::interrupt()).map_err(ServiceError::Io)?; - let mut stream_term = signal(SignalKind::terminate()).map_err(ServiceError::Io)?; - - let t1 = stream_int.recv().fuse(); - let t2 = stream_term.recv().fuse(); - let t3 = func; - - pin_mut!(t1, t2, t3); +impl Signals { + /// Capture the relevant signals to handle shutdown of the node smoothly. + /// + /// Needs to be called in a Tokio context to have access to the tokio reactor. + #[cfg(target_family = "unix")] + pub fn capture() -> std::result::Result<Self, ServiceError> { + use tokio::signal::unix::{signal, SignalKind}; + + let mut stream_int = signal(SignalKind::interrupt()).map_err(ServiceError::Io)?; + let mut stream_term = signal(SignalKind::terminate()).map_err(ServiceError::Io)?; + + Ok(Signals( + async move { + future::select(stream_int.recv().boxed(), stream_term.recv().boxed()).await; + } + .boxed(), + )) + } - select! { - _ = t1 => {}, - _ = t2 => {}, - res = t3 => res?, + /// Capture the relevant signals to handle shutdown of the node smoothly. + /// + /// Needs to be called in a Tokio context to have access to the tokio reactor. + #[cfg(not(unix))] + pub fn capture() -> std::result::Result<Self, ServiceError> { + use tokio::signal::ctrl_c; + + Ok(Signals( + async move { + let _ = ctrl_c().await; + } + .boxed(), + )) } - Ok(()) + /// A dummy signal that never returns. + pub fn dummy() -> Self { + Self(future::pending().boxed()) + } } -#[cfg(not(unix))] -async fn main<F, E>(func: F) -> std::result::Result<(), E> +async fn main<F, E>(func: F, signals: impl Future<Output = ()>) -> std::result::Result<(), E> where F: Future<Output = std::result::Result<(), E>> + future::FusedFuture, - E: std::error::Error + Send + Sync + 'static + From<ServiceError>, + E: std::error::Error + Send + Sync + 'static, { - use tokio::signal::ctrl_c; - - let t1 = ctrl_c().fuse(); - let t2 = func; + let signals = signals.fuse(); - pin_mut!(t1, t2); + pin_mut!(func, signals); select! { - _ = t1 => {}, - res = t2 => res?, + _ = signals => {}, + res = func => res?, } Ok(()) @@ -89,6 +107,7 @@ fn run_until_exit<F, E>( tokio_runtime: tokio::runtime::Runtime, future: F, task_manager: TaskManager, + signals: impl Future<Output = ()>, ) -> std::result::Result<(), E> where F: Future<Output = std::result::Result<(), E>> + future::Future, @@ -97,7 +116,7 @@ where let f = future.fuse(); pin_mut!(f); - tokio_runtime.block_on(main(f))?; + tokio_runtime.block_on(main(f, signals))?; drop(task_manager); Ok(()) @@ -107,13 +126,18 @@ where pub struct Runner<C: SubstrateCli> { config: Configuration, tokio_runtime: tokio::runtime::Runtime, + signals: Signals, phantom: PhantomData<C>, } impl<C: SubstrateCli> Runner<C> { /// Create a new runtime with the command provided in argument - pub fn new(config: Configuration, tokio_runtime: tokio::runtime::Runtime) -> Result<Runner<C>> { - Ok(Runner { config, tokio_runtime, phantom: PhantomData }) + pub fn new( + config: Configuration, + tokio_runtime: tokio::runtime::Runtime, + signals: Signals, + ) -> Result<Runner<C>> { + Ok(Runner { config, tokio_runtime, signals, phantom: PhantomData }) } /// Log information about the node itself. @@ -147,7 +171,7 @@ impl<C: SubstrateCli> Runner<C> { self.print_node_infos(); let mut task_manager = self.tokio_runtime.block_on(initialize(self.config))?; - let res = self.tokio_runtime.block_on(main(task_manager.future().fuse())); + let res = self.tokio_runtime.block_on(main(task_manager.future().fuse(), self.signals.0)); // We need to drop the task manager here to inform all tasks that they should shut down. // // This is important to be done before we instruct the tokio runtime to shutdown. Otherwise @@ -210,7 +234,7 @@ impl<C: SubstrateCli> Runner<C> { E: std::error::Error + Send + Sync + 'static + From<ServiceError> + From<CliError>, { let (future, task_manager) = runner(self.config)?; - run_until_exit::<_, E>(self.tokio_runtime, future, task_manager) + run_until_exit::<_, E>(self.tokio_runtime, future, task_manager, self.signals.0) } /// Get an immutable reference to the node Configuration @@ -369,6 +393,7 @@ mod tests { runtime_cache_size: 2, }, runtime, + Signals::dummy(), ) .unwrap();