From 7afe69066f0c64ad754d44495d4183a7f903c56a Mon Sep 17 00:00:00 2001
From: Pierre Krieger <pierre.krieger1708@gmail.com>
Date: Thu, 7 Feb 2019 17:16:41 +0100
Subject: [PATCH] Fix panic in custom protocol handler (#1723)

---
 .../src/custom_proto/behaviour.rs             |  6 ++
 .../src/custom_proto/handler.rs               | 62 ++++++++++++++-----
 2 files changed, 54 insertions(+), 14 deletions(-)

diff --git a/substrate/core/network-libp2p/src/custom_proto/behaviour.rs b/substrate/core/network-libp2p/src/custom_proto/behaviour.rs
index b7575524c1e..66af5e048c5 100644
--- a/substrate/core/network-libp2p/src/custom_proto/behaviour.rs
+++ b/substrate/core/network-libp2p/src/custom_proto/behaviour.rs
@@ -542,6 +542,9 @@ where
 				}
 			}
 			CustomProtosHandlerOut::CustomMessage { protocol_id, data } => {
+				debug_assert!(self.open_protocols.iter().any(|(s, p)|
+					s == &source && p == &protocol_id
+				));
 				let event = CustomProtosOut::CustomMessage {
 					peer_id: source,
 					protocol_id,
@@ -551,6 +554,9 @@ where
 				self.events.push(NetworkBehaviourAction::GenerateEvent(event));
 			}
 			CustomProtosHandlerOut::Clogged { protocol_id, messages } => {
+				debug_assert!(self.open_protocols.iter().any(|(s, p)|
+					s == &source && p == &protocol_id
+				));
 				warn!(target: "sub-libp2p", "Queue of packets to send to {:?} (protocol: {:?}) is \
 					pretty large", source, protocol_id);
 				self.events.push(NetworkBehaviourAction::GenerateEvent(CustomProtosOut::Clogged {
diff --git a/substrate/core/network-libp2p/src/custom_proto/handler.rs b/substrate/core/network-libp2p/src/custom_proto/handler.rs
index 198e51d7d06..51ff2408ad8 100644
--- a/substrate/core/network-libp2p/src/custom_proto/handler.rs
+++ b/substrate/core/network-libp2p/src/custom_proto/handler.rs
@@ -41,7 +41,7 @@ pub struct CustomProtosHandler<TSubstream> {
 	protocols: RegisteredProtocols,
 
 	/// See the documentation of `State`.
-	state: State,
+	state: State<TSubstream>,
 
 	/// Value to be returned by `connection_keep_alive()`.
 	keep_alive: KeepAlive,
@@ -54,8 +54,12 @@ pub struct CustomProtosHandler<TSubstream> {
 }
 
 /// State of the handler.
-#[derive(Debug, Copy, Clone, PartialEq, Eq)]
-enum State {
+enum State<TSubstream> {
+	/// Waiting for the behaviour to tell the handler whether it is enabled or disabled.
+	/// Contains a list of substreams opened by the remote and that we will integrate to
+	/// `substreams` only if we get enabled.
+	Init(SmallVec<[RegisteredProtocolSubstream<TSubstream>; 6]>),
+
 	/// Normal functionning.
 	Normal,
 
@@ -136,7 +140,7 @@ where
 			protocols,
 			// We keep the connection alive for at least 5 seconds, waiting for what happens.
 			keep_alive: KeepAlive::Until(Instant::now() + Duration::from_secs(5)),
-			state: State::Normal,
+			state: State::Init(SmallVec::new()),
 			substreams: SmallVec::new(),
 			events_queue: SmallVec::new(),
 		}
@@ -147,23 +151,31 @@ where
 		&mut self,
 		proto: RegisteredProtocolSubstream<TSubstream>,
 	) {
+		if self.substreams.iter().any(|p| p.protocol_id() == proto.protocol_id()) {
+			// Skipping protocol that's already open.
+			return
+		}
+
 		match self.state {
+			State::Init(ref mut pending) => {
+				if pending.iter().all(|p| p.protocol_id() != proto.protocol_id()) {
+					pending.push(proto);
+				}
+				return
+			},
 			// TODO: we should shut down refused substreams gracefully; this should be fixed
 			// at the same time as https://github.com/paritytech/substrate/issues/1517
 			State::Disabled | State::ShuttingDown => return,
 			State::Normal => ()
 		}
 
-		if self.substreams.iter().any(|p| p.protocol_id() == proto.protocol_id()) {
-			// Skipping protocol that's already open.
-			return
-		}
-
 		let event = CustomProtosHandlerOut::CustomProtocolOpen {
 			protocol_id: proto.protocol_id(),
 			version: proto.protocol_version(),
 		};
 
+		self.keep_alive = KeepAlive::Forever;
+
 		self.substreams.push(proto);
 		self.events_queue.push(ProtocolsHandlerEvent::Custom(event));
 	}
@@ -206,7 +218,7 @@ where
 		match message {
 			CustomProtosHandlerIn::Disable => {
 				match self.state {
-					State::Normal => self.state = State::Disabled,
+					State::Init(_) | State::Normal => self.state = State::Disabled,
 					State::Disabled | State::ShuttingDown => (),
 				}
 
@@ -217,6 +229,19 @@ where
 			},
 			CustomProtosHandlerIn::EnableActive | CustomProtosHandlerIn::EnablePassive => {
 				match self.state {
+					State::Init(ref mut list) => {
+						for proto in list.drain() {
+							let event = CustomProtosHandlerOut::CustomProtocolOpen {
+								protocol_id: proto.protocol_id(),
+								version: proto.protocol_version(),
+							};
+
+							self.substreams.push(proto);
+							self.events_queue.push(ProtocolsHandlerEvent::Custom(event));
+						}
+
+						self.state = State::Normal;
+					}
 					State::Disabled => self.state = State::Normal,
 					State::Normal | State::ShuttingDown => (),
 				}
@@ -277,7 +302,7 @@ where
 
 	fn shutdown(&mut self) {
 		match self.state {
-			State::Normal | State::Disabled => self.state = State::ShuttingDown,
+			State::Init(_) | State::Normal | State::Disabled => self.state = State::ShuttingDown,
 			State::ShuttingDown => (),
 		}
 
@@ -297,8 +322,10 @@ where
 			return Ok(Async::Ready(event))
 		}
 
-		if self.state == State::ShuttingDown && self.substreams.is_empty() {
-			return Ok(Async::Ready(ProtocolsHandlerEvent::Shutdown))
+		if let State::ShuttingDown = self.state {
+			if self.substreams.is_empty() {
+				return Ok(Async::Ready(ProtocolsHandlerEvent::Shutdown))
+			}
 		}
 
 		for n in (0..self.substreams.len()).rev() {
@@ -323,6 +350,10 @@ where
 				Ok(Async::NotReady) =>
 					self.substreams.push(substream),
 				Ok(Async::Ready(None)) => {
+					// Close the connection as soon as possible.
+					if self.substreams.is_empty() {
+						self.keep_alive = KeepAlive::Now;
+					}
 					let event = CustomProtosHandlerOut::CustomProtocolClosed {
 						protocol_id: substream.protocol_id(),
 						result: Ok(())
@@ -330,6 +361,10 @@ where
 					return Ok(Async::Ready(ProtocolsHandlerEvent::Custom(event)))
 				},
 				Err(err) => {
+					// Close the connection as soon as possible.
+					if self.substreams.is_empty() {
+						self.keep_alive = KeepAlive::Now;
+					}
 					let event = CustomProtosHandlerOut::CustomProtocolClosed {
 						protocol_id: substream.protocol_id(),
 						result: Err(err)
@@ -350,7 +385,6 @@ where
 	fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
 		f.debug_struct("CustomProtosHandler")
 			.field("protocols", &self.protocols.len())
-			.field("state", &self.state)
 			.field("substreams", &self.substreams.len())
 			.finish()
 	}
-- 
GitLab