From 52472238bda687401bdaff219baab5fb0acd8734 Mon Sep 17 00:00:00 2001
From: Pierre Krieger <pierre.krieger1708@gmail.com>
Date: Thu, 2 May 2019 21:03:04 +0200
Subject: [PATCH] Drop connections when the handler gets disabled (#2439)

* Drop connections when the handler gets disabled

* Add test
---
 .../src/custom_proto/handler.rs               |  60 ++++++----
 substrate/core/network-libp2p/tests/test.rs   | 106 +++++++++++++++++-
 2 files changed, 144 insertions(+), 22 deletions(-)

diff --git a/substrate/core/network-libp2p/src/custom_proto/handler.rs b/substrate/core/network-libp2p/src/custom_proto/handler.rs
index 25eac86ffad..2cfc1107a96 100644
--- a/substrate/core/network-libp2p/src/custom_proto/handler.rs
+++ b/substrate/core/network-libp2p/src/custom_proto/handler.rs
@@ -30,7 +30,6 @@ use smallvec::{smallvec, SmallVec};
 use std::{borrow::Cow, error, fmt, io, marker::PhantomData, mem, time::Duration};
 use tokio_io::{AsyncRead, AsyncWrite};
 use tokio_timer::{Delay, clock::Clock};
-use void::Void;
 
 /// Implements the `IntoProtocolsHandler` trait of libp2p.
 ///
@@ -193,6 +192,10 @@ enum ProtocolState<TMessage, TSubstream> {
 		reenable: bool,
 	},
 
+	/// In this state, we don't care about anything anymore and need to kill the connection as soon
+	/// as possible.
+	KillAsap,
+
 	/// We sometimes temporarily switch to this state during processing. If we are in this state
 	/// at the beginning of a method, that means something bad happened in the source code.
 	Poisoned,
@@ -290,6 +293,7 @@ where
 				}
 			}
 
+			st @ ProtocolState::KillAsap => st,
 			st @ ProtocolState::Opening { .. } => st,
 			st @ ProtocolState::Normal { .. } => st,
 			ProtocolState::Disabled { shutdown, .. } => {
@@ -314,27 +318,18 @@ where
 				ProtocolState::Disabled { shutdown, reenable: false }
 			}
 
-			ProtocolState::Opening { .. } => {
-				ProtocolState::Disabled { shutdown: SmallVec::new(), reenable: false }
-			}
-
-			ProtocolState::Normal { substreams, mut shutdown } => {
-				for mut substream in substreams {
-					substream.shutdown();
-					shutdown.push(substream);
-				}
-				let event = CustomProtoHandlerOut::CustomProtocolClosed {
-					reason: "Disabled on purpose on our side".into()
-				};
-				self.events_queue.push(ProtocolsHandlerEvent::Custom(event));
-				ProtocolState::Disabled {
-					shutdown: shutdown.into_iter().collect(),
-					reenable: false
-				}
-			}
+			ProtocolState::Opening { .. } | ProtocolState::Normal { .. } =>
+				// At the moment, if we get disabled while things were working, we kill the entire
+				// connection in order to force a reset of the state.
+				// This is obviously an extremely shameful way to do things, but at the time of
+				// the writing of this comment, the networking works very poorly and a solution
+				// needs to be found.
+				ProtocolState::KillAsap,
 
 			ProtocolState::Disabled { shutdown, .. } =>
 				ProtocolState::Disabled { shutdown, reenable: false },
+
+			ProtocolState::KillAsap => ProtocolState::KillAsap,
 		};
 	}
 
@@ -462,6 +457,8 @@ where
 					None
 				}
 			}
+
+			ProtocolState::KillAsap => None,
 		}
 	}
 
@@ -507,6 +504,8 @@ where
 				shutdown.push(substream);
 				ProtocolState::Disabled { shutdown, reenable: false }
 			}
+
+			ProtocolState::KillAsap => ProtocolState::KillAsap,
 		};
 	}
 
@@ -527,7 +526,7 @@ where TSubstream: AsyncRead + AsyncWrite, TMessage: CustomMessage {
 	type InEvent = CustomProtoHandlerIn<TMessage>;
 	type OutEvent = CustomProtoHandlerOut<TMessage>;
 	type Substream = TSubstream;
-	type Error = Void;
+	type Error = ConnectionKillError;
 	type InboundProtocol = RegisteredProtocol<TMessage>;
 	type OutboundProtocol = RegisteredProtocol<TMessage>;
 	type OutboundOpenInfo = ();
@@ -577,7 +576,8 @@ where TSubstream: AsyncRead + AsyncWrite, TMessage: CustomMessage {
 		match self.state {
 			ProtocolState::Init { .. } | ProtocolState::Opening { .. } |
 			ProtocolState::Normal { .. } => KeepAlive::Yes,
-			ProtocolState::Disabled { .. } | ProtocolState::Poisoned => KeepAlive::No,
+			ProtocolState::Disabled { .. } | ProtocolState::Poisoned |
+      ProtocolState::KillAsap => KeepAlive::No,
 		}
 	}
 
@@ -593,6 +593,11 @@ where TSubstream: AsyncRead + AsyncWrite, TMessage: CustomMessage {
 			return Ok(Async::Ready(event))
 		}
 
+		// Kill the connection if needed.
+		if let ProtocolState::KillAsap = self.state {
+			return Err(ConnectionKillError);
+		}
+
 		// Process all the substreams.
 		if let Some(event) = self.poll_state() {
 			return Ok(Async::Ready(event))
@@ -629,3 +634,16 @@ where TSubstream: AsyncRead + AsyncWrite, TMessage: CustomMessage {
 		list.push(substream);
 	}
 }
+
+/// Error returned when switching from normal to disabled.
+#[derive(Debug)]
+pub struct ConnectionKillError;
+
+impl error::Error for ConnectionKillError {
+}
+
+impl fmt::Display for ConnectionKillError {
+	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+		write!(f, "Connection kill when switching from normal to disabled")
+	}
+}
diff --git a/substrate/core/network-libp2p/tests/test.rs b/substrate/core/network-libp2p/tests/test.rs
index ff4d5824e02..04206e66d84 100644
--- a/substrate/core/network-libp2p/tests/test.rs
+++ b/substrate/core/network-libp2p/tests/test.rs
@@ -16,7 +16,7 @@
 
 use futures::{future, stream, prelude::*, try_ready};
 use rand::seq::SliceRandom;
