Unverified Commit 3477414e authored by asynchronous rob's avatar asynchronous rob Committed by GitHub
Browse files

Use background tasks properly in candidate-validation (#4002)

* refactor: candidate-validation background tasks

* fix tests

* fmt
parent 6cb31e50
Pipeline #160455 canceled with stages
in 11 minutes and 33 seconds
...@@ -35,7 +35,7 @@ use polkadot_node_subsystem::{ ...@@ -35,7 +35,7 @@ use polkadot_node_subsystem::{
CandidateValidationMessage, RuntimeApiMessage, RuntimeApiRequest, ValidationFailed, CandidateValidationMessage, RuntimeApiMessage, RuntimeApiRequest, ValidationFailed,
}, },
overseer, FromOverseer, OverseerSignal, SpawnedSubsystem, SubsystemContext, SubsystemError, overseer, FromOverseer, OverseerSignal, SpawnedSubsystem, SubsystemContext, SubsystemError,
SubsystemResult, SubsystemResult, SubsystemSender,
}; };
use polkadot_node_subsystem_util::metrics::{self, prometheus}; use polkadot_node_subsystem_util::metrics::{self, prometheus};
use polkadot_parachain::primitives::{ValidationParams, ValidationResult as WasmValidationResult}; use polkadot_parachain::primitives::{ValidationParams, ValidationResult as WasmValidationResult};
...@@ -120,7 +120,7 @@ where ...@@ -120,7 +120,7 @@ where
Context: SubsystemContext<Message = CandidateValidationMessage>, Context: SubsystemContext<Message = CandidateValidationMessage>,
Context: overseer::SubsystemContext<Message = CandidateValidationMessage>, Context: overseer::SubsystemContext<Message = CandidateValidationMessage>,
{ {
let (mut validation_host, task) = polkadot_node_core_pvf::start( let (validation_host, task) = polkadot_node_core_pvf::start(
polkadot_node_core_pvf::Config::new(cache_path, program_path), polkadot_node_core_pvf::Config::new(cache_path, program_path),
pvf_metrics, pvf_metrics,
); );
...@@ -137,24 +137,28 @@ where ...@@ -137,24 +137,28 @@ where
pov, pov,
response_sender, response_sender,
) => { ) => {
let _timer = metrics.time_validate_from_chain_state(); let bg = {
let mut sender = ctx.sender().clone();
let res = spawn_validate_from_chain_state( let metrics = metrics.clone();
&mut ctx, let validation_host = validation_host.clone();
&mut validation_host,
descriptor, async move {
pov, let _timer = metrics.time_validate_from_chain_state();
&metrics, let res = validate_from_chain_state(
) &mut sender,
.await; validation_host,
descriptor,
match res { pov,
Ok(x) => { &metrics,
metrics.on_validation_event(&x); )
let _ = response_sender.send(x); .await;
},
Err(e) => return Err(e), metrics.on_validation_event(&res);
} let _ = response_sender.send(res);
}
};
ctx.spawn("validate-from-chain-state", bg.boxed())?;
}, },
CandidateValidationMessage::ValidateFromExhaustive( CandidateValidationMessage::ValidateFromExhaustive(
persisted_validation_data, persisted_validation_data,
...@@ -163,50 +167,68 @@ where ...@@ -163,50 +167,68 @@ where
pov, pov,
response_sender, response_sender,
) => { ) => {
let _timer = metrics.time_validate_from_exhaustive(); let bg = {
let metrics = metrics.clone();
let res = validate_candidate_exhaustive( let validation_host = validation_host.clone();
&mut validation_host,
persisted_validation_data, async move {
validation_code, let _timer = metrics.time_validate_from_exhaustive();
descriptor, let res = validate_candidate_exhaustive(
pov, validation_host,
&metrics, persisted_validation_data,
) validation_code,
.await; descriptor,
pov,
match res { &metrics,
Ok(x) => { )
metrics.on_validation_event(&x); .await;
if let Err(_e) = response_sender.send(x) { metrics.on_validation_event(&res);
tracing::warn!( let _ = response_sender.send(res);
target: LOG_TARGET, }
"Requester of candidate validation dropped", };
)
} ctx.spawn("validate-from-exhaustive", bg.boxed())?;
},
Err(e) => return Err(e),
}
}, },
}, },
} }
} }
} }
async fn runtime_api_request<T, Context>( struct RuntimeRequestFailed;
ctx: &mut Context,
async fn runtime_api_request<T, Sender>(
sender: &mut Sender,
relay_parent: Hash, relay_parent: Hash,
request: RuntimeApiRequest, request: RuntimeApiRequest,
receiver: oneshot::Receiver<Result<T, RuntimeApiError>>, receiver: oneshot::Receiver<Result<T, RuntimeApiError>>,
) -> SubsystemResult<Result<T, RuntimeApiError>> ) -> Result<T, RuntimeRequestFailed>
where where
Context: SubsystemContext<Message = CandidateValidationMessage>, Sender: SubsystemSender,
Context: overseer::SubsystemContext<Message = CandidateValidationMessage>,
{ {
ctx.send_message(RuntimeApiMessage::Request(relay_parent, request)).await; sender
.send_message(RuntimeApiMessage::Request(relay_parent, request).into())
.await;
receiver.await.map_err(Into::into) receiver
.await
.map_err(|_| {
tracing::debug!(target: LOG_TARGET, ?relay_parent, "Runtime API request dropped");
RuntimeRequestFailed
})
.and_then(|res| {
res.map_err(|e| {
tracing::debug!(
target: LOG_TARGET,
?relay_parent,
err = ?e,
"Runtime API request internal error"
);
RuntimeRequestFailed
})
})
} }
#[derive(Debug)] #[derive(Debug)]
...@@ -216,61 +238,57 @@ enum AssumptionCheckOutcome { ...@@ -216,61 +238,57 @@ enum AssumptionCheckOutcome {
BadRequest, BadRequest,
} }
async fn check_assumption_validation_data<Context>( async fn check_assumption_validation_data<Sender>(
ctx: &mut Context, sender: &mut Sender,
descriptor: &CandidateDescriptor, descriptor: &CandidateDescriptor,
assumption: OccupiedCoreAssumption, assumption: OccupiedCoreAssumption,
) -> SubsystemResult<AssumptionCheckOutcome> ) -> AssumptionCheckOutcome
where where
Context: SubsystemContext<Message = CandidateValidationMessage>, Sender: SubsystemSender,
Context: overseer::SubsystemContext<Message = CandidateValidationMessage>,
{ {
let validation_data = { let validation_data = {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
let d = runtime_api_request( let d = runtime_api_request(
ctx, sender,
descriptor.relay_parent, descriptor.relay_parent,
RuntimeApiRequest::PersistedValidationData(descriptor.para_id, assumption, tx), RuntimeApiRequest::PersistedValidationData(descriptor.para_id, assumption, tx),
rx, rx,
) )
.await?; .await;
match d { match d {
Ok(None) | Err(_) => return Ok(AssumptionCheckOutcome::BadRequest), Ok(None) | Err(RuntimeRequestFailed) => return AssumptionCheckOutcome::BadRequest,
Ok(Some(d)) => d, Ok(Some(d)) => d,
} }
}; };
let persisted_validation_data_hash = validation_data.hash(); let persisted_validation_data_hash = validation_data.hash();
SubsystemResult::Ok( if descriptor.persisted_validation_data_hash == persisted_validation_data_hash {
if descriptor.persisted_validation_data_hash == persisted_validation_data_hash { let (code_tx, code_rx) = oneshot::channel();
let (code_tx, code_rx) = oneshot::channel(); let validation_code = runtime_api_request(
let validation_code = runtime_api_request( sender,
ctx, descriptor.relay_parent,
descriptor.relay_parent, RuntimeApiRequest::ValidationCode(descriptor.para_id, assumption, code_tx),
RuntimeApiRequest::ValidationCode(descriptor.para_id, assumption, code_tx), code_rx,
code_rx, )
) .await;
.await?;
match validation_code { match validation_code {
Ok(None) | Err(_) => AssumptionCheckOutcome::BadRequest, Ok(None) | Err(RuntimeRequestFailed) => AssumptionCheckOutcome::BadRequest,
Ok(Some(v)) => AssumptionCheckOutcome::Matches(validation_data, v), Ok(Some(v)) => AssumptionCheckOutcome::Matches(validation_data, v),
} }
} else { } else {
AssumptionCheckOutcome::DoesNotMatch AssumptionCheckOutcome::DoesNotMatch
}, }
)
} }
async fn find_assumed_validation_data<Context>( async fn find_assumed_validation_data<Sender>(
ctx: &mut Context, sender: &mut Sender,
descriptor: &CandidateDescriptor, descriptor: &CandidateDescriptor,
) -> SubsystemResult<AssumptionCheckOutcome> ) -> AssumptionCheckOutcome
where where
Context: SubsystemContext<Message = CandidateValidationMessage>, Sender: SubsystemSender,
Context: overseer::SubsystemContext<Message = CandidateValidationMessage>,
{ {
// The candidate descriptor has a `persisted_validation_data_hash` which corresponds to // The candidate descriptor has a `persisted_validation_data_hash` which corresponds to
// one of up to two possible values that we can derive from the state of the // one of up to two possible values that we can derive from the state of the
...@@ -287,41 +305,40 @@ where ...@@ -287,41 +305,40 @@ where
// Consider running these checks in parallel to reduce validation latency. // Consider running these checks in parallel to reduce validation latency.
for assumption in ASSUMPTIONS { for assumption in ASSUMPTIONS {
let outcome = check_assumption_validation_data(ctx, descriptor, *assumption).await?; let outcome = check_assumption_validation_data(sender, descriptor, *assumption).await;
match outcome { match outcome {
AssumptionCheckOutcome::Matches(_, _) => return Ok(outcome), AssumptionCheckOutcome::Matches(_, _) => return outcome,
AssumptionCheckOutcome::BadRequest => return Ok(outcome), AssumptionCheckOutcome::BadRequest => return outcome,
AssumptionCheckOutcome::DoesNotMatch => continue, AssumptionCheckOutcome::DoesNotMatch => continue,
} }
} }
Ok(AssumptionCheckOutcome::DoesNotMatch) AssumptionCheckOutcome::DoesNotMatch
} }
async fn spawn_validate_from_chain_state<Context>( async fn validate_from_chain_state<Sender>(
ctx: &mut Context, sender: &mut Sender,
validation_host: &mut ValidationHost, validation_host: ValidationHost,
descriptor: CandidateDescriptor, descriptor: CandidateDescriptor,
pov: Arc<PoV>, pov: Arc<PoV>,
metrics: &Metrics, metrics: &Metrics,
) -> SubsystemResult<Result<ValidationResult, ValidationFailed>> ) -> Result<ValidationResult, ValidationFailed>
where where
Context: SubsystemContext<Message = CandidateValidationMessage>, Sender: SubsystemSender,
Context: overseer::SubsystemContext<Message = CandidateValidationMessage>,
{ {
let (validation_data, validation_code) = let (validation_data, validation_code) =
match find_assumed_validation_data(ctx, &descriptor).await? { match find_assumed_validation_data(sender, &descriptor).await {
AssumptionCheckOutcome::Matches(validation_data, validation_code) => AssumptionCheckOutcome::Matches(validation_data, validation_code) =>
(validation_data, validation_code), (validation_data, validation_code),
AssumptionCheckOutcome::DoesNotMatch => { AssumptionCheckOutcome::DoesNotMatch => {
// If neither the assumption of the occupied core having the para included or the assumption // If neither the assumption of the occupied core having the para included or the assumption
// of the occupied core timing out are valid, then the persisted_validation_data_hash in the descriptor // of the occupied core timing out are valid, then the persisted_validation_data_hash in the descriptor
// is not based on the relay parent and is thus invalid. // is not based on the relay parent and is thus invalid.
return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::BadParent))) return Ok(ValidationResult::Invalid(InvalidCandidate::BadParent))
}, },
AssumptionCheckOutcome::BadRequest => AssumptionCheckOutcome::BadRequest =>
return Ok(Err(ValidationFailed("Assumption Check: Bad request".into()))), return Err(ValidationFailed("Assumption Check: Bad request".into())),
}; };
let validation_result = validate_candidate_exhaustive( let validation_result = validate_candidate_exhaustive(
...@@ -334,20 +351,20 @@ where ...@@ -334,20 +351,20 @@ where
) )
.await; .await;
if let Ok(Ok(ValidationResult::Valid(ref outputs, _))) = validation_result { if let Ok(ValidationResult::Valid(ref outputs, _)) = validation_result {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
match runtime_api_request( match runtime_api_request(
ctx, sender,
descriptor.relay_parent, descriptor.relay_parent,
RuntimeApiRequest::CheckValidationOutputs(descriptor.para_id, outputs.clone(), tx), RuntimeApiRequest::CheckValidationOutputs(descriptor.para_id, outputs.clone(), tx),
rx, rx,
) )
.await? .await
{ {
Ok(true) => {}, Ok(true) => {},
Ok(false) => return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::InvalidOutputs))), Ok(false) => return Ok(ValidationResult::Invalid(InvalidCandidate::InvalidOutputs)),
Err(_) => Err(RuntimeRequestFailed) =>
return Ok(Err(ValidationFailed("Check Validation Outputs: Bad request".into()))), return Err(ValidationFailed("Check Validation Outputs: Bad request".into())),
} }
} }
...@@ -361,7 +378,7 @@ async fn validate_candidate_exhaustive( ...@@ -361,7 +378,7 @@ async fn validate_candidate_exhaustive(
descriptor: CandidateDescriptor, descriptor: CandidateDescriptor,
pov: Arc<PoV>, pov: Arc<PoV>,
metrics: &Metrics, metrics: &Metrics,
) -> SubsystemResult<Result<ValidationResult, ValidationFailed>> { ) -> Result<ValidationResult, ValidationFailed> {
let _timer = metrics.time_validate_candidate_exhaustive(); let _timer = metrics.time_validate_candidate_exhaustive();
let validation_code_hash = validation_code.hash(); let validation_code_hash = validation_code.hash();
...@@ -378,7 +395,7 @@ async fn validate_candidate_exhaustive( ...@@ -378,7 +395,7 @@ async fn validate_candidate_exhaustive(
&*pov, &*pov,
&validation_code_hash, &validation_code_hash,
) { ) {
return Ok(Ok(ValidationResult::Invalid(e))) return Ok(ValidationResult::Invalid(e))
} }
let raw_validation_code = match sp_maybe_compressed_blob::decompress( let raw_validation_code = match sp_maybe_compressed_blob::decompress(
...@@ -390,7 +407,7 @@ async fn validate_candidate_exhaustive( ...@@ -390,7 +407,7 @@ async fn validate_candidate_exhaustive(
tracing::debug!(target: LOG_TARGET, err=?e, "Invalid validation code"); tracing::debug!(target: LOG_TARGET, err=?e, "Invalid validation code");
// If the validation code is invalid, the candidate certainly is. // If the validation code is invalid, the candidate certainly is.
return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure))) return Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure))
}, },
}; };
...@@ -401,7 +418,7 @@ async fn validate_candidate_exhaustive( ...@@ -401,7 +418,7 @@ async fn validate_candidate_exhaustive(
tracing::debug!(target: LOG_TARGET, err=?e, "Invalid PoV code"); tracing::debug!(target: LOG_TARGET, err=?e, "Invalid PoV code");
// If the PoV is invalid, the candidate certainly is. // If the PoV is invalid, the candidate certainly is.
return Ok(Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure))) return Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure))
}, },
}; };
...@@ -424,7 +441,7 @@ async fn validate_candidate_exhaustive( ...@@ -424,7 +441,7 @@ async fn validate_candidate_exhaustive(
); );
} }
let result = match result { match result {
Err(ValidationError::InternalError(e)) => Err(ValidationFailed(e)), Err(ValidationError::InternalError(e)) => Err(ValidationFailed(e)),
Err(ValidationError::InvalidCandidate(WasmInvalidCandidate::HardTimeout)) => Err(ValidationError::InvalidCandidate(WasmInvalidCandidate::HardTimeout)) =>
...@@ -450,9 +467,7 @@ async fn validate_candidate_exhaustive( ...@@ -450,9 +467,7 @@ async fn validate_candidate_exhaustive(
}; };
Ok(ValidationResult::Valid(outputs, persisted_validation_data)) Ok(ValidationResult::Valid(outputs, persisted_validation_data))
}, },
}; }
Ok(result)
} }
#[async_trait] #[async_trait]
...@@ -465,7 +480,7 @@ trait ValidationBackend { ...@@ -465,7 +480,7 @@ trait ValidationBackend {
} }
#[async_trait] #[async_trait]
impl ValidationBackend for &'_ mut ValidationHost { impl ValidationBackend for ValidationHost {
async fn validate_candidate( async fn validate_candidate(
&mut self, &mut self,
raw_validation_code: Vec<u8>, raw_validation_code: Vec<u8>,
......
...@@ -19,6 +19,7 @@ use assert_matches::assert_matches; ...@@ -19,6 +19,7 @@ use assert_matches::assert_matches;
use futures::executor; use futures::executor;
use polkadot_node_subsystem::messages::AllMessages; use polkadot_node_subsystem::messages::AllMessages;
use polkadot_node_subsystem_test_helpers as test_helpers; use polkadot_node_subsystem_test_helpers as test_helpers;
use polkadot_node_subsystem_util::reexports::SubsystemContext;
use polkadot_primitives::v1::{HeadData, UpwardMessage}; use polkadot_primitives::v1::{HeadData, UpwardMessage};
use sp_core::testing::TaskExecutor; use sp_core::testing::TaskExecutor;
use sp_keyring::Sr25519Keyring; use sp_keyring::Sr25519Keyring;
...@@ -52,11 +53,15 @@ fn correctly_checks_included_assumption() { ...@@ -52,11 +53,15 @@ fn correctly_checks_included_assumption() {
candidate.para_id = para_id; candidate.para_id = para_id;
let pool = TaskExecutor::new(); let pool = TaskExecutor::new();
let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); let (mut ctx, mut ctx_handle) =
test_helpers::make_subsystem_context::<AllMessages, _>(pool.clone());
let (check_fut, check_result) = let (check_fut, check_result) = check_assumption_validation_data(
check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::Included) ctx.sender(),
.remote_handle(); &candidate,
OccupiedCoreAssumption::Included,
)
.remote_handle();
let test_fut = async move { let test_fut = async move {
assert_matches!( assert_matches!(
...@@ -89,7 +94,7 @@ fn correctly_checks_included_assumption() { ...@@ -89,7 +94,7 @@ fn correctly_checks_included_assumption() {
} }
); );
assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::Matches(o, v) => { assert_matches!(check_result.await, AssumptionCheckOutcome::Matches(o, v) => {
assert_eq!(o, validation_data); assert_eq!(o, validation_data);
assert_eq!(v, validation_code); assert_eq!(v, validation_code);
}); });
...@@ -114,11 +119,15 @@ fn correctly_checks_timed_out_assumption() { ...@@ -114,11 +119,15 @@ fn correctly_checks_timed_out_assumption() {
candidate.para_id = para_id; candidate.para_id = para_id;
let pool = TaskExecutor::new(); let pool = TaskExecutor::new();
let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); let (mut ctx, mut ctx_handle) =
test_helpers::make_subsystem_context::<AllMessages, _>(pool.clone());
let (check_fut, check_result) = let (check_fut, check_result) = check_assumption_validation_data(
check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::TimedOut) ctx.sender(),
.remote_handle(); &candidate,
OccupiedCoreAssumption::TimedOut,
)
.remote_handle();
let test_fut = async move { let test_fut = async move {
assert_matches!( assert_matches!(
...@@ -151,7 +160,7 @@ fn correctly_checks_timed_out_assumption() { ...@@ -151,7 +160,7 @@ fn correctly_checks_timed_out_assumption() {
} }
); );
assert_matches!(check_result.await.unwrap(), AssumptionCheckOutcome::Matches(o, v) => { assert_matches!(check_result.await, AssumptionCheckOutcome::Matches(o, v) => {
assert_eq!(o, validation_data); assert_eq!(o, validation_data);
assert_eq!(v, validation_code); assert_eq!(v, validation_code);
}); });
...@@ -174,11 +183,15 @@ fn check_is_bad_request_if_no_validation_data() { ...@@ -174,11 +183,15 @@ fn check_is_bad_request_if_no_validation_data() {
candidate.para_id = para_id; candidate.para_id = para_id;
let pool = TaskExecutor::new(); let pool = TaskExecutor::new();
let (mut ctx, mut ctx_handle) = test_helpers::make_subsystem_context(pool.clone()); let (mut ctx, mut ctx_handle) =
test_helpers::make_subsystem_context::<AllMessages, _>(pool.clone());
let (check_fut, check_result) = let (check_fut, check_result) = check_assumption_validation_data(
check_assumption_validation_data(&mut ctx, &candidate, OccupiedCoreAssumption::Included) ctx.sender(),
.remote_handle(); &candidate,
OccupiedCoreAssumption::Included,
)