From d3244b728aa8109c76bad589d80d6248854f6a17 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bastian=20K=C3=B6cher?= <bkchr@users.noreply.github.com>
Date: Mon, 2 Mar 2020 18:20:04 +0100
Subject: [PATCH] Make sure we remove a peer on disconnect in gossip (#5104)

* Make sure we remove peers on disconnect in gossip state machine

* Clear up the code

* Add a comment
---
 .../network-gossip/src/state_machine.rs       | 49 +++++++++++++++++++
 substrate/client/network/src/protocol/sync.rs |  3 +-
 .../network/src/protocol/sync/blocks.rs       | 27 +++++-----
 3 files changed, 63 insertions(+), 16 deletions(-)

diff --git a/substrate/client/network-gossip/src/state_machine.rs b/substrate/client/network-gossip/src/state_machine.rs
index 26433e63ec3..db5ea3603dc 100644
--- a/substrate/client/network-gossip/src/state_machine.rs
+++ b/substrate/client/network-gossip/src/state_machine.rs
@@ -258,6 +258,7 @@ impl<B: BlockT> ConsensusGossip<B> {
 			let mut context = NetworkContext { gossip: self, network, engine_id: engine_id.clone() };
 			v.peer_disconnected(&mut context, &who);
 		}
+		self.peers.remove(&who);
 	}
 
 	/// Perform periodic maintenance
@@ -644,4 +645,52 @@ mod tests {
 		let _ = consensus.live_message_sinks.remove(&([0, 0, 0, 0], topic));
 		assert_eq!(stream.next(), None);
 	}
+
+	#[test]
+	fn peer_is_removed_on_disconnect() {
+		struct TestNetwork;
+		impl Network<Block> for TestNetwork {
+			fn event_stream(
+				&self,
+			) -> std::pin::Pin<Box<dyn futures::Stream<Item = crate::Event> + Send>> {
+				unimplemented!("Not required in tests")
+			}
+
+			fn report_peer(&self, _: PeerId, _: crate::ReputationChange) {
+				unimplemented!("Not required in tests")
+			}
+
+			fn disconnect_peer(&self, _: PeerId) {
+				unimplemented!("Not required in tests")
+			}
+
+			fn write_notification(&self, _: PeerId, _: crate::ConsensusEngineId, _: Vec<u8>) {
+				unimplemented!("Not required in tests")
+			}
+
+			fn register_notifications_protocol(
+				&self,
+				_: ConsensusEngineId,
+				_: std::borrow::Cow<'static, [u8]>,
+			) {
+				unimplemented!("Not required in tests")
+			}
+
+			fn announce(&self, _: H256, _: Vec<u8>) {
+				unimplemented!("Not required in tests")
+			}
+		}
+
+		let mut consensus = ConsensusGossip::<Block>::new();
+		consensus.register_validator_internal([0, 0, 0, 0], Arc::new(AllowAll));
+
+		let mut network = TestNetwork;
+
+		let peer_id = PeerId::random();
+		consensus.new_peer(&mut network, peer_id.clone(), Roles::FULL);
+		assert!(consensus.peers.contains_key(&peer_id));
+
+		consensus.peer_disconnected(&mut network, peer_id.clone());
+		assert!(!consensus.peers.contains_key(&peer_id));
+	}
 }
diff --git a/substrate/client/network/src/protocol/sync.rs b/substrate/client/network/src/protocol/sync.rs
index b1cd89155ef..d0427e61a81 100644
--- a/substrate/client/network/src/protocol/sync.rs
+++ b/substrate/client/network/src/protocol/sync.rs
@@ -1167,8 +1167,7 @@ impl<B: BlockT> ChainSync<B> {
 	}
 
 	/// Restart the sync process.
-	fn restart<'a>(&'a mut self) -> impl Iterator<Item = Result<(PeerId, BlockRequest<B>), BadPeer>> + 'a
-	{
+	fn restart<'a>(&'a mut self) -> impl Iterator<Item = Result<(PeerId, BlockRequest<B>), BadPeer>> + 'a {
 		self.queue_blocks.clear();
 		self.blocks.clear();
 		let info = self.client.info();
diff --git a/substrate/client/network/src/protocol/sync/blocks.rs b/substrate/client/network/src/protocol/sync/blocks.rs
index d4e4581c670..279150a2255 100644
--- a/substrate/client/network/src/protocol/sync/blocks.rs
+++ b/substrate/client/network/src/protocol/sync/blocks.rs
@@ -104,8 +104,7 @@ impl<B: BlockT> BlockCollection<B> {
 		common: NumberFor<B>,
 		max_parallel: u32,
 		max_ahead: u32,
-	) -> Option<Range<NumberFor<B>>>
-	{
+	) -> Option<Range<NumberFor<B>>> {
 		if peer_best <= common {
 			// Bail out early
 			return None;
@@ -165,20 +164,20 @@ impl<B: BlockT> BlockCollection<B> {
 	pub fn drain(&mut self, from: NumberFor<B>) -> Vec<BlockData<B>> {
 		let mut drained = Vec::new();
 		let mut ranges = Vec::new();
-		{
-			let mut prev = from;
-			for (start, range_data) in &mut self.blocks {
-				match range_data {
-					&mut BlockRangeState::Complete(ref mut blocks) if *start <= prev => {
-							prev = *start + (blocks.len() as u32).into();
-							let mut blocks = mem::replace(blocks, Vec::new());
-							drained.append(&mut blocks);
-							ranges.push(*start);
-					},
-					_ => break,
-				}
+
+		let mut prev = from;
+		for (start, range_data) in &mut self.blocks {
+			match range_data {
+				&mut BlockRangeState::Complete(ref mut blocks) if *start <= prev => {
+					prev = *start + (blocks.len() as u32).into();
+					// Remove all elements from `blocks` and add them to `drained`
+					drained.append(blocks);
+					ranges.push(*start);
+				},
+				_ => break,
 			}
 		}
+
 		for r in ranges {
 			self.blocks.remove(&r);
 		}
-- 
GitLab