-use std::io;
+use std::{io, time::Duration, time::Instant};
 use substrate_network_libp2p::{CustomMessage, Multiaddr, multiaddr::Protocol, ServiceEvent, build_multiaddr};
 
 /// Builds two services. The second one and further have the first one as its bootstrap node.
@@ -253,3 +253,107 @@ fn basic_two_nodes_requests_in_parallel() {
 	let combined = fut1.select(fut2).map_err(|(err, _)| err);
 	tokio::runtime::Runtime::new().unwrap().block_on_all(combined).unwrap();
 }
+
+#[test]
+fn reconnect_after_disconnect() {
+	// We connect two nodes together, then force a disconnect (through the API of the `Service`),
+	// check that the disconnect worked, and finally check whether they successfully reconnect.
+
+	let (mut service1, mut service2) = {
+		let mut l = build_nodes::<Vec<u8>>(2, 50350).into_iter();
+		let a = l.next().unwrap();
+		let b = l.next().unwrap();
+		(a, b)
+	};
+
+	// We use the `current_thread` runtime because it doesn't require us to have `'static` futures.
+	let mut runtime = tokio::runtime::current_thread::Runtime::new().unwrap();
+
+	// For this test, the services can be in the following states.
+	#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+	enum ServiceState { NotConnected, FirstConnec, Disconnected, ConnectedAgain }
+	let mut service1_state = ServiceState::NotConnected;
+	let mut service2_state = ServiceState::NotConnected;
+
+	// Run the events loops.
+	runtime.block_on(future::poll_fn(|| -> Result<_, io::Error> {
+		loop {
+			let mut service1_not_ready = false;
+
+			match service1.poll().unwrap() {
+				Async::Ready(Some(ServiceEvent::OpenedCustomProtocol { .. })) => {
+					match service1_state {
+						ServiceState::NotConnected => {
+							service1_state = ServiceState::FirstConnec;
+							if service2_state == ServiceState::FirstConnec {
+								service1.drop_node(service2.peer_id());
+							}
+						},
+						ServiceState::Disconnected => service1_state = ServiceState::ConnectedAgain,
+						ServiceState::FirstConnec | ServiceState::ConnectedAgain => panic!(),
+					}
+				},
+				Async::Ready(Some(ServiceEvent::ClosedCustomProtocol { .. })) => {
+					match service1_state {
+						ServiceState::FirstConnec => service1_state = ServiceState::Disconnected,
+						ServiceState::ConnectedAgain| ServiceState::NotConnected |
+						ServiceState::Disconnected => panic!(),
+					}
+				},
+				Async::NotReady => service1_not_ready = true,
+				_ => panic!()
+			}
+
+			match service2.poll().unwrap() {
+				Async::Ready(Some(ServiceEvent::OpenedCustomProtocol { .. })) => {
+					match service2_state {
+						ServiceState::NotConnected => {
+							service2_state = ServiceState::FirstConnec;
+							if service1_state == ServiceState::FirstConnec {
+								service1.drop_node(service2.peer_id());
+							}
+						},
+						ServiceState::Disconnected => service2_state = ServiceState::ConnectedAgain,
+						ServiceState::FirstConnec | ServiceState::ConnectedAgain => panic!(),
+					}
+				},
+				Async::Ready(Some(ServiceEvent::ClosedCustomProtocol { .. })) => {
+					match service2_state {
+						ServiceState::FirstConnec => service2_state = ServiceState::Disconnected,
+						ServiceState::ConnectedAgain| ServiceState::NotConnected |
+						ServiceState::Disconnected => panic!(),
+					}
+				},
+				Async::NotReady if service1_not_ready => break,
+				Async::NotReady => {}
+				_ => panic!()
+			}
+		}
+
+		if service1_state == ServiceState::ConnectedAgain && service2_state == ServiceState::ConnectedAgain {
+			Ok(Async::Ready(()))
+		} else {
+			Ok(Async::NotReady)
+		}
+	})).unwrap();
+
+	// Do a second 3-seconds run to make sure we don't get disconnected immediately again.
+	let mut delay = tokio::timer::Delay::new(Instant::now() + Duration::from_secs(3));
+	runtime.block_on(future::poll_fn(|| -> Result<_, io::Error> {
+		match service1.poll().unwrap() {
+			Async::NotReady => {},
+			_ => panic!()
+		}
+
+		match service2.poll().unwrap() {
+			Async::NotReady => {},
+			_ => panic!()
+		}
+
+		if let Async::Ready(()) = delay.poll().unwrap() {
+			Ok(Async::Ready(()))
+		} else {
+			Ok(Async::NotReady)
+		}
+	})).unwrap();
+}
-- 
GitLab