Unverified Commit d1a2eff2 authored by Andronik Ordian's avatar Andronik Ordian Committed by GitHub
Browse files

validator_discovery: less flexible, but simpler design (#3052)

* validator_discovery: less flexible, but simpler design

* fix test

* remove unused struct

* smol optimization
parent 9fd0f5ba
Pipeline #138717 passed with stages
in 27 minutes and 57 seconds
......@@ -36,7 +36,6 @@ use requester::Requester;
/// Handing requests for PoVs during backing.
mod pov_requester;
use pov_requester::PoVRequester;
/// Responding to erasure chunk requests:
mod responder;
......@@ -90,7 +89,6 @@ impl AvailabilityDistributionSubsystem {
Context: SubsystemContext<Message = AvailabilityDistributionMessage> + Sync + Send,
{
let mut requester = Requester::new(self.metrics.clone()).fuse();
let mut pov_requester = PoVRequester::new();
loop {
let action = {
let mut subsystem_next = ctx.recv().fuse();
......@@ -113,18 +111,6 @@ impl AvailabilityDistributionSubsystem {
};
match message {
FromOverseer::Signal(OverseerSignal::ActiveLeaves(update)) => {
let result = pov_requester.update_connected_validators(
&mut ctx,
&mut self.runtime,
&update,
).await;
if let Err(error) = result {
tracing::debug!(
target: LOG_TARGET,
?error,
"PoVRequester::update_connected_validators",
);
}
log_error(
requester.get_mut().update_fetching_heads(&mut ctx, &mut self.runtime, update).await,
"Error in Requester::update_fetching_heads"
......@@ -154,7 +140,7 @@ impl AvailabilityDistributionSubsystem {
},
} => {
log_error(
pov_requester.fetch_pov(
pov_requester::fetch_pov(
&mut ctx,
&mut self.runtime,
relay_parent,
......@@ -163,7 +149,7 @@ impl AvailabilityDistributionSubsystem {
pov_hash,
tx,
).await,
"PoVRequester::fetch_pov"
"pov_requester::fetch_pov"
)?;
}
}
......
......@@ -17,118 +17,68 @@
//! PoV requester takes care of requesting PoVs from validators of a backing group.
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use lru::LruCache;
use polkadot_subsystem::jaeger;
use polkadot_node_network_protocol::{
peer_set::PeerSet,
request_response::{OutgoingRequest, Recipient, request::{RequestError, Requests},
v1::{PoVFetchingRequest, PoVFetchingResponse}}
};
use polkadot_primitives::v1::{
AuthorityDiscoveryId, CandidateHash, Hash, SessionIndex, ValidatorIndex
CandidateHash, Hash, ValidatorIndex,
};
use polkadot_node_primitives::PoV;
use polkadot_subsystem::{
ActiveLeavesUpdate, SubsystemContext, ActivatedLeaf,
SubsystemContext,
messages::{AllMessages, NetworkBridgeMessage, IfDisconnected}
};
use polkadot_node_subsystem_util::runtime::{RuntimeInfo, ValidatorInfo};
use polkadot_node_subsystem_util::runtime::RuntimeInfo;
use crate::error::{Fatal, NonFatal};
use crate::LOG_TARGET;
/// Number of sessions we want to keep in the LRU.
const NUM_SESSIONS: usize = 2;
pub struct PoVRequester {
/// We only ever care about being connected to validators of at most two sessions.
///
/// So we keep an LRU for managing connection requests of size 2.
/// Cache will contain `None` if we are not a validator in that session.
connected_validators: LruCache<SessionIndex, Option<oneshot::Sender<()>>>,
}
impl PoVRequester {
/// Create a new requester for PoVs.
pub fn new() -> Self {
Self {
connected_validators: LruCache::new(NUM_SESSIONS),
}
}
/// Make sure we are connected to the right set of validators.
///
/// On every `ActiveLeavesUpdate`, we check whether we are connected properly to our current
/// validator group.
pub async fn update_connected_validators<Context>(
&mut self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
update: &ActiveLeavesUpdate,
) -> super::Result<()>
where
Context: SubsystemContext,
{
let activated = update.activated.iter().map(|ActivatedLeaf { hash: h, .. }| h);
let activated_sessions =
get_activated_sessions(ctx, runtime, activated).await?;
for (parent, session_index) in activated_sessions {
if self.connected_validators.contains(&session_index) {
continue
}
let tx = connect_to_relevant_validators(ctx, runtime, parent, session_index).await?;
self.connected_validators.put(session_index, tx);
}
Ok(())
}
/// Start background worker for taking care of fetching the requested `PoV` from the network.
pub async fn fetch_pov<Context>(
&self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
parent: Hash,
from_validator: ValidatorIndex,
candidate_hash: CandidateHash,
pov_hash: Hash,
tx: oneshot::Sender<PoV>
) -> super::Result<()>
where
Context: SubsystemContext,
{
let info = &runtime.get_session_info(ctx, parent).await?.session_info;
let authority_id = info.discovery_keys.get(from_validator.0 as usize)
.ok_or(NonFatal::InvalidValidatorIndex)?
.clone();
let (req, pending_response) = OutgoingRequest::new(
Recipient::Authority(authority_id),
PoVFetchingRequest {
candidate_hash,
},
);
let full_req = Requests::PoVFetching(req);
ctx.send_message(
AllMessages::NetworkBridge(
NetworkBridgeMessage::SendRequests(
vec![full_req],
// We are supposed to be connected to validators of our group via `PeerSet`,
// but at session boundaries that is kind of racy, in case a connection takes
// longer to get established, so we try to connect in any case.
IfDisconnected::TryConnect
)
)).await;
let span = jaeger::Span::new(candidate_hash, "fetch-pov")
.with_validator_index(from_validator)
.with_relay_parent(parent);
ctx.spawn("pov-fetcher", fetch_pov_job(pov_hash, pending_response.boxed(), span, tx).boxed())
.await
.map_err(|e| Fatal::SpawnTask(e))?;
Ok(())
}
/// Start background worker for taking care of fetching the requested `PoV` from the network.
pub async fn fetch_pov<Context>(
ctx: &mut Context,
runtime: &mut RuntimeInfo,
parent: Hash,
from_validator: ValidatorIndex,
candidate_hash: CandidateHash,
pov_hash: Hash,
tx: oneshot::Sender<PoV>
) -> super::Result<()>
where
Context: SubsystemContext,
{
let info = &runtime.get_session_info(ctx, parent).await?.session_info;
let authority_id = info.discovery_keys.get(from_validator.0 as usize)
.ok_or(NonFatal::InvalidValidatorIndex)?
.clone();
let (req, pending_response) = OutgoingRequest::new(
Recipient::Authority(authority_id),
PoVFetchingRequest {
candidate_hash,
},
);
let full_req = Requests::PoVFetching(req);
ctx.send_message(
AllMessages::NetworkBridge(
NetworkBridgeMessage::SendRequests(
vec![full_req],
// We are supposed to be connected to validators of our group via `PeerSet`,
// but at session boundaries that is kind of racy, in case a connection takes
// longer to get established, so we try to connect in any case.
IfDisconnected::TryConnect
)
)).await;
let span = jaeger::Span::new(candidate_hash, "fetch-pov")
.with_validator_index(from_validator)
.with_relay_parent(parent);
ctx.spawn("pov-fetcher", fetch_pov_job(pov_hash, pending_response.boxed(), span, tx).boxed())
.await
.map_err(|e| Fatal::SpawnTask(e))?;
Ok(())
}
/// Future to be spawned for taking care of handling reception and sending of PoV.
......@@ -170,74 +120,6 @@ async fn do_fetch_pov(
}
}
/// Get the session indeces for the given relay chain parents.
async fn get_activated_sessions<Context>(ctx: &mut Context, runtime: &mut RuntimeInfo, new_heads: impl Iterator<Item = &Hash>)
-> super::Result<impl Iterator<Item = (Hash, SessionIndex)>>
where
Context: SubsystemContext,
{
let mut sessions = Vec::new();
for parent in new_heads {
sessions.push((*parent, runtime.get_session_index(ctx, *parent).await?));
}
Ok(sessions.into_iter())
}
/// Connect to validators of our validator group.
async fn connect_to_relevant_validators<Context>(
ctx: &mut Context,
runtime: &mut RuntimeInfo,
parent: Hash,
session: SessionIndex
)
-> super::Result<Option<oneshot::Sender<()>>>
where
Context: SubsystemContext,
{
if let Some(validator_ids) = determine_relevant_validators(ctx, runtime, parent, session).await? {
let (tx, keep_alive) = oneshot::channel();
ctx.send_message(AllMessages::NetworkBridge(NetworkBridgeMessage::ConnectToValidators {
validator_ids, peer_set: PeerSet::Validation, keep_alive
})).await;
Ok(Some(tx))
} else {
Ok(None)
}
}
/// Get the validators in our validator group.
///
/// Return: `None` if not a validator.
async fn determine_relevant_validators<Context>(
ctx: &mut Context,
runtime: &mut RuntimeInfo,
parent: Hash,
session: SessionIndex,
)
-> super::Result<Option<Vec<AuthorityDiscoveryId>>>
where
Context: SubsystemContext,
{
let info = runtime.get_session_info_by_index(ctx, parent, session).await?;
if let ValidatorInfo {
our_index: Some(our_index),
our_group: Some(our_group)
} = &info.validator_info {
let indeces = info.session_info.validator_groups.get(our_group.0 as usize)
.expect("Our group got retrieved from that session info, it must exist. qed.")
.clone();
Ok(Some(
indeces.into_iter()
.filter(|i| *i != *our_index)
.map(|i| info.session_info.discovery_keys[i.0 as usize].clone())
.collect()
))
} else {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
......@@ -274,7 +156,6 @@ mod tests {
}
fn test_run(pov_hash: Hash, pov: PoV) {
let requester = PoVRequester::new();
let pool = TaskExecutor::new();
let (mut context, mut virtual_overseer) =
test_helpers::make_subsystem_context::<AvailabilityDistributionMessage, TaskExecutor>(pool.clone());
......@@ -283,7 +164,7 @@ mod tests {
let (tx, rx) = oneshot::channel();
let testee = async {
requester.fetch_pov(
fetch_pov(
&mut context,
&mut runtime,
Hash::default(),
......
......@@ -509,7 +509,6 @@ where
NetworkBridgeMessage::ConnectToValidators {
validator_ids,
peer_set,
keep_alive,
} => {
tracing::trace!(
target: LOG_TARGET,
......@@ -522,7 +521,6 @@ where
let (ns, ads) = validator_discovery.on_request(
validator_ids,
peer_set,
keep_alive,
network_service,
authority_discovery_service,
).await;
......
......@@ -19,10 +19,9 @@
use crate::Network;
use core::marker::PhantomData;
use std::collections::{HashSet, HashMap, hash_map};
use std::collections::HashSet;
use async_trait::async_trait;
use futures::channel::oneshot;
use sc_network::multiaddr::Multiaddr;
use sc_authority_discovery::Service as AuthorityDiscoveryService;
......@@ -52,51 +51,6 @@ impl AuthorityDiscovery for AuthorityDiscoveryService {
}
}
/// This struct tracks the state for one `ConnectToValidators` request.
struct NonRevokedConnectionRequestState {
requested: Vec<AuthorityDiscoveryId>,
keep_alive: oneshot::Receiver<()>,
}
impl NonRevokedConnectionRequestState {
/// Create a new instance of `ConnectToValidatorsState`.
pub fn new(
requested: Vec<AuthorityDiscoveryId>,
keep_alive: oneshot::Receiver<()>,
) -> Self {
Self {
requested,
keep_alive,
}
}
/// Returns `true` if the request is revoked.
pub fn is_revoked(&mut self) -> bool {
self.keep_alive.try_recv().is_err()
}
pub fn requested(&self) -> &[AuthorityDiscoveryId] {
self.requested.as_ref()
}
}
/// Will be called by [`Service::on_request`] when a request was revoked.
///
/// Takes the `map` of requested validators and the `id` of the validator that should be revoked.
///
/// Returns `Some(id)` iff the request counter is `0`.
fn on_revoke(map: &mut HashMap<AuthorityDiscoveryId, u64>, id: AuthorityDiscoveryId) -> Option<AuthorityDiscoveryId> {
if let hash_map::Entry::Occupied(mut entry) = map.entry(id) {
*entry.get_mut() = entry.get().saturating_sub(1);
if *entry.get() == 0 {
return Some(entry.remove_entry().0);
}
}
None
}
pub(super) struct Service<N, AD> {
state: PerPeerSet<StatePerPeerSet>,
// PhantomData used to make the struct generic instead of having generic methods
......@@ -105,111 +59,67 @@ pub(super) struct Service<N, AD> {
#[derive(Default)]
struct StatePerPeerSet {
// The `u64` counts the number of pending non-revoked requests for this validator
// note: the validators in this map are not necessarily present
// in the `connected_validators` map.
// Invariant: the value > 0 for non-revoked requests.
requested_validators: HashMap<AuthorityDiscoveryId, u64>,
non_revoked_discovery_requests: Vec<NonRevokedConnectionRequestState>,
previously_requested: HashSet<Multiaddr>,
}
impl<N: Network, AD: AuthorityDiscovery> Service<N, AD> {
pub fn new() -> Self {
Self {
state: PerPeerSet::default(),
state: Default::default(),
_phantom: PhantomData,
}
}
/// On a new connection request, a peer set update will be issued.
/// It will ask the network to connect to the validators and not disconnect
/// from them at least until all the pending requests containing them are revoked.
/// from them at least until the next request is issued for the same peer set.
///
/// This method will also clean up all previously revoked requests.
/// This method will also disconnect from previously connected validators not in the `validator_ids` set.
/// it takes `network_service` and `authority_discovery_service` by value
/// and returns them as a workaround for the Future: Send requirement imposed by async fn impl.
pub async fn on_request(
&mut self,
validator_ids: Vec<AuthorityDiscoveryId>,
peer_set: PeerSet,
keep_alive: oneshot::Receiver<()>,
mut network_service: N,
mut authority_discovery_service: AD,
) -> (N, AD) {
const MAX_ADDR_PER_PEER: usize = 3;
let state = &mut self.state[peer_set];
// Increment the counter of how many times the validators were requested.
validator_ids.iter().for_each(|id| *state.requested_validators.entry(id.clone()).or_default() += 1);
// collect multiaddress of validators
let mut multiaddr_to_add = HashSet::new();
for authority in validator_ids.iter() {
let mut newly_requested = HashSet::new();
for authority in validator_ids.into_iter() {
let result = authority_discovery_service.get_addresses_by_authority_id(authority.clone()).await;
if let Some(addresses) = result {
// We might have several `PeerId`s per `AuthorityId`
multiaddr_to_add.extend(addresses.into_iter().take(MAX_ADDR_PER_PEER));
newly_requested.extend(addresses);
} else {
tracing::debug!(target: LOG_TARGET, "Authority Discovery couldn't resolve {:?}", authority);
}
}
let state = &mut self.state[peer_set];
// clean up revoked requests
let mut revoked_indices = Vec::new();
let mut revoked_validators = Vec::new();
for (i, maybe_revoked) in state.non_revoked_discovery_requests.iter_mut().enumerate() {
if maybe_revoked.is_revoked() {
for id in maybe_revoked.requested() {
if let Some(id) = on_revoke(&mut state.requested_validators, id.clone()) {
revoked_validators.push(id);
}
}
revoked_indices.push(i);
}
}
// clean up revoked requests states
//
// note that the `.rev()` here is important to guarantee `swap_remove`
// doesn't invalidate unprocessed `revoked_indices`
for to_revoke in revoked_indices.into_iter().rev() {
drop(state.non_revoked_discovery_requests.swap_remove(to_revoke));
}
// multiaddresses to remove
let mut multiaddr_to_remove = HashSet::new();
for id in revoked_validators.into_iter() {
let result = authority_discovery_service.get_addresses_by_authority_id(id.clone()).await;
if let Some(addresses) = result {
multiaddr_to_remove.extend(addresses.into_iter());
} else {
tracing::debug!(
target: LOG_TARGET,
"Authority Discovery couldn't resolve {:?} on cleanup, a leak is possible",
id,
);
}
}
let multiaddr_to_remove = state.previously_requested
.difference(&newly_requested)
.cloned()
.collect();
let multiaddr_to_add = newly_requested.difference(&state.previously_requested)
.cloned()
.collect();
state.previously_requested = newly_requested;
// ask the network to connect to these nodes and not disconnect
// from them until removed from the set
if let Err(e) = network_service.add_to_peers_set(
peer_set.into_protocol_name(),
multiaddr_to_add.clone(),
multiaddr_to_add,
).await {
tracing::warn!(target: LOG_TARGET, err = ?e, "AuthorityDiscoveryService returned an invalid multiaddress");
}
// the addresses are known to be valid
let _ = network_service.remove_from_peers_set(
peer_set.into_protocol_name(),
multiaddr_to_remove.clone()
multiaddr_to_remove
).await;
state.non_revoked_discovery_requests.push(NonRevokedConnectionRequestState::new(
validator_ids,
keep_alive,
));
(network_service, authority_discovery_service)
}
}
......@@ -219,7 +129,7 @@ mod tests {
use super::*;
use crate::network::{Network, NetworkAction};
use std::{borrow::Cow, pin::Pin};
use std::{borrow::Cow, pin::Pin, collections::HashMap};
use futures::{sink::Sink, stream::BoxStream};
use sc_network::{Event as NetworkEvent, IfDisconnected};
use sp_keyring::Sr25519Keyring;
......@@ -317,26 +227,9 @@ mod tests {
"/ip4/127.0.0.1/tcp/1236".parse().unwrap(),
]
}
#[test]
fn request_is_revoked_when_the_receiver_is_dropped() {
let (keep_alive_handle, keep_alive) = oneshot::channel();
let mut request = NonRevokedConnectionRequestState::new(
Vec::new(),
keep_alive,
);
assert!(!request.is_revoked());
drop(keep_alive_handle);
assert!(request.is_revoked());
}
// Test cleanup works.
#[test]
fn requests_are_removed_on_revoke() {
fn old_multiaddrs_are_removed_on_new_request() {
let mut service = new_service();
let (ns, ads) = new_network();
......@@ -344,87 +237,22 @@ mod tests {
let authority_ids: Vec<_> = ads.by_peer_id.values().cloned().collect();
futures::executor::block_on(async move {
let (keep_alive_handle, keep_alive) = oneshot::channel();
let (ns, ads) = service.on_request(
vec![authority_ids[0].clone()],
PeerSet::Validation,
keep_alive,
ns,
ads,
).await;
// revoke the request
drop(keep_alive_handle);
let (_keep_alive_handle, keep_alive) = oneshot::channel();
let _ = service.on_request(
vec![authority_ids[1].clone()],
PeerSet::Validation,
keep_alive,
ns,
ads,
).await;
let state = &service.state[PeerSet::Validation];
assert_eq!(state.non_revoked_discovery_requests.len(), 1);
});
}
// More complex test with overlapping revoked requests
#[test]
fn revoking_requests_with_overlapping_validator_sets() {
let mut service = new_service();
let (ns, ads) = new_network();
let authority_ids: Vec<_> = ads.by_peer_id.values().cloned().collect();
futures::executor::block_on(async move {
let (keep_alive_handle, keep_alive) = oneshot::channel();