From 6981a1c366ca701505f0628c208f0fc942f5fe21 Mon Sep 17 00:00:00 2001
From: Andronik Ordian <write@reusable.software>
Date: Mon, 8 Feb 2021 08:57:59 +0100
Subject: [PATCH] validator_discovery: pass PeerSet to the request (#2372)

* validator_discovery: pass PeerSet to the request

* validator_discovery: track PeerSet of connected peers

* validator_discovery: fix tests

* validator_discovery: fix long line

* some fixes

* some validator_discovery logs

* log validator discovery request

* Also connect to validators on `DistributePoV`.

* validator_discovery: store the whole state per peer_set

* bump spec versions in kusama, polkadot and westend

* Correcting doc.

* validator_discovery: bump channel capacity

* pov-distribution: some cleanup

* this should fix the test, but it does not

* I just got some brain damage while fixing this

Why are you even reading this???

* wrap long line

* address some review nits

Co-authored-by: Robert Klotzner <robert.klotzner@gmx.at>
---
 polkadot/Cargo.lock                           |   1 +
 .../network/availability-recovery/src/lib.rs  |   3 +-
 polkadot/node/network/bridge/src/action.rs    |   3 +
 polkadot/node/network/bridge/src/lib.rs       |  21 ++-
 .../network/bridge/src/validator_discovery.rs | 146 +++++++++++-------
 .../collator-protocol/src/collator_side.rs    |   5 +-
 .../node/network/pov-distribution/src/lib.rs  |  82 ++++++----
 .../network/pov-distribution/src/tests.rs     |  41 +++--
 .../node/network/protocol/src/peer_set.rs     |   7 +-
 polkadot/node/subsystem-util/Cargo.toml       |   1 +
 .../subsystem-util/src/validator_discovery.rs |  29 +++-
 polkadot/node/subsystem/src/messages.rs       |   4 +-
 .../src/types/overseer-protocol.md            |   2 +
 13 files changed, 227 insertions(+), 118 deletions(-)

diff --git a/polkadot/Cargo.lock b/polkadot/Cargo.lock
index 3dbd02ccbcb..3103a5f322a 100644
--- a/polkadot/Cargo.lock
+++ b/polkadot/Cargo.lock
@@ -5521,6 +5521,7 @@ dependencies = [
  "parking_lot 0.11.1",
  "pin-project 1.0.4",
  "polkadot-node-jaeger",
+ "polkadot-node-network-protocol",
  "polkadot-node-primitives",
  "polkadot-node-subsystem",
  "polkadot-node-subsystem-test-helpers",
diff --git a/polkadot/node/network/availability-recovery/src/lib.rs b/polkadot/node/network/availability-recovery/src/lib.rs
index 6b28c7295f0..e4e7e13c1eb 100644
--- a/polkadot/node/network/availability-recovery/src/lib.rs
+++ b/polkadot/node/network/availability-recovery/src/lib.rs
@@ -43,7 +43,7 @@ use polkadot_subsystem::{
 	},
 };
 use polkadot_node_network_protocol::{
-	v1 as protocol_v1, PeerId, ReputationChange as Rep, RequestId,
+	peer_set::PeerSet, v1 as protocol_v1, PeerId, ReputationChange as Rep, RequestId,
 };
 use polkadot_node_subsystem_util::{
 	Timeout, TimeoutExt,
@@ -579,6 +579,7 @@ async fn handle_from_interaction(
 
 			let message = NetworkBridgeMessage::ConnectToValidators {
 				validator_ids: vec![id.clone()],
+				peer_set: PeerSet::Validation,
 				connected: tx,
 			};
 
diff --git a/polkadot/node/network/bridge/src/action.rs b/polkadot/node/network/bridge/src/action.rs
index d298a17543c..12846c566c3 100644
--- a/polkadot/node/network/bridge/src/action.rs
+++ b/polkadot/node/network/bridge/src/action.rs
@@ -50,6 +50,7 @@ pub(crate) enum Action {
 	/// Ask network to connect to validators.
 	ConnectToValidators {
 		validator_ids: Vec<AuthorityDiscoveryId>,
+		peer_set: PeerSet,
 		connected: mpsc::Sender<(AuthorityDiscoveryId, PeerId)>,
 	},
 
@@ -133,9 +134,11 @@ impl From<polkadot_subsystem::SubsystemResult<FromOverseer<NetworkBridgeMessage>
 				}
 				NetworkBridgeMessage::ConnectToValidators {
 					validator_ids,
+					peer_set,
 					connected,
 				} => Action::ConnectToValidators {
 					validator_ids,
+					peer_set,
 					connected,
 				},
 			},
diff --git a/polkadot/node/network/bridge/src/lib.rs b/polkadot/node/network/bridge/src/lib.rs
index 41fec833209..5b869f6392d 100644
--- a/polkadot/node/network/bridge/src/lib.rs
+++ b/polkadot/node/network/bridge/src/lib.rs
@@ -244,10 +244,18 @@ where
 
 			Action::ConnectToValidators {
 				validator_ids,
+				peer_set,
 				connected,
 			} => {
+				tracing::debug!(
+					target: LOG_TARGET,
+					peer_set = ?peer_set,
+					ids = ?validator_ids,
+					"Received a validator connection request",
+				);
 				let (ns, ads) = validator_discovery.on_request(
 					validator_ids,
+					peer_set,
 					connected,
 					bridge.network_service,
 					bridge.authority_discovery_service,
@@ -257,11 +265,6 @@ where
 			},
 
 			Action::ReportPeer(peer, rep) => {
-				tracing::debug!(
-					target: LOG_TARGET,
-					peer = ?peer,
-					"Peer sent us an invalid request",
-				);
 				bridge.network_service.report_peer(peer, rep).await?
 			}
 
@@ -296,7 +299,11 @@ where
 					PeerSet::Collation => &mut collation_peers,
 				};
 
-				validator_discovery.on_peer_connected(&peer, &mut bridge.authority_discovery_service).await;
+				validator_discovery.on_peer_connected(
+					peer.clone(),
+					peer_set,
+					&mut bridge.authority_discovery_service,
+				).await;
 
 				match peer_map.entry(peer.clone()) {
 					hash_map::Entry::Occupied(_) => continue,
@@ -358,7 +365,7 @@ where
 					PeerSet::Collation => &mut collation_peers,
 				};
 
-				validator_discovery.on_peer_disconnected(&peer);
+				validator_discovery.on_peer_disconnected(&peer, peer_set);
 
 				if peer_map.remove(&peer).is_some() {
 					match peer_set {
diff --git a/polkadot/node/network/bridge/src/validator_discovery.rs b/polkadot/node/network/bridge/src/validator_discovery.rs
index 926aa370664..762dc7d9eaa 100644
--- a/polkadot/node/network/bridge/src/validator_discovery.rs
+++ b/polkadot/node/network/bridge/src/validator_discovery.rs
@@ -23,6 +23,7 @@ use std::sync::Arc;
 
 use async_trait::async_trait;
 use futures::channel::mpsc;
+use strum::IntoEnumIterator as _;
 
 use sc_network::multiaddr::{Multiaddr, Protocol};
 use sc_authority_discovery::Service as AuthorityDiscoveryService;
@@ -93,7 +94,11 @@ impl NonRevokedConnectionRequestState {
 		}
 	}
 
-	pub fn on_authority_connected(&mut self, authority: &AuthorityDiscoveryId, peer_id: &PeerId) {
+	pub fn on_authority_connected(
+		&mut self,
+		authority: &AuthorityDiscoveryId,
+		peer_id: &PeerId,
+	) {
 		if self.pending.remove(authority) {
 			// an error may happen if the request was revoked or
 			// the channel's buffer is full, ignoring it is fine
@@ -118,7 +123,8 @@ impl NonRevokedConnectionRequestState {
 /// Returns `Some(id)` iff the request counter is `0`.
 fn on_revoke(map: &mut HashMap<AuthorityDiscoveryId, u64>, id: AuthorityDiscoveryId) -> Option<AuthorityDiscoveryId> {
 	if let hash_map::Entry::Occupied(mut entry) = map.entry(id) {
-		if entry.get_mut().saturating_sub(1) == 0 {
+		*entry.get_mut() = entry.get().saturating_sub(1);
+		if *entry.get() == 0 {
 			return Some(entry.remove_entry().0);
 		}
 	}
@@ -135,6 +141,14 @@ fn peer_id_from_multiaddr(addr: &Multiaddr) -> Option<PeerId> {
 }
 
 pub(super) struct Service<N, AD> {
+	// indexed by PeerSet as usize
+	state: Vec<StatePerPeerSet>,
+	// PhantomData used to make the struct generic instead of having generic methods
+	_phantom: PhantomData<(N, AD)>,
+}
+
+#[derive(Default)]
+struct StatePerPeerSet {
 	// Peers that are connected to us and authority ids associated to them.
 	connected_peers: HashMap<PeerId, HashSet<AuthorityDiscoveryId>>,
 	// The `u64` counts the number of pending non-revoked requests for this validator
@@ -143,20 +157,20 @@ pub(super) struct Service<N, AD> {
 	// Invariant: the value > 0 for non-revoked requests.
 	requested_validators: HashMap<AuthorityDiscoveryId, u64>,
 	non_revoked_discovery_requests: Vec<NonRevokedConnectionRequestState>,
-	// PhantomData used to make the struct generic instead of having generic methods
-	_phantom: PhantomData<(N, AD)>,
 }
 
 impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 	pub fn new() -> Self {
 		Self {
-			connected_peers: HashMap::new(),
-			requested_validators: HashMap::new(),
-			non_revoked_discovery_requests: Vec::new(),
+			state: PeerSet::iter().map(|_| Default::default()).collect(),
 			_phantom: PhantomData,
 		}
 	}
 
+	fn state_mut(&mut self, peer_set: PeerSet) -> &mut StatePerPeerSet {
+		&mut self.state[peer_set as usize]
+	}
+
 	/// Find connected validators using the given `validator_ids`.
 	///
 	/// Returns a [`HashMap`] that contains the found [`AuthorityDiscoveryId`]'s and their associated [`PeerId`]'s.
@@ -164,15 +178,24 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 	async fn find_connected_validators(
 		&mut self,
 		validator_ids: &[AuthorityDiscoveryId],
+		peer_set: PeerSet,
 		authority_discovery_service: &mut AD,
 	) -> HashMap<AuthorityDiscoveryId, PeerId> {
 		let mut result = HashMap::new();
+		let state = self.state_mut(peer_set);
 
 		for id in validator_ids {
 			// First check if we already cached the validator
-			if let Some(pid) = self.connected_peers
+			if let Some(pid) = state.connected_peers
 				.iter()
-				.find_map(|(pid, ids)| if ids.contains(&id) { Some(pid) } else { None }) {
+				.find_map(|(pid, ids)| {
+					if ids.contains(&id) {
+						Some(pid)
+					 } else {
+						None
+					}
+				})
+			{
 				result.insert(id.clone(), pid.clone());
 				continue;
 			}
@@ -180,9 +203,9 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 			// If not ask the authority discovery
 			if let Some(addresses) = authority_discovery_service.get_addresses_by_authority_id(id.clone()).await {
 				for peer_id in addresses.iter().filter_map(peer_id_from_multiaddr) {
-					if let Some(ids) = self.connected_peers.get_mut(&peer_id) {
+					if let Some(ids) = state.connected_peers.get_mut(&peer_id) {
 						ids.insert(id.clone());
-						result.insert(id.clone(), peer_id.clone());
+						result.insert(id.clone(), peer_id);
 					}
 				}
 			}
@@ -202,15 +225,22 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 	pub async fn on_request(
 		&mut self,
 		validator_ids: Vec<AuthorityDiscoveryId>,
+		peer_set: PeerSet,
 		mut connected: mpsc::Sender<(AuthorityDiscoveryId, PeerId)>,
 		mut network_service: N,
 		mut authority_discovery_service: AD,
 	) -> (N, AD) {
 		const MAX_ADDR_PER_PEER: usize = 3;
 
+		let already_connected = self.find_connected_validators(
+			&validator_ids,
+			peer_set,
+			&mut authority_discovery_service,
+		).await;
+
+		let state = self.state_mut(peer_set);
 		// Increment the counter of how many times the validators were requested.
-		validator_ids.iter().for_each(|id| *self.requested_validators.entry(id.clone()).or_default() += 1);
-		let already_connected = self.find_connected_validators(&validator_ids, &mut authority_discovery_service).await;
+		validator_ids.iter().for_each(|id| *state.requested_validators.entry(id.clone()).or_default() += 1);
 
 		// try to send already connected peers
 		for (id, peer) in already_connected.iter() {
@@ -218,7 +248,7 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 				Err(e) if e.is_disconnected() => {
 					// the request is already revoked
 					for peer_id in validator_ids {
-						let _ = on_revoke(&mut self.requested_validators, peer_id);
+						let _ = on_revoke(&mut state.requested_validators, peer_id);
 					}
 					return (network_service, authority_discovery_service);
 				}
@@ -238,10 +268,6 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 			let result = authority_discovery_service.get_addresses_by_authority_id(authority.clone()).await;
 			if let Some(addresses) = result {
 				// We might have several `PeerId`s per `AuthorityId`
-				// depending on the number of sentry nodes,
-				// so we limit the max number of sentries per node to connect to.
-				// They are going to be removed soon though:
-				// https://github.com/paritytech/substrate/issues/6845
 				multiaddr_to_add.extend(addresses.into_iter().take(MAX_ADDR_PER_PEER));
 			}
 		}
@@ -249,10 +275,10 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 		// clean up revoked requests
 		let mut revoked_indices = Vec::new();
 		let mut revoked_validators = Vec::new();
-		for (i, maybe_revoked) in self.non_revoked_discovery_requests.iter_mut().enumerate() {
+		for (i, maybe_revoked) in state.non_revoked_discovery_requests.iter_mut().enumerate() {
 			if maybe_revoked.is_revoked() {
 				for id in maybe_revoked.requested() {
-					if let Some(id) = on_revoke(&mut self.requested_validators, id.clone()) {
+					if let Some(id) = on_revoke(&mut state.requested_validators, id.clone()) {
 						revoked_validators.push(id);
 					}
 				}
@@ -262,7 +288,7 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 
 		// clean up revoked requests states
 		for to_revoke in revoked_indices.into_iter().rev() {
-			drop(self.non_revoked_discovery_requests.swap_remove(to_revoke));
+			drop(state.non_revoked_discovery_requests.swap_remove(to_revoke));
 		}
 
 		// multiaddresses to remove
@@ -277,33 +303,23 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 		// ask the network to connect to these nodes and not disconnect
 		// from them until removed from the set
 		if let Err(e) = network_service.add_peers_to_reserved_set(
-			PeerSet::Collation.into_protocol_name(),
+			peer_set.into_protocol_name(),
 			multiaddr_to_add.clone(),
 		).await {
 			tracing::warn!(target: LOG_TARGET, err = ?e, "AuthorityDiscoveryService returned an invalid multiaddress");
 		}
-		if let Err(e) = network_service.add_peers_to_reserved_set(
-			PeerSet::Validation.into_protocol_name(),
-			multiaddr_to_add,
-		).await {
-			tracing::warn!(target: LOG_TARGET, err = ?e, "AuthorityDiscoveryService returned an invalid multiaddress");
-		}
 		// the addresses are known to be valid
 		let _ = network_service.remove_peers_from_reserved_set(
-			PeerSet::Collation.into_protocol_name(),
+			peer_set.into_protocol_name(),
 			multiaddr_to_remove.clone()
 		).await;
-		let _ = network_service.remove_peers_from_reserved_set(
-			PeerSet::Validation.into_protocol_name(),
-			multiaddr_to_remove
-		).await;
 
 		let pending = validator_ids.iter()
 			.cloned()
 			.filter(|id| !already_connected.contains_key(id))
 			.collect::<HashSet<_>>();
 
-		self.non_revoked_discovery_requests.push(NonRevokedConnectionRequestState::new(
+		state.non_revoked_discovery_requests.push(NonRevokedConnectionRequestState::new(
 			validator_ids,
 			pending,
 			connected,
@@ -314,23 +330,30 @@ impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
 
 	/// Should be called when a peer connected.
 	#[tracing::instrument(level = "trace", skip(self, authority_discovery_service), fields(subsystem = LOG_TARGET))]
-	pub async fn on_peer_connected(&mut self, peer_id: &PeerId, authority_discovery_service: &mut AD) {
+	pub async fn on_peer_connected(
+		&mut self,
+		peer_id: PeerId,
+		peer_set: PeerSet,
+		authority_discovery_service: &mut AD,
+	) {
+		let state = self.state_mut(peer_set);
 		// check if it's an authority we've been waiting for
 		let maybe_authority = authority_discovery_service.get_authority_id_by_peer_id(peer_id.clone()).await;
 		if let Some(authority) = maybe_authority {
-			for request in self.non_revoked_discovery_requests.iter_mut() {
-				let _ = request.on_authority_connected(&authority, peer_id);
+			for request in state.non_revoked_discovery_requests.iter_mut() {
+				let _ = request.on_authority_connected(&authority, &peer_id);
 			}
 
-			self.connected_peers.entry(peer_id.clone()).or_default().insert(authority);
+			state.connected_peers.entry(peer_id).or_default().insert(authority);
 		} else {
-			self.connected_peers.insert(peer_id.clone(), Default::default());
+			state.connected_peers.insert(peer_id, Default::default());
 		}
 	}
 
 	/// Should be called when a peer disconnected.
-	pub fn on_peer_disconnected(&mut self, peer_id: &PeerId) {
-		self.connected_peers.remove(peer_id);
+	pub fn on_peer_disconnected(&mut self, peer_id: &PeerId, peer_set: PeerSet) {
+		let state = self.state_mut(peer_set);
+		state.connected_peers.remove(peer_id);
 	}
 }
 
@@ -453,10 +476,11 @@ mod tests {
 			let req1 = vec![authority_ids[0].clone(), authority_ids[1].clone()];
 			let (sender, mut receiver) = mpsc::channel(2);
 
-			service.on_peer_connected(&peer_ids[0], &mut ads).await;
+			service.on_peer_connected(peer_ids[0].clone(), PeerSet::Validation, &mut ads).await;
 
 			let _ = service.on_request(
 				req1,
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
@@ -485,18 +509,19 @@ mod tests {
 
 			let (_, mut ads) = service.on_request(
 				req1,
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
 			).await;
 
 
-			service.on_peer_connected(&peer_ids[0], &mut ads).await;
+			service.on_peer_connected(peer_ids[0].clone(), PeerSet::Validation, &mut ads).await;
 			let reply1 = receiver.next().await.unwrap();
 			assert_eq!(reply1.0, authority_ids[0]);
 			assert_eq!(reply1.1, peer_ids[0]);
 
-			service.on_peer_connected(&peer_ids[1], &mut ads).await;
+			service.on_peer_connected(peer_ids[1].clone(), PeerSet::Validation, &mut ads).await;
 			let reply2 = receiver.next().await.unwrap();
 			assert_eq!(reply2.0, authority_ids[1]);
 			assert_eq!(reply2.1, peer_ids[1]);
@@ -516,11 +541,12 @@ mod tests {
 		futures::executor::block_on(async move {
 			let (sender, mut receiver) = mpsc::channel(1);
 
-			service.on_peer_connected(&peer_ids[0], &mut ads).await;
-			service.on_peer_connected(&peer_ids[1], &mut ads).await;
+			service.on_peer_connected(peer_ids[0].clone(), PeerSet::Validation, &mut ads).await;
+			service.on_peer_connected(peer_ids[1].clone(), PeerSet::Validation, &mut ads).await;
 
 			let (ns, ads) = service.on_request(
 				vec![authority_ids[0].clone()],
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
@@ -534,6 +560,7 @@ mod tests {
 
 			let _ = service.on_request(
 				vec![authority_ids[1].clone()],
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
@@ -542,7 +569,8 @@ mod tests {
 			let reply = receiver.next().await.unwrap();
 			assert_eq!(reply.0, authority_ids[1]);
 			assert_eq!(reply.1, peer_ids[1]);
-			assert_eq!(service.non_revoked_discovery_requests.len(), 1);
+			let state = service.state_mut(PeerSet::Validation);
+			assert_eq!(state.non_revoked_discovery_requests.len(), 1);
 		});
 	}
 
@@ -559,11 +587,12 @@ mod tests {
 		futures::executor::block_on(async move {
 			let (sender, mut receiver) = mpsc::channel(1);
 
-			service.on_peer_connected(&peer_ids[0], &mut ads).await;
-			service.on_peer_connected(&peer_ids[1], &mut ads).await;
+			service.on_peer_connected(peer_ids[0].clone(), PeerSet::Validation, &mut ads).await;
+			service.on_peer_connected(peer_ids[1].clone(), PeerSet::Validation, &mut ads).await;
 
 			let (ns, ads) = service.on_request(
 				vec![authority_ids[0].clone(), authority_ids[2].clone()],
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
@@ -577,13 +606,15 @@ mod tests {
 
 			let (ns, ads) = service.on_request(
 				vec![authority_ids[0].clone(), authority_ids[1].clone()],
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
 			).await;
 
 			let _ = receiver.next().await.unwrap();
-			assert_eq!(service.non_revoked_discovery_requests.len(), 1);
+			let state = service.state_mut(PeerSet::Validation);
+			assert_eq!(state.non_revoked_discovery_requests.len(), 1);
 			assert_eq!(ns.peers_set.len(), 2);
 
 			// revoke the second request
@@ -593,13 +624,15 @@ mod tests {
 
 			let (ns, _) = service.on_request(
 				vec![authority_ids[0].clone()],
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
 			).await;
 
 			let _ = receiver.next().await.unwrap();
-			assert_eq!(service.non_revoked_discovery_requests.len(), 1);
+			let state = service.state_mut(PeerSet::Validation);
+			assert_eq!(state.non_revoked_discovery_requests.len(), 1);
 			assert_eq!(ns.peers_set.len(), 1);
 		});
 	}
@@ -619,7 +652,7 @@ mod tests {
 		futures::executor::block_on(async move {
 			let (sender, mut receiver) = mpsc::channel(1);
 
-			service.on_peer_connected(&validator_peer_id, &mut ads).await;
+			service.on_peer_connected(validator_peer_id.clone(), PeerSet::Validation, &mut ads).await;
 
 			let address = known_multiaddr()[0].clone().with(Protocol::P2p(validator_peer_id.clone().into()));
 			ads.by_peer_id.insert(validator_peer_id.clone(), validator_id.clone());
@@ -627,13 +660,20 @@ mod tests {
 
 			let _ = service.on_request(
 				vec![validator_id.clone()],
+				PeerSet::Validation,
 				sender,
 				ns,
 				ads,
 			).await;
 
 			assert_eq!((validator_id.clone(), validator_peer_id.clone()), receiver.next().await.unwrap());
-			assert!(service.connected_peers.get(&validator_peer_id).unwrap().contains(&validator_id));
+			let state = service.state_mut(PeerSet::Validation);
+			assert!(
+				state.connected_peers
+					.get(&validator_peer_id)
+					.unwrap()
+					.contains(&validator_id)
+			);
 		});
 	}
 }
diff --git a/polkadot/node/network/collator-protocol/src/collator_side.rs b/polkadot/node/network/collator-protocol/src/collator_side.rs
index afd3bc1a495..7365cda530c 100644
--- a/polkadot/node/network/collator-protocol/src/collator_side.rs
+++ b/polkadot/node/network/collator-protocol/src/collator_side.rs
@@ -28,7 +28,9 @@ use polkadot_subsystem::{
 	FromOverseer, OverseerSignal, SubsystemContext,
 	messages::{AllMessages, CollatorProtocolMessage, NetworkBridgeMessage, NetworkBridgeEvent},
 };
-use polkadot_node_network_protocol::{v1 as protocol_v1, View, PeerId, RequestId, OurView};
+use polkadot_node_network_protocol::{
+	peer_set::PeerSet, v1 as protocol_v1, View, PeerId, RequestId, OurView,
+};
 use polkadot_node_subsystem_util::{
 	validator_discovery,
 	request_validators_ctx,
@@ -365,6 +367,7 @@ async fn connect_to_validators(
 		ctx,
 		relay_parent,
 		validators,
+		PeerSet::Collation,
 	).await?;
 
 	state.connection_requests.put(relay_parent, request);
diff --git a/polkadot/node/network/pov-distribution/src/lib.rs b/polkadot/node/network/pov-distribution/src/lib.rs
index 6527be99c3d..d09641407c4 100644
--- a/polkadot/node/network/pov-distribution/src/lib.rs
+++ b/polkadot/node/network/pov-distribution/src/lib.rs
@@ -40,7 +40,7 @@ use polkadot_node_subsystem_util::{
 	metrics::{self, prometheus},
 };
 use polkadot_node_network_protocol::{
-	v1 as protocol_v1, ReputationChange as Rep, PeerId, OurView,
+	peer_set::PeerSet, v1 as protocol_v1, ReputationChange as Rep, PeerId, OurView,
 };
 
 use futures::prelude::*;
@@ -296,6 +296,42 @@ async fn distribute_to_awaiting(
 	metrics.on_pov_distributed();
 }
 
+/// Connect to relevant validators in case we are not already.
+async fn connect_to_relevant_validators(
+	connection_requests: &mut validator_discovery::ConnectionRequests,
+	ctx: &mut impl SubsystemContext<Message = PoVDistributionMessage>,
+	relay_parent: Hash,
+	descriptor: &CandidateDescriptor,
+) {
+	if let Ok(Some(relevant_validators)) =
+		determine_relevant_validators(ctx, relay_parent, descriptor.para_id).await
+	{
+		// We only need one connection request per (relay_parent, para_id)
+		// so here we take this shortcut to avoid calling `connect_to_validators`
+		// more than once.
+		if !connection_requests.contains_request(&relay_parent) {
+			tracing::debug!(target: LOG_TARGET, validators=?relevant_validators, "connecting to validators");
+			match validator_discovery::connect_to_validators(
+				ctx,
+				relay_parent,
+				relevant_validators,
+				PeerSet::Validation,
+			).await {
+				Ok(new_connection_request) => {
+					connection_requests.put(relay_parent, new_connection_request);
+				}
+				Err(e) => {
+					tracing::debug!(
+						target: LOG_TARGET,
+						"Failed to create a validator connection request {:?}",
+						e,
+					);
+				}
+			}
+		}
+	}
+}
+
 /// Get the Id of the Core that is assigned to the para being collated on if any
 /// and the total number of cores.
 async fn determine_core(
@@ -394,35 +430,8 @@ async fn handle_fetch(
 				return;
 			}
 			Entry::Vacant(e) => {
-				if let Ok(Some(relevant_validators)) = determine_relevant_validators(
-					ctx,
-					relay_parent,
-					descriptor.para_id,
-				).await {
-					// We only need one connection request per (relay_parent, para_id)
-					// so here we take this shortcut to avoid calling `connect_to_validators`
-					// more than once.
-					if !state.connection_requests.contains_request(&relay_parent) {
-						match validator_discovery::connect_to_validators(
-							ctx,
-							relay_parent,
-							relevant_validators.clone(),
-						).await {
-							Ok(new_connection_request) => {
-								state.connection_requests.put(relay_parent, new_connection_request);
-							}
-							Err(e) => {
-								tracing::debug!(
-									target: LOG_TARGET,
-									"Failed to create a validator connection request {:?}",
-									e,
-								);
-							}
-						}
-					}
-
-					e.insert(vec![response_sender]);
-				}
+				connect_to_relevant_validators(&mut state.connection_requests, ctx, relay_parent, &descriptor).await;
+				e.insert(vec![response_sender]);
 			}
 		}
 	}
@@ -460,6 +469,8 @@ async fn handle_distribute(
 		None => return,
 	};
 
+	connect_to_relevant_validators(&mut state.connection_requests, ctx, relay_parent, &descriptor).await;
+
 	if let Some(our_awaited) = relay_parent_state.fetching.get_mut(&descriptor.pov_hash) {
 		// Drain all the senders, but keep the entry in the map around intentionally.
 		//
@@ -640,7 +651,7 @@ async fn handle_incoming_pov(
 	relay_parent_state.known.insert(pov_hash, (pov, encoded_pov));
 }
 
-/// Handles a newly connected validator in the context of some relay leaf.
+/// Handles a newly or already connected validator in the context of some relay leaf.
 fn handle_validator_connected(state: &mut State, peer_id: PeerId) {
 	state.peer_state.entry(peer_id).or_default();
 }
@@ -718,10 +729,17 @@ impl PoVDistribution {
 
 	#[tracing::instrument(skip(self, ctx), fields(subsystem = LOG_TARGET))]
 	async fn run(
+		self,
+		ctx: impl SubsystemContext<Message = PoVDistributionMessage>,
+	) -> SubsystemResult<()> {
+		self.run_with_state(ctx, State::default()).await
+	}
+
+	async fn run_with_state(
 		self,
 		mut ctx: impl SubsystemContext<Message = PoVDistributionMessage>,
+		mut state: State,
 	) -> SubsystemResult<()> {
-		let mut state = State::default();
 		state.metrics = self.metrics;
 
 		loop {
diff --git a/polkadot/node/network/pov-distribution/src/tests.rs b/polkadot/node/network/pov-distribution/src/tests.rs
index afe83b6399e..96822e79462 100644
--- a/polkadot/node/network/pov-distribution/src/tests.rs
+++ b/polkadot/node/network/pov-distribution/src/tests.rs
@@ -60,6 +60,7 @@ struct TestHarness {
 }
 
 fn test_harness<T: Future<Output = ()>>(
+	state: State,
 	test: impl FnOnce(TestHarness) -> T,
 ) {
 	let _ = env_logger::builder()
@@ -80,7 +81,7 @@ fn test_harness<T: Future<Output = ()>>(
 
 	let subsystem = super::PoVDistribution::new(Metrics::default());
 
-	let subsystem = subsystem.run(context);
+	let subsystem = subsystem.run_with_state(context, state);
 
 	let test_fut = test(TestHarness { virtual_overseer });
 
@@ -257,7 +258,7 @@ async fn test_validator_discovery(
 fn ask_validators_for_povs() {
 	let test_state = TestState::default();
 
-	test_harness(|test_harness| async move {
+	test_harness(State::default(), |test_harness| async move {
 		let mut virtual_overseer = test_harness.virtual_overseer;
 
 		let pov_block = PoV {
@@ -566,7 +567,7 @@ fn distributes_to_those_awaiting_and_completes_local() {
 	let pov = make_pov(vec![1, 2, 3]);
 	let pov_hash = pov.hash();
 
-	let mut state = State {
+	let state = State {
 		relay_parent_state: {
 			let mut s = HashMap::new();
 			let mut b = BlockBasedState {
@@ -607,28 +608,38 @@ fn distributes_to_those_awaiting_and_completes_local() {
 		connection_requests: Default::default(),
 	};
 
-	let pool = sp_core::testing::TaskExecutor::new();
-	let (mut ctx, mut handle) = polkadot_node_subsystem_test_helpers::make_subsystem_context(pool);
 	let mut descriptor = CandidateDescriptor::default();
 	descriptor.pov_hash = pov_hash;
 
-	executor::block_on(async move {
-		handle_distribute(
-			&mut state,
-			&mut ctx,
-			hash_a,
-			descriptor,
-			Arc::new(pov.clone()),
+	test_harness(state, |test_harness| async move {
+		let mut virtual_overseer = test_harness.virtual_overseer;
+
+		overseer_send(
+			&mut virtual_overseer,
+			PoVDistributionMessage::DistributePoV(
+				hash_a,
+				descriptor,
+				Arc::new(pov.clone())
+			)
 		).await;
 
-		assert!(!state.peer_state[&peer_a].awaited[&hash_a].contains(&pov_hash));
-		assert!(state.peer_state[&peer_c].awaited[&hash_b].contains(&pov_hash));
+		// Let's assume runtime call failed and we're already connected to the peers.
+		assert_matches!(
+			virtual_overseer.recv().await,
+			AllMessages::RuntimeApi(RuntimeApiMessage::Request(
+				relay_parent,
+				RuntimeApiRequest::AvailabilityCores(tx)
+			)) => {
+				assert_eq!(relay_parent, hash_a);
+				tx.send(Err("nope".to_string().into())).unwrap();
+			}
+		);
 
 		// our local sender also completed
 		assert_eq!(&*pov_recv.await.unwrap(), &pov);
 
 		assert_matches!(
-			handle.recv().await,
+			virtual_overseer.recv().await,
 			AllMessages::NetworkBridge(
 				NetworkBridgeMessage::SendValidationMessage(peers, message)
 			) => {
diff --git a/polkadot/node/network/protocol/src/peer_set.rs b/polkadot/node/network/protocol/src/peer_set.rs
index 856d355785d..01d14067b90 100644
--- a/polkadot/node/network/protocol/src/peer_set.rs
+++ b/polkadot/node/network/protocol/src/peer_set.rs
@@ -21,13 +21,14 @@ use std::borrow::Cow;
 use strum::{EnumIter, IntoEnumIterator};
 
 /// The peer-sets and thus the protocols which are used for the network.
-#[derive(Debug, Clone, Copy, PartialEq, EnumIter)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EnumIter)]
+#[repr(usize)]
 pub enum PeerSet {
 	/// The validation peer-set is responsible for all messages related to candidate validation and
 	/// communication among validators.
-	Validation,
+	Validation = 0,
 	/// The collation peer-set is used for validator<>collator communication.
-	Collation,
+	Collation = 1,
 }
 
 impl PeerSet {
diff --git a/polkadot/node/subsystem-util/Cargo.toml b/polkadot/node/subsystem-util/Cargo.toml
index 4b165bdf6d2..0267b5db570 100644
--- a/polkadot/node/subsystem-util/Cargo.toml
+++ b/polkadot/node/subsystem-util/Cargo.toml
@@ -20,6 +20,7 @@ tracing-futures = "0.2.4"
 polkadot-node-primitives = { path = "../primitives" }
 polkadot-node-subsystem = { path = "../subsystem" }
 polkadot-node-jaeger = { path = "../jaeger" }
+polkadot-node-network-protocol = { path = "../network/protocol" }
 polkadot-primitives = { path = "../../primitives" }
 metered-channel = { path = "../metered-channel"}
 
diff --git a/polkadot/node/subsystem-util/src/validator_discovery.rs b/polkadot/node/subsystem-util/src/validator_discovery.rs
index 9472d44d40c..a16af82b82f 100644
--- a/polkadot/node/subsystem-util/src/validator_discovery.rs
+++ b/polkadot/node/subsystem-util/src/validator_discovery.rs
@@ -33,6 +33,7 @@ use polkadot_node_subsystem::{
 	SubsystemContext,
 };
 use polkadot_primitives::v1::{Hash, ValidatorId, AuthorityDiscoveryId, SessionIndex};
+use polkadot_node_network_protocol::peer_set::PeerSet;
 use sc_network::PeerId;
 use crate::Error;
 
@@ -41,16 +42,24 @@ pub async fn connect_to_validators<Context: SubsystemContext>(
 	ctx: &mut Context,
 	relay_parent: Hash,
 	validators: Vec<ValidatorId>,
+	peer_set: PeerSet,
 ) -> Result<ConnectionRequest, Error> {
 	let current_index = crate::request_session_index_for_child_ctx(relay_parent, ctx).await?.await??;
-	connect_to_past_session_validators(ctx, relay_parent, validators, current_index).await
+	connect_to_validators_in_session(
+		ctx,
+		relay_parent,
+		validators,
+		peer_set,
+		current_index,
+	).await
 }
 
-/// Utility function to make it easier to connect to validators in the past sessions.
-pub async fn connect_to_past_session_validators<Context: SubsystemContext>(
+/// Utility function to make it easier to connect to validators in the given session.
+pub async fn connect_to_validators_in_session<Context: SubsystemContext>(
 	ctx: &mut Context,
 	relay_parent: Hash,
 	validators: Vec<ValidatorId>,
+	peer_set: PeerSet,
 	session_index: SessionIndex,
 ) -> Result<ConnectionRequest, Error> {
 	let session_info = crate::request_session_info_ctx(
@@ -66,6 +75,14 @@ pub async fn connect_to_past_session_validators<Context: SubsystemContext>(
 		).into()),
 	};
 
+	tracing::trace!(
+		target: "network_bridge",
+		validators = ?validators,
+		discovery_keys = ?discovery_keys,
+		session_index,
+		"Trying to serve the validator discovery request",
+	);
+
 	let id_to_index = session_validators.iter()
 		.zip(0usize..)
 		.collect::<HashMap<_, _>>();
@@ -88,7 +105,7 @@ pub async fn connect_to_past_session_validators<Context: SubsystemContext>(
 		.filter_map(|(k, v)| v.map(|v| (v, k)))
 		.collect::<HashMap<AuthorityDiscoveryId, ValidatorId>>();
 
-	let connections = connect_to_authorities(ctx, authorities).await;
+	let connections = connect_to_authorities(ctx, authorities, peer_set).await;
 
 	Ok(ConnectionRequest {
 		validator_map,
@@ -99,14 +116,16 @@ pub async fn connect_to_past_session_validators<Context: SubsystemContext>(
 async fn connect_to_authorities<Context: SubsystemContext>(
 	ctx: &mut Context,
 	validator_ids: Vec<AuthorityDiscoveryId>,
+	peer_set: PeerSet,
 ) -> mpsc::Receiver<(AuthorityDiscoveryId, PeerId)> {
-	const PEERS_CAPACITY: usize = 8;
+	const PEERS_CAPACITY: usize = 32;
 
 	let (connected, connected_rx) = mpsc::channel(PEERS_CAPACITY);
 
 	ctx.send_message(AllMessages::NetworkBridge(
 		NetworkBridgeMessage::ConnectToValidators {
 			validator_ids,
+			peer_set,
 			connected,
 		}
 	)).await;
diff --git a/polkadot/node/subsystem/src/messages.rs b/polkadot/node/subsystem/src/messages.rs
index 19bb56076e6..f6c91b6d630 100644
--- a/polkadot/node/subsystem/src/messages.rs
+++ b/polkadot/node/subsystem/src/messages.rs
@@ -25,7 +25,7 @@
 use futures::channel::{mpsc, oneshot};
 use thiserror::Error;
 use polkadot_node_network_protocol::{
-	v1 as protocol_v1, ReputationChange, PeerId,
+	peer_set::PeerSet, v1 as protocol_v1, ReputationChange, PeerId,
 	request_response::{Requests, request::IncomingRequest, v1 as req_res_v1},
 };
 use polkadot_node_primitives::{
@@ -229,6 +229,8 @@ pub enum NetworkBridgeMessage {
 	ConnectToValidators {
 		/// Ids of the validators to connect to.
 		validator_ids: Vec<AuthorityDiscoveryId>,
+		/// The underlying protocol to use for this request.
+		peer_set: PeerSet,
 		/// Response sender by which the issuer can learn the `PeerId`s of
 		/// the validators as they are connected.
 		/// The response is sent immediately for already connected peers.
diff --git a/polkadot/roadmap/implementers-guide/src/types/overseer-protocol.md b/polkadot/roadmap/implementers-guide/src/types/overseer-protocol.md
index 5b298472141..60994961a88 100644
--- a/polkadot/roadmap/implementers-guide/src/types/overseer-protocol.md
+++ b/polkadot/roadmap/implementers-guide/src/types/overseer-protocol.md
@@ -337,6 +337,8 @@ enum NetworkBridgeMessage {
     ConnectToValidators {
         /// Ids of the validators to connect to.
         validator_ids: Vec<AuthorityDiscoveryId>,
+        /// The underlying protocol to use for this request.
+        peer_set: PeerSet,
         /// Response sender by which the issuer can learn the `PeerId`s of
         /// the validators as they are connected.
         /// The response is sent immediately for already connected peers.
-- 
GitLab