From 2bd84151ede72c99d76a5ca3d4e9f015b4891837 Mon Sep 17 00:00:00 2001
From: Lldenaurois <ljdenaurois@gmail.com>
Date: Mon, 6 Sep 2021 13:24:04 +0200
Subject: [PATCH] Add tests and modify as_vec implementation (#3715)

* Add tests and modify as_vec implementation

* Address feedback

* fix typo in test
---
 polkadot/Cargo.lock                           |  1 +
 polkadot/erasure-coding/src/lib.rs            | 38 ++++++++++++-------
 .../src/requester/fetch_task/mod.rs           |  2 +-
 .../network/availability-recovery/src/lib.rs  |  8 ++--
 polkadot/node/primitives/Cargo.toml           |  3 ++
 polkadot/node/primitives/src/lib.rs           | 10 ++---
 6 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/polkadot/Cargo.lock b/polkadot/Cargo.lock
index a0c7c2b0318..b2ec0aa33dd 100644
--- a/polkadot/Cargo.lock
+++ b/polkadot/Cargo.lock
@@ -6297,6 +6297,7 @@ dependencies = [
  "bounded-vec",
  "futures 0.3.17",
  "parity-scale-codec",
+ "polkadot-erasure-coding",
  "polkadot-parachain",
  "polkadot-primitives",
  "schnorrkel",
diff --git a/polkadot/erasure-coding/src/lib.rs b/polkadot/erasure-coding/src/lib.rs
index 3c5e7f10e73..92f05fce10a 100644
--- a/polkadot/erasure-coding/src/lib.rs
+++ b/polkadot/erasure-coding/src/lib.rs
@@ -298,10 +298,10 @@ where
 
 /// Verify a merkle branch, yielding the chunk hash meant to be present at that
 /// index.
-pub fn branch_hash(root: &H256, branch_nodes: &[Vec<u8>], index: usize) -> Result<H256, Error> {
+pub fn branch_hash(root: &H256, branch_nodes: &Proof, index: usize) -> Result<H256, Error> {
 	let mut trie_storage: MemoryDB<Blake2Hasher> = MemoryDB::default();
 	for node in branch_nodes.iter() {
-		(&mut trie_storage as &mut trie::HashDB<_>).insert(EMPTY_PREFIX, node.as_slice());
+		(&mut trie_storage as &mut trie::HashDB<_>).insert(EMPTY_PREFIX, node);
 	}
 
 	let trie = TrieDB::new(&trie_storage, &root).map_err(|_| Error::InvalidBranchProof)?;
@@ -372,6 +372,10 @@ mod tests {
 	use super::*;
 	use polkadot_primitives::v0::{AvailableData, BlockData, PoVBlock};
 
+	// In order to adequately compute the number of entries in the Merkle
+	// trie, we must account for the fixed 16-ary trie structure.
+	const KEY_INDEX_NIBBLE_SIZE: usize = 4;
+
 	#[test]
 	fn field_order_is_right_size() {
 		assert_eq!(MAX_VALIDATORS, 65536);
@@ -404,28 +408,36 @@ mod tests {
 		assert_eq!(reconstructed, Err(Error::NotEnoughValidators));
 	}
 
-	#[test]
-	fn construct_valid_branches() {
-		let pov_block = PoVBlock { block_data: BlockData(vec![2; 256]) };
+	fn generate_trie_and_generate_proofs(magnitude: u32) {
+		let n_validators = 2_u32.pow(magnitude) as usize;
+		let pov_block =
+			PoVBlock { block_data: BlockData(vec![2; n_validators / KEY_INDEX_NIBBLE_SIZE]) };
 
 		let available_data = AvailableData { pov_block, omitted_validation: Default::default() };
 
-		let chunks = obtain_chunks(10, &available_data).unwrap();
+		let chunks = obtain_chunks(magnitude as usize, &available_data).unwrap();
 
-		assert_eq!(chunks.len(), 10);
+		assert_eq!(chunks.len() as u32, magnitude);
 
 		let branches = branches(chunks.as_ref());
 		let root = branches.root();
 
 		let proofs: Vec<_> = branches.map(|(proof, _)| proof).collect();
+		assert_eq!(proofs.len() as u32, magnitude);
+		for (i, proof) in proofs.into_iter().enumerate() {
+			let encode = Encode::encode(&proof);
+			let decode = Decode::decode(&mut &encode[..]).unwrap();
+			assert_eq!(proof, decode);
+			assert_eq!(encode, Encode::encode(&decode));
 
-		assert_eq!(proofs.len(), 10);
+			assert_eq!(branch_hash(&root, &proof, i).unwrap(), BlakeTwo256::hash(&chunks[i]));
+		}
+	}
 
-		for (i, proof) in proofs.into_iter().enumerate() {
-			assert_eq!(
-				branch_hash(&root, &proof.as_vec(), i).unwrap(),
-				BlakeTwo256::hash(&chunks[i])
-			);
+	#[test]
+	fn roundtrip_proof_encoding() {
+		for i in 2..16 {
+			generate_trie_and_generate_proofs(i);
 		}
 	}
 }
diff --git a/polkadot/node/network/availability-distribution/src/requester/fetch_task/mod.rs b/polkadot/node/network/availability-distribution/src/requester/fetch_task/mod.rs
index 4eed9440952..f1615d1f33a 100644
--- a/polkadot/node/network/availability-distribution/src/requester/fetch_task/mod.rs
+++ b/polkadot/node/network/availability-distribution/src/requester/fetch_task/mod.rs
@@ -363,7 +363,7 @@ impl RunningTask {
 
 	fn validate_chunk(&self, validator: &AuthorityDiscoveryId, chunk: &ErasureChunk) -> bool {
 		let anticipated_hash =
-			match branch_hash(&self.erasure_root, &chunk.proof_as_vec(), chunk.index.0 as usize) {
+			match branch_hash(&self.erasure_root, chunk.proof(), chunk.index.0 as usize) {
 				Ok(hash) => hash,
 				Err(e) => {
 					tracing::warn!(
diff --git a/polkadot/node/network/availability-recovery/src/lib.rs b/polkadot/node/network/availability-recovery/src/lib.rs
index 1860f9a28ae..bd747c463fe 100644
--- a/polkadot/node/network/availability-recovery/src/lib.rs
+++ b/polkadot/node/network/availability-recovery/src/lib.rs
@@ -363,11 +363,9 @@ impl RequestChunksPhase {
 
 					let validator_index = chunk.index;
 
-					if let Ok(anticipated_hash) = branch_hash(
-						&params.erasure_root,
-						&chunk.proof_as_vec(),
-						chunk.index.0 as usize,
-					) {
+					if let Ok(anticipated_hash) =
+						branch_hash(&params.erasure_root, chunk.proof(), chunk.index.0 as usize)
+					{
 						let erasure_chunk_hash = BlakeTwo256::hash(&chunk.chunk);
 
 						if erasure_chunk_hash != anticipated_hash {
diff --git a/polkadot/node/primitives/Cargo.toml b/polkadot/node/primitives/Cargo.toml
index c5b9ab05f40..d8b15b80800 100644
--- a/polkadot/node/primitives/Cargo.toml
+++ b/polkadot/node/primitives/Cargo.toml
@@ -23,3 +23,6 @@ serde = { version = "1.0.130", features = ["derive"] }
 
 [target.'cfg(not(target_os = "unknown"))'.dependencies]
 zstd = "0.6.0"
+
+[dev-dependencies]
+polkadot-erasure-coding = { path = "../../erasure-coding" }
diff --git a/polkadot/node/primitives/src/lib.rs b/polkadot/node/primitives/src/lib.rs
index 7503874796d..c2c300fca74 100644
--- a/polkadot/node/primitives/src/lib.rs
+++ b/polkadot/node/primitives/src/lib.rs
@@ -301,8 +301,8 @@ pub struct Proof(BoundedVec<BoundedVec<u8, 1, MERKLE_NODE_MAX_SIZE>, 1, MERKLE_P
 
 impl Proof {
 	/// This function allows to convert back to the standard nested Vec format
-	pub fn as_vec(&self) -> Vec<Vec<u8>> {
-		self.0.as_vec().iter().map(|v| v.as_vec().clone()).collect()
+	pub fn iter(&self) -> impl Iterator<Item = &[u8]> {
+		self.0.iter().map(|v| v.as_slice())
 	}
 
 	/// Construct an invalid dummy proof
@@ -365,7 +365,7 @@ impl Encode for Proof {
 	}
 
 	fn using_encoded<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
-		let temp = self.as_vec();
+		let temp = self.0.iter().map(|v| v.as_vec()).collect::<Vec<_>>();
 		temp.using_encoded(f)
 	}
 }
@@ -404,8 +404,8 @@ pub struct ErasureChunk {
 
 impl ErasureChunk {
 	/// Convert bounded Vec Proof to regular Vec<Vec<u8>>
-	pub fn proof_as_vec(&self) -> Vec<Vec<u8>> {
-		self.proof.as_vec()
+	pub fn proof(&self) -> &Proof {
+		&self.proof
 	}
 }
 
-- 
GitLab