Unverified Commit e5bab572 authored by asynchronous rob's avatar asynchronous rob Committed by GitHub
Browse files

Code, PoV compression and remove `CompressedPoV` struct (#2852)

* use compressed blob in candidate-validation

* add some tests for compressed code blobs

* remove CompressedPoV and apply compression in collation-generation

* decompress BlockData before executing

* don't produce oversized collations

* add test for PoV decompression failure

* fix tests and clean up

* fix test

* address review and fix CI

* take this )
parent c4dfad37
Pipeline #133452 failed with stages
in 19 minutes and 10 seconds
This diff is collapsed.
......@@ -13,7 +13,9 @@ polkadot-node-subsystem = { path = "../subsystem" }
polkadot-node-subsystem-util = { path = "../subsystem-util" }
polkadot-primitives = { path = "../../primitives" }
sp-core = { git = "https://github.com/paritytech/substrate", branch = "master" }
sp-maybe-compressed-blob = { git = "https://github.com/paritytech/substrate", branch = "master" }
thiserror = "1.0.23"
parity-scale-codec = { version = "2.0.0", default-features = false, features = ["bit-vec", "derive"] }
[dev-dependencies]
polkadot-node-subsystem-test-helpers = { path = "../subsystem-test-helpers" }
......@@ -26,7 +26,9 @@ use futures::{
sink::SinkExt,
stream::StreamExt,
};
use polkadot_node_primitives::{CollationGenerationConfig, AvailableData, PoV};
use polkadot_node_primitives::{
CollationGenerationConfig, AvailableData, PoV,
};
use polkadot_node_subsystem::{
messages::{AllMessages, CollationGenerationMessage, CollatorProtocolMessage},
FromOverseer, SpawnedSubsystem, Subsystem, SubsystemContext, SubsystemResult,
......@@ -41,6 +43,7 @@ use polkadot_primitives::v1::{
CandidateDescriptor, CandidateReceipt, CoreState, Hash, OccupiedCoreAssumption,
PersistedValidationData,
};
use parity_scale_codec::Encode;
use sp_core::crypto::Pair;
use std::sync::Arc;
......@@ -313,7 +316,32 @@ async fn handle_new_activations<Context: SubsystemContext>(
}
};
let pov_hash = collation.proof_of_validity.hash();
// Apply compression to the block data.
let pov = {
let pov = polkadot_node_primitives::maybe_compress_pov(collation.proof_of_validity);
let encoded_size = pov.encoded_size();
// As long as `POV_BOMB_LIMIT` is at least `max_pov_size`, this ensures
// that honest collators never produce a PoV which is uncompressed.
//
// As such, honest collators never produce an uncompressed PoV which starts with
// a compression magic number, which would lead validators to reject the collation.
if encoded_size > validation_data.max_pov_size as usize {
tracing::debug!(
target: LOG_TARGET,
para_id = %scheduled_core.para_id,
size = encoded_size,
max_size = validation_data.max_pov_size,
"PoV exceeded maximum size"
);
return
}
pov
};
let pov_hash = pov.hash();
let signature_payload = collator_signature_payload(
&relay_parent,
......@@ -326,7 +354,7 @@ async fn handle_new_activations<Context: SubsystemContext>(
let erasure_root = match erasure_root(
n_validators,
validation_data,
collation.proof_of_validity.clone(),
pov.clone(),
) {
Ok(erasure_root) => erasure_root,
Err(err) => {
......@@ -375,7 +403,7 @@ async fn handle_new_activations<Context: SubsystemContext>(
metrics.on_collation_generated();
if let Err(err) = task_sender.send(AllMessages::CollatorProtocol(
CollatorProtocolMessage::DistributeCollation(ccr, collation.proof_of_validity, result_sender)
CollatorProtocolMessage::DistributeCollation(ccr, pov, result_sender)
)).await {
tracing::warn!(
target: LOG_TARGET,
......@@ -492,7 +520,7 @@ mod tests {
task::{Context as FuturesContext, Poll},
Future,
};
use polkadot_node_primitives::{Collation, CollationResult, BlockData, PoV};
use polkadot_node_primitives::{Collation, CollationResult, BlockData, PoV, POV_BOMB_LIMIT};
use polkadot_node_subsystem::messages::{
AllMessages, RuntimeApiMessage, RuntimeApiRequest,
};
......@@ -500,8 +528,7 @@ mod tests {
subsystem_test_harness, TestSubsystemContextHandle,
};
use polkadot_primitives::v1::{
BlockNumber, CollatorPair, Id as ParaId,
PersistedValidationData, ScheduledCore, ValidationCode,
CollatorPair, Id as ParaId, PersistedValidationData, ScheduledCore, ValidationCode,
};
use std::pin::Pin;
......@@ -519,6 +546,24 @@ mod tests {
}
}
fn test_collation_compressed() -> Collation {
let mut collation = test_collation();
let compressed = PoV {
block_data: BlockData(sp_maybe_compressed_blob::compress(
&collation.proof_of_validity.block_data.0,
POV_BOMB_LIMIT,
).unwrap())
};
collation.proof_of_validity = compressed;
collation
}
fn test_validation_data() -> PersistedValidationData {
let mut persisted_validation_data: PersistedValidationData = Default::default();
persisted_validation_data.max_pov_size = 1024;
persisted_validation_data
}
// Box<dyn Future<Output = Collation> + Unpin + Send
struct TestCollator;
......@@ -715,7 +760,7 @@ mod tests {
tx,
),
))) => {
tx.send(Ok(Some(Default::default()))).unwrap();
tx.send(Ok(Some(test_validation_data()))).unwrap();
}
Some(AllMessages::RuntimeApi(RuntimeApiMessage::Request(
_hash,
......@@ -766,9 +811,8 @@ mod tests {
// we expect a single message to be sent, containing a candidate receipt.
// we don't care too much about the commitments_hash right now, but let's ensure that we've calculated the
// correct descriptor
let expect_pov_hash = test_collation().proof_of_validity.hash();
let expect_validation_data_hash
= PersistedValidationData::<Hash, BlockNumber>::default().hash();
let expect_pov_hash = test_collation_compressed().proof_of_validity.hash();
let expect_validation_data_hash = test_validation_data().hash();
let expect_relay_parent = Hash::repeat_byte(4);
let expect_validation_code_hash = ValidationCode(vec![1, 2, 3]).hash();
let expect_payload = collator_signature_payload(
......
......@@ -9,6 +9,7 @@ futures = "0.3.12"
tracing = "0.1.25"
sp-core = { package = "sp-core", git = "https://github.com/paritytech/substrate", branch = "master" }
sp-maybe-compressed-blob = { package = "sp-maybe-compressed-blob", git = "https://github.com/paritytech/substrate", branch = "master" }
parity-scale-codec = { version = "2.0.0", default-features = false, features = ["bit-vec", "derive"] }
polkadot-primitives = { path = "../../../primitives" }
......
......@@ -33,7 +33,9 @@ use polkadot_subsystem::{
};
use polkadot_node_subsystem_util::metrics::{self, prometheus};
use polkadot_subsystem::errors::RuntimeApiError;
use polkadot_node_primitives::{ValidationResult, InvalidCandidate, PoV};
use polkadot_node_primitives::{
VALIDATION_CODE_BOMB_LIMIT, POV_BOMB_LIMIT, ValidationResult, InvalidCandidate, PoV, BlockData,
};
use polkadot_primitives::v1::{
ValidationCode, CandidateDescriptor, PersistedValidationData,
OccupiedCoreAssumption, Hash, CandidateCommitments,
......@@ -368,12 +370,12 @@ fn perform_basic_checks(
pov: &PoV,
validation_code: &ValidationCode,
) -> Result<(), InvalidCandidate> {
let encoded_pov = pov.encode();
let pov_hash = pov.hash();
let validation_code_hash = validation_code.hash();
if encoded_pov.len() > max_pov_size as usize {
return Err(InvalidCandidate::ParamsTooLarge(encoded_pov.len() as u64));
let encoded_pov_size = pov.encoded_size();
if encoded_pov_size > max_pov_size as usize {
return Err(InvalidCandidate::ParamsTooLarge(encoded_pov_size as u64));
}
if pov_hash != candidate.pov_hash {
......@@ -396,7 +398,7 @@ trait ValidationBackend {
fn validate<S: SpawnNamed + 'static>(
arg: Self::Arg,
validation_code: &ValidationCode,
raw_validation_code: &[u8],
params: ValidationParams,
spawn: S,
) -> Result<WasmValidationResult, ValidationError>;
......@@ -409,12 +411,12 @@ impl ValidationBackend for RealValidationBackend {
fn validate<S: SpawnNamed + 'static>(
isolation_strategy: IsolationStrategy,
validation_code: &ValidationCode,
raw_validation_code: &[u8],
params: ValidationParams,
spawn: S,
) -> Result<WasmValidationResult, ValidationError> {
wasm_executor::validate_candidate(
&validation_code.0,
&raw_validation_code,
params,
&isolation_strategy,
spawn,
......@@ -446,14 +448,40 @@ fn validate_candidate_exhaustive<B: ValidationBackend, S: SpawnNamed + 'static>(
return Ok(ValidationResult::Invalid(e))
}
let raw_validation_code = match sp_maybe_compressed_blob::decompress(
&validation_code.0,
VALIDATION_CODE_BOMB_LIMIT,
) {
Ok(code) => code,
Err(e) => {
tracing::debug!(target: LOG_TARGET, err=?e, "Invalid validation code");
// If the validation code is invalid, the candidate certainly is.
return Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure));
}
};
let raw_block_data = match sp_maybe_compressed_blob::decompress(
&pov.block_data.0,
POV_BOMB_LIMIT,
) {
Ok(block_data) => BlockData(block_data.to_vec()),
Err(e) => {
tracing::debug!(target: LOG_TARGET, err=?e, "Invalid PoV code");
// If the PoV is invalid, the candidate certainly is.
return Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure));
}
};
let params = ValidationParams {
parent_head: persisted_validation_data.parent_head.clone(),
block_data: pov.block_data.clone(),
block_data: raw_block_data,
relay_parent_number: persisted_validation_data.relay_parent_number,
relay_parent_storage_root: persisted_validation_data.relay_parent_storage_root,
};
match B::validate(backend_arg, &validation_code, params, spawn) {
match B::validate(backend_arg, &raw_validation_code, params, spawn) {
Err(ValidationError::InvalidCandidate(WasmInvalidCandidate::Timeout)) =>
Ok(ValidationResult::Invalid(InvalidCandidate::Timeout)),
Err(ValidationError::InvalidCandidate(WasmInvalidCandidate::ParamsTooLarge(l, _))) =>
......@@ -580,7 +608,6 @@ mod tests {
use super::*;
use polkadot_node_subsystem_test_helpers as test_helpers;
use polkadot_primitives::v1::{HeadData, UpwardMessage};
use polkadot_node_primitives::BlockData;
use sp_core::testing::TaskExecutor;
use futures::executor;
use assert_matches::assert_matches;
......@@ -597,7 +624,7 @@ mod tests {
fn validate<S: SpawnNamed + 'static>(
arg: Self::Arg,
_validation_code: &ValidationCode,
_raw_validation_code: &[u8],
_params: ValidationParams,
_spawn: S,
) -> Result<WasmValidationResult, ValidationError> {
......@@ -1059,4 +1086,139 @@ mod tests {
assert_matches!(v, ValidationResult::Invalid(InvalidCandidate::CodeHashMismatch));
}
#[test]
fn compressed_code_works() {
let validation_data = PersistedValidationData { max_pov_size: 1024, ..Default::default() };
let pov = PoV { block_data: BlockData(vec![1; 32]) };
let head_data = HeadData(vec![1, 1, 1]);
let raw_code = vec![2u8; 16];
let validation_code = sp_maybe_compressed_blob::compress(
&raw_code,
VALIDATION_CODE_BOMB_LIMIT,
)
.map(ValidationCode)
.unwrap();
let mut descriptor = CandidateDescriptor::default();
descriptor.pov_hash = pov.hash();
descriptor.para_head = head_data.hash();
descriptor.validation_code_hash = validation_code.hash();
collator_sign(&mut descriptor, Sr25519Keyring::Alice);
let validation_result = WasmValidationResult {
head_data,
new_validation_code: None,
upward_messages: Vec::new(),
horizontal_messages: Vec::new(),
processed_downward_messages: 0,
hrmp_watermark: 0,
};
let v = validate_candidate_exhaustive::<MockValidationBackend, _>(
MockValidationArg { result: Ok(validation_result) },
validation_data,
validation_code,
descriptor,
Arc::new(pov),
TaskExecutor::new(),
&Default::default(),
);
assert_matches!(v, Ok(ValidationResult::Valid(_, _)));
}
#[test]
fn code_decompression_failure_is_invalid() {
let validation_data = PersistedValidationData { max_pov_size: 1024, ..Default::default() };
let pov = PoV { block_data: BlockData(vec![1; 32]) };
let head_data = HeadData(vec![1, 1, 1]);
let raw_code = vec![2u8; VALIDATION_CODE_BOMB_LIMIT + 1];
let validation_code = sp_maybe_compressed_blob::compress(
&raw_code,
VALIDATION_CODE_BOMB_LIMIT + 1,
)
.map(ValidationCode)
.unwrap();
let mut descriptor = CandidateDescriptor::default();
descriptor.pov_hash = pov.hash();
descriptor.para_head = head_data.hash();
descriptor.validation_code_hash = validation_code.hash();
collator_sign(&mut descriptor, Sr25519Keyring::Alice);
let validation_result = WasmValidationResult {
head_data,
new_validation_code: None,
upward_messages: Vec::new(),
horizontal_messages: Vec::new(),
processed_downward_messages: 0,
hrmp_watermark: 0,
};
let v = validate_candidate_exhaustive::<MockValidationBackend, _>(
MockValidationArg { result: Ok(validation_result) },
validation_data,
validation_code,
descriptor,
Arc::new(pov),
TaskExecutor::new(),
&Default::default(),
);
assert_matches!(
v,
Ok(ValidationResult::Invalid(InvalidCandidate::CodeDecompressionFailure))
);
}
#[test]
fn pov_decompression_failure_is_invalid() {
let validation_data = PersistedValidationData {
max_pov_size: POV_BOMB_LIMIT as u32,
..Default::default()
};
let head_data = HeadData(vec![1, 1, 1]);
let raw_block_data = vec![2u8; POV_BOMB_LIMIT + 1];
let pov = sp_maybe_compressed_blob::compress(
&raw_block_data,
POV_BOMB_LIMIT + 1,
)
.map(|raw| PoV { block_data: BlockData(raw) })
.unwrap();
let validation_code = ValidationCode(vec![2; 16]);
let mut descriptor = CandidateDescriptor::default();
descriptor.pov_hash = pov.hash();
descriptor.para_head = head_data.hash();
descriptor.validation_code_hash = validation_code.hash();
collator_sign(&mut descriptor, Sr25519Keyring::Alice);
let validation_result = WasmValidationResult {
head_data,
new_validation_code: None,
upward_messages: Vec::new(),
horizontal_messages: Vec::new(),
processed_downward_messages: 0,
hrmp_watermark: 0,
};
let v = validate_candidate_exhaustive::<MockValidationBackend, _>(
MockValidationArg { result: Ok(validation_result) },
validation_data,
validation_code,
descriptor,
Arc::new(pov),
TaskExecutor::new(),
&Default::default(),
);
assert_matches!(
v,
Ok(ValidationResult::Invalid(InvalidCandidate::PoVDecompressionFailure))
);
}
}
......@@ -24,7 +24,6 @@ use futures::channel::oneshot;
use polkadot_node_subsystem_util::Error as UtilError;
use polkadot_primitives::v1::SessionIndex;
use polkadot_node_primitives::CompressedPoVError;
use polkadot_subsystem::{errors::RuntimeApiError, SubsystemError};
use crate::LOG_TARGET;
......@@ -79,10 +78,6 @@ pub enum Error {
#[error("There was no session with the given index")]
NoSuchSession(SessionIndex),
/// Decompressing PoV failed.
#[error("PoV could not be decompressed")]
PoVDecompression(CompressedPoVError),
/// Fetching PoV failed with `RequestError`.
#[error("FetchPoV request error")]
FetchPoV(#[source] RequestError),
......
......@@ -152,9 +152,7 @@ async fn do_fetch_pov(
{
let response = pending_response.await.map_err(Error::FetchPoV)?;
let pov = match response {
PoVFetchingResponse::PoV(compressed) => {
compressed.decompress().map_err(Error::PoVDecompression)?
}
PoVFetchingResponse::PoV(pov) => pov,
PoVFetchingResponse::NoSuchPoV => {
return Err(Error::NoSuchPoV)
}
......@@ -244,7 +242,7 @@ mod tests {
use sp_core::testing::TaskExecutor;
use polkadot_primitives::v1::{CandidateHash, Hash, ValidatorIndex};
use polkadot_node_primitives::{BlockData, CompressedPoV};
use polkadot_node_primitives::BlockData;
use polkadot_subsystem_testhelpers as test_helpers;
use polkadot_subsystem::messages::{AvailabilityDistributionMessage, RuntimeApiMessage, RuntimeApiRequest};
......@@ -315,9 +313,8 @@ mod tests {
reqs.pop(),
Some(Requests::PoVFetching(outgoing)) => {outgoing}
);
req.pending_response.send(Ok(PoVFetchingResponse::PoV(
CompressedPoV::compress(&pov).unwrap()).encode()
)).unwrap();
req.pending_response.send(Ok(PoVFetchingResponse::PoV(pov.clone()).encode()))
.unwrap();
break
},
msg => tracing::debug!(target: LOG_TARGET, msg = ?msg, "Received msg"),
......
......@@ -16,11 +16,13 @@
//! Answer requests for availability chunks.
use std::sync::Arc;
use futures::channel::oneshot;
use polkadot_node_network_protocol::request_response::{request::IncomingRequest, v1};
use polkadot_primitives::v1::{CandidateHash, ValidatorIndex};
use polkadot_node_primitives::{AvailableData, CompressedPoV, ErasureChunk};
use polkadot_node_primitives::{AvailableData, ErasureChunk};
use polkadot_subsystem::{
messages::{AllMessages, AvailabilityStoreMessage},
SubsystemContext, jaeger,
......@@ -100,18 +102,7 @@ where
let response = match av_data {
None => v1::PoVFetchingResponse::NoSuchPoV,
Some(av_data) => {
let pov = match CompressedPoV::compress(&av_data.pov) {
Ok(pov) => pov,
Err(error) => {
tracing::error!(
target: LOG_TARGET,
error = ?error,
"Failed to create `CompressedPov`",
);
// this should really not happen, let this request time out:
return Err(Error::PoVDecompression(error))
}
};
let pov = Arc::try_unwrap(av_data.pov).unwrap_or_else(|a| (&*a).clone());
v1::PoVFetchingResponse::PoV(pov)
}
};
......
......@@ -42,7 +42,7 @@ use polkadot_node_subsystem_util::{
request_availability_cores,
metrics::{self, prometheus},
};
use polkadot_node_primitives::{SignedFullStatement, Statement, PoV, CompressedPoV};
use polkadot_node_primitives::{SignedFullStatement, Statement, PoV};
const COST_UNEXPECTED_MESSAGE: Rep = Rep::CostMinor("An unexpected message");
......@@ -660,27 +660,6 @@ async fn send_collation(
receipt: CandidateReceipt,
pov: PoV,
) {
let pov = match CompressedPoV::compress(&pov) {
Ok(compressed) => {
tracing::trace!(
target: LOG_TARGET,
size = %pov.block_data.0.len(),
compressed = %compressed.len(),
peer_id = ?request.peer,
"Sending collation."
);
compressed
},
Err(error) => {
tracing::error!(
target: LOG_TARGET,
?error,
"Failed to create `CompressedPov`",
);
return
}
};
if let Err(_) = request.send_response(CollationFetchingResponse::Collation(receipt, pov)) {
tracing::warn!(
target: LOG_TARGET,
......@@ -1519,7 +1498,7 @@ mod tests {
)
.expect("Decoding should work");
assert_eq!(receipt, candidate);
assert_eq!(pov.decompress().unwrap(), pov_block);
assert_eq!(pov, pov_block);
}
);
......
......@@ -1158,49 +1158,33 @@ where
modify_reputation(ctx, *peer_id, COST_WRONG_PARA).await;
}
Ok(CollationFetchingResponse::Collation(receipt, compressed_pov)) => {
match compressed_pov.decompress() {
Ok(pov) => {
tracing::debug!(
target: LOG_TARGET,
para_id = %para_id,
hash = ?hash,
candidate_hash = ?receipt.hash(),
"Received collation",
);
// Actual sending:
let _span = jaeger::Span::new(&pov, "received-collation");
let (mut tx, _) = oneshot::channel();
std::mem::swap(&mut tx, &mut (per_req.to_requester));
let result = tx.send((receipt, pov));
if let Err(_) = result {
tracing::warn!(
target: LOG_TARGET,
hash = ?hash,
para_id = ?para_id,
peer_id = ?peer_id,
"Sending response back to requester failed (receiving side closed)"
);
} else {
metrics_result = Ok(());
success = "true";
}
Ok(CollationFetchingResponse::Collation(receipt, pov)) => {
tracing::debug!(
target: LOG_TARGET,
para_id = %para_id,
hash = ?hash,
candidate_hash = ?receipt.hash(),
"Received collation",
);
}
Err(error) => {
tracing::warn!(
target: LOG_TARGET,
hash = ?hash,
para_id = ?para_id,
peer_id = ?peer_id,
?error,
"Failed to extract PoV",
);
modify_reputation(ctx, *peer_id, COST_CORRUPTED_MESSAGE).await;
}
};
// Actual sending:
let _span = jaeger::Span::new(&pov, "received-collation");
let (mut tx, _) = oneshot::channel();
std::mem::swap(&mut tx, &mut (per_req.to_requester));
let result = tx.send((receipt, pov));