From 9ca96d6fe7fedec888b3fa683f19303877d6f807 Mon Sep 17 00:00:00 2001
From: Arkadiy Paronyan <arkady.paronyan@gmail.com>
Date: Tue, 13 Aug 2019 09:19:54 +0200
Subject: [PATCH] More robust state pinning (#3355)

* Better state pinning

* Fixed pinning race

* Update core/state-db/src/noncanonical.rs

Co-Authored-By: Robert Habermeier <rphmeier@gmail.com>
---
 substrate/core/client/db/src/lib.rs         |   5 +-
 substrate/core/state-db/src/lib.rs          |  65 ++++++-----
 substrate/core/state-db/src/noncanonical.rs | 115 ++++++++++++++------
 3 files changed, 115 insertions(+), 70 deletions(-)

diff --git a/substrate/core/client/db/src/lib.rs b/substrate/core/client/db/src/lib.rs
index 967dcffc869..927359ecdf7 100644
--- a/substrate/core/client/db/src/lib.rs
+++ b/substrate/core/client/db/src/lib.rs
@@ -92,9 +92,6 @@ pub struct RefTrackingState<Block: BlockT> {
 
 impl<B: BlockT> RefTrackingState<B> {
 	fn new(state: DbState, storage: Arc<StorageDb<B>>, parent_hash: Option<B::Hash>) -> RefTrackingState<B> {
-		if let Some(hash) = &parent_hash {
-			storage.state_db.pin(hash);
-		}
 		RefTrackingState {
 			state,
 			parent_hash,
@@ -1401,7 +1398,7 @@ impl<Block> client::backend::Backend<Block, Blake2Hasher> for Backend<Block> whe
 		match self.blockchain.header(block) {
 			Ok(Some(ref hdr)) => {
 				let hash = hdr.hash();
-				if !self.storage.state_db.is_pruned(&hash, (*hdr.number()).saturated_into::<u64>()) {
+				if let Ok(()) = self.storage.state_db.pin(&hash) {
 					let root = H256::from_slice(hdr.state_root().as_ref());
 					let db_state = DbState::new(self.storage.clone(), root);
 					let state = RefTrackingState::new(db_state, self.storage.clone(), Some(hash.clone()));
diff --git a/substrate/core/state-db/src/lib.rs b/substrate/core/state-db/src/lib.rs
index 43820529d4e..81772e554bc 100644
--- a/substrate/core/state-db/src/lib.rs
+++ b/substrate/core/state-db/src/lib.rs
@@ -36,7 +36,7 @@ mod pruning;
 use std::fmt;
 use parking_lot::RwLock;
 use codec::Codec;
-use std::collections::{VecDeque, HashMap, hash_map::Entry};
+use std::collections::{HashMap, hash_map::Entry};
 use noncanonical::NonCanonicalOverlay;
 use pruning::RefWindow;
 use log::trace;
@@ -77,8 +77,12 @@ pub enum Error<E: fmt::Debug> {
 	InvalidBlockNumber,
 	/// Trying to insert block with unknown parent.
 	InvalidParent,
-	/// Canonicalization would discard pinned state.
-	DiscardingPinned,
+}
+
+/// Pinning error type.
+pub enum PinError {
+	/// Trying to pin invalid block.
+	InvalidBlock,
 }
 
 impl<E: fmt::Debug> From<codec::Error> for Error<E> {
@@ -95,7 +99,6 @@ impl<E: fmt::Debug> fmt::Debug for Error<E> {
 			Error::InvalidBlock => write!(f, "Trying to canonicalize invalid block"),
 			Error::InvalidBlockNumber => write!(f, "Trying to insert block with invalid number"),
 			Error::InvalidParent => write!(f, "Trying to insert block with unknown parent"),
-			Error::DiscardingPinned => write!(f, "Trying to discard pinned state"),
 		}
 	}
 }
@@ -173,7 +176,6 @@ fn to_meta_key<S: Codec>(suffix: &[u8], data: &S) -> Vec<u8> {
 struct StateDbSync<BlockHash: Hash, Key: Hash> {
 	mode: PruningMode,
 	non_canonical: NonCanonicalOverlay<BlockHash, Key>,
-	canonicalization_queue: VecDeque<BlockHash>,
 	pruning: Option<RefWindow<BlockHash, Key>>,
 	pinned: HashMap<BlockHash, u32>,
 }
@@ -195,7 +197,6 @@ impl<BlockHash: Hash, Key: Hash> StateDbSync<BlockHash, Key> {
 			non_canonical,
 			pruning,
 			pinned: Default::default(),
-			canonicalization_queue: Default::default(),
 		})
 	}
 
@@ -220,26 +221,16 @@ impl<BlockHash: Hash, Key: Hash> StateDbSync<BlockHash, Key> {
 		if self.mode == PruningMode::ArchiveAll {
 			return Ok(commit)
 		}
-		self.canonicalization_queue.push_back(hash.clone());
-		while let Some(hash) = self.canonicalization_queue.front().cloned() {
-			if self.pinned.contains_key(&hash) {
-				break;
-			}
-			match self.non_canonical.canonicalize(&hash, &self.pinned, &mut commit) {
-				Ok(()) => {
-					self.canonicalization_queue.pop_front();
-					if self.mode == PruningMode::ArchiveCanonical {
-						commit.data.deleted.clear();
-					}
-				}
-				Err(Error::DiscardingPinned) => {
-					break;
+		match self.non_canonical.canonicalize(&hash, &mut commit) {
+			Ok(()) => {
+				if self.mode == PruningMode::ArchiveCanonical {
+					commit.data.deleted.clear();
 				}
-				Err(e) => return Err(e),
-			};
-			if let Some(ref mut pruning) = self.pruning {
-				pruning.note_canonical(&hash, &mut commit);
 			}
+			Err(e) => return Err(e),
+		};
+		if let Some(ref mut pruning) = self.pruning {
+			pruning.note_canonical(&hash, &mut commit);
 		}
 		self.prune(&mut commit);
 		Ok(commit)
@@ -296,12 +287,25 @@ impl<BlockHash: Hash, Key: Hash> StateDbSync<BlockHash, Key> {
 		}
 	}
 
-	pub fn pin(&mut self, hash: &BlockHash) {
-		let refs = self.pinned.entry(hash.clone()).or_default();
-		if *refs == 0 {
-			trace!(target: "state-db", "Pinned block: {:?}", hash);
+	pub fn pin(&mut self, hash: &BlockHash) -> Result<(), PinError> {
+		match self.mode {
+			PruningMode::ArchiveAll => Ok(()),
+			PruningMode::ArchiveCanonical | PruningMode::Constrained(_) => {
+				if self.non_canonical.have_block(hash) ||
+					self.pruning.as_ref().map_or(false, |pruning| pruning.have_block(hash))
+				{
+					let refs = self.pinned.entry(hash.clone()).or_default();
+					if *refs == 0 {
+						trace!(target: "state-db", "Pinned block: {:?}", hash);
+						self.non_canonical.pin(hash);
+					}
+					*refs += 1;
+					Ok(())
+				} else {
+					Err(PinError::InvalidBlock)
+				}
+			}
 		}
-		*refs += 1
 	}
 
 	pub fn unpin(&mut self, hash: &BlockHash) {
@@ -311,6 +315,7 @@ impl<BlockHash: Hash, Key: Hash> StateDbSync<BlockHash, Key> {
 				if *entry.get() == 0 {
 					trace!(target: "state-db", "Unpinned block: {:?}", hash);
 					entry.remove();
+					self.non_canonical.unpin(hash);
 				} else {
 					trace!(target: "state-db", "Releasing reference for {:?}", hash);
 				}
@@ -375,7 +380,7 @@ impl<BlockHash: Hash, Key: Hash> StateDb<BlockHash, Key> {
 	}
 
 	/// Prevents pruning of specified block and its descendants.
-	pub fn pin(&self, hash: &BlockHash) {
+	pub fn pin(&self, hash: &BlockHash) -> Result<(), PinError> {
 		self.db.write().pin(hash)
 	}
 
diff --git a/substrate/core/state-db/src/noncanonical.rs b/substrate/core/state-db/src/noncanonical.rs
index 1017b8a2118..58715715ccd 100644
--- a/substrate/core/state-db/src/noncanonical.rs
+++ b/substrate/core/state-db/src/noncanonical.rs
@@ -37,6 +37,7 @@ pub struct NonCanonicalOverlay<BlockHash: Hash, Key: Hash> {
 	pending_canonicalizations: Vec<BlockHash>,
 	pending_insertions: Vec<BlockHash>,
 	values: HashMap<Key, (u32, DBValue)>, //ref counted
+	pinned: HashMap<BlockHash, HashMap<Key, DBValue>>, //would be deleted but kept around because block is pinned
 }
 
 #[derive(Encode, Decode)]
@@ -67,14 +68,21 @@ fn insert_values<Key: Hash>(values: &mut HashMap<Key, (u32, DBValue)>, inserted:
 	}
 }
 
-fn discard_values<Key: Hash>(values: &mut HashMap<Key, (u32, DBValue)>, inserted: Vec<Key>) {
+fn discard_values<Key: Hash>(
+	values: &mut HashMap<Key, (u32, DBValue)>,
+	inserted: Vec<Key>,
+	mut into: Option<&mut HashMap<Key, DBValue>>,
+) {
 	for k in inserted {
 		match values.entry(k) {
 			Entry::Occupied(mut e) => {
 				let (ref mut counter, _) = e.get_mut();
 				*counter -= 1;
 				if *counter == 0 {
-					e.remove();
+					let (key, (_, value)) = e.remove_entry();
+					if let Some(ref mut into) = into {
+						into.insert(key, value);
+					}
 				}
 			},
 			Entry::Vacant(_) => {
@@ -89,8 +97,9 @@ fn discard_descendants<BlockHash: Hash, Key: Hash>(
 	mut values: &mut HashMap<Key, (u32, DBValue)>,
 	index: usize,
 	parents: &mut HashMap<BlockHash, BlockHash>,
+	pinned: &mut HashMap<BlockHash, HashMap<Key, DBValue>>,
 	hash: &BlockHash,
-	) {
+) {
 	let mut discarded = Vec::new();
 	if let Some(level) = levels.get_mut(index) {
 		*level = level.drain(..).filter_map(|overlay| {
@@ -98,7 +107,7 @@ fn discard_descendants<BlockHash: Hash, Key: Hash>(
 			if parent == *hash {
 				parents.remove(&overlay.hash);
 				discarded.push(overlay.hash);
-				discard_values(&mut values, overlay.inserted);
+				discard_values(&mut values, overlay.inserted, pinned.get_mut(hash));
 				None
 			} else {
 				Some(overlay)
@@ -106,7 +115,7 @@ fn discard_descendants<BlockHash: Hash, Key: Hash>(
 		}).collect();
 	}
 	for hash in discarded {
-		discard_descendants(levels, values, index + 1, parents, &hash);
+		discard_descendants(levels, values, index + 1, parents, pinned, &hash);
 	}
 }
 
@@ -166,6 +175,7 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 			parents,
 			pending_canonicalizations: Default::default(),
 			pending_insertions: Default::default(),
+			pinned: Default::default(),
 			values: values,
 		})
 	}
@@ -278,7 +288,6 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 	pub fn canonicalize<E: fmt::Debug>(
 		&mut self,
 		hash: &BlockHash,
-		pinned: &HashMap<BlockHash, u32>,
 		commit: &mut CommitSet<Key>,
 	) -> Result<(), Error<E>> {
 		trace!(target: "state-db", "Canonicalizing {:?}", hash);
@@ -303,13 +312,6 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 			discarded_blocks.push(overlay.hash.clone());
 		}
 
-		for hash in discarded_blocks.into_iter() {
-			if pinned.contains_key(&hash) {
-				trace!(target: "state-db", "Refusing to discard pinned state {:?}", hash);
-				return Err(Error::DiscardingPinned)
-			}
-		}
-
 		// get the one we need to canonicalize
 		let overlay = &level[index];
 		commit.data.inserted.extend(overlay.inserted.iter()
@@ -339,9 +341,16 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 			for (i, overlay) in level.into_iter().enumerate() {
 				self.parents.remove(&overlay.hash);
 				if i != index {
-					discard_descendants(&mut self.levels, &mut self.values, 0, &mut self.parents, &overlay.hash);
+					discard_descendants(
+						&mut self.levels,
+						&mut self.values,
+						0,
+						&mut self.parents,
+						&mut self.pinned,
+						&overlay.hash,
+					);
 				}
-				discard_values(&mut self.values, overlay.inserted);
+				discard_values(&mut self.values, overlay.inserted, self.pinned.get_mut(&overlay.hash));
 			}
 		}
 		if let Some(hash) = last {
@@ -355,6 +364,11 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 		if let Some((_, value)) = self.values.get(&key) {
 			return Some(value.clone());
 		}
+		for pinned in self.pinned.values() {
+			if let Some(value) = pinned.get(&key) {
+				return Some(value.clone());
+			}
+		}
 		None
 	}
 
@@ -371,7 +385,7 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 			for overlay in level.into_iter() {
 				commit.meta.deleted.push(overlay.journal_key);
 				self.parents.remove(&overlay.hash);
-				discard_values(&mut self.values, overlay.inserted);
+				discard_values(&mut self.values, overlay.inserted, None);
 			}
 			commit
 		})
@@ -388,7 +402,7 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 				.expect("Hash is added in insert");
 
 			let	overlay = self.levels[level_index].pop().expect("Empty levels are not allowed in self.levels");
-			discard_values(&mut self.values, overlay.inserted);
+			discard_values(&mut self.values, overlay.inserted, None);
 			if self.levels[level_index].is_empty() {
 				debug_assert_eq!(level_index, self.levels.len() - 1);
 				self.levels.pop_back();
@@ -407,11 +421,21 @@ impl<BlockHash: Hash, Key: Hash> NonCanonicalOverlay<BlockHash, Key> {
 		self.pending_canonicalizations.clear();
 		self.revert_insertions();
 	}
+
+	/// Pin state values in memory
+	pub fn pin(&mut self, hash: &BlockHash) {
+		self.pinned.insert(hash.clone(), HashMap::default());
+	}
+
+	/// Discard pinned state
+	pub fn unpin(&mut self, hash: &BlockHash) {
+		self.pinned.remove(hash);
+	}
 }
 
 #[cfg(test)]
 mod tests {
-	use std::{collections::HashMap, io};
+	use std::io;
 	use primitives::H256;
 	use super::{NonCanonicalOverlay, to_journal_key};
 	use crate::{ChangeSet, CommitSet};
@@ -436,7 +460,7 @@ mod tests {
 		let db = make_db(&[]);
 		let mut overlay = NonCanonicalOverlay::<H256, H256>::new(&db).unwrap();
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&H256::default(), &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&H256::default(), &mut commit).unwrap();
 	}
 
 	#[test]
@@ -481,7 +505,7 @@ mod tests {
 		let mut overlay = NonCanonicalOverlay::<H256, H256>::new(&db).unwrap();
 		overlay.insert::<io::Error>(&h1, 1, &H256::default(), ChangeSet::default()).unwrap();
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h2, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h2, &mut commit).unwrap();
 	}
 
 	#[test]
@@ -497,7 +521,7 @@ mod tests {
 		assert_eq!(insertion.meta.deleted.len(), 0);
 		db.commit(&insertion);
 		let mut finalization = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h1, &HashMap::default(), &mut finalization).unwrap();
+		overlay.canonicalize::<io::Error>(&h1, &mut finalization).unwrap();
 		assert_eq!(finalization.data.inserted.len(), changeset.inserted.len());
 		assert_eq!(finalization.data.deleted.len(), changeset.deleted.len());
 		assert_eq!(finalization.meta.inserted.len(), 1);
@@ -531,7 +555,7 @@ mod tests {
 		db.commit(&overlay.insert::<io::Error>(&h1, 10, &H256::default(), make_changeset(&[3, 4], &[2])).unwrap());
 		db.commit(&overlay.insert::<io::Error>(&h2, 11, &h1, make_changeset(&[5], &[3])).unwrap());
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h1, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h1, &mut commit).unwrap();
 		db.commit(&commit);
 		overlay.apply_pending();
 		assert_eq!(overlay.levels.len(), 1);
@@ -558,7 +582,7 @@ mod tests {
 		assert_eq!(overlay.levels.len(), 2);
 		assert_eq!(overlay.parents.len(), 2);
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h1, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h1, &mut commit).unwrap();
 		db.commit(&commit);
 		assert!(contains(&overlay, 5));
 		assert_eq!(overlay.levels.len(), 2);
@@ -569,7 +593,7 @@ mod tests {
 		assert!(!contains(&overlay, 5));
 		assert!(contains(&overlay, 7));
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h2, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h2, &mut commit).unwrap();
 		db.commit(&commit);
 		overlay.apply_pending();
 		assert_eq!(overlay.levels.len(), 0);
@@ -588,7 +612,7 @@ mod tests {
 		db.commit(&overlay.insert::<io::Error>(&h_2, 1, &H256::default(), c_2).unwrap());
 		assert!(contains(&overlay, 1));
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h_1, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h_1, &mut commit).unwrap();
 		db.commit(&commit);
 		assert!(contains(&overlay, 1));
 		overlay.apply_pending();
@@ -607,8 +631,8 @@ mod tests {
 		db.commit(&overlay.insert::<io::Error>(&h2, 2, &h1, changeset.clone()).unwrap());
 		overlay.apply_pending();
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h1, &HashMap::default(), &mut commit).unwrap();
-		overlay.canonicalize::<io::Error>(&h2, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h1, &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h2, &mut commit).unwrap();
 		db.commit(&commit);
 		db.commit(&overlay.insert::<io::Error>(&h3, 3, &h2, changeset.clone()).unwrap());
 		overlay.apply_pending();
@@ -679,7 +703,7 @@ mod tests {
 
 		// canonicalize 1. 2 and all its children should be discarded
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h_1, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h_1, &mut commit).unwrap();
 		db.commit(&commit);
 		overlay.apply_pending();
 		assert_eq!(overlay.levels.len(), 2);
@@ -698,14 +722,9 @@ mod tests {
 		assert!(db.get_meta(&to_journal_key(2, 2)).unwrap().is_none());
 		assert!(db.get_meta(&to_journal_key(2, 3)).unwrap().is_none());
 
-		// check that discarding pinned state produces an error.
-		let mut commit = CommitSet::default();
-		let pinned = vec![(h_1_1_1, 1)].into_iter().collect();
-		assert!(overlay.canonicalize::<io::Error>(&h_1_2, &pinned, &mut commit).is_err());
-
 		// canonicalize 1_2. 1_1 and all its children should be discarded
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h_1_2, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h_1_2, &mut commit).unwrap();
 		db.commit(&commit);
 		overlay.apply_pending();
 		assert_eq!(overlay.levels.len(), 1);
@@ -722,7 +741,7 @@ mod tests {
 
 		// canonicalize 1_2_2
 		let mut commit = CommitSet::default();
-		overlay.canonicalize::<io::Error>(&h_1_2_2, &HashMap::default(), &mut commit).unwrap();
+		overlay.canonicalize::<io::Error>(&h_1_2_2, &mut commit).unwrap();
 		db.commit(&commit);
 		overlay.apply_pending();
 		assert_eq!(overlay.levels.len(), 0);
@@ -777,5 +796,29 @@ mod tests {
 		assert_eq!(overlay.levels.len(), 0);
 		assert_eq!(overlay.parents.len(), 0);
 	}
-}
 
+	#[test]
+	fn keeps_pinned() {
+		let mut db = make_db(&[]);
+
+		// - 1 - 1_1
+		//     \ 1_2
+
+		let (h_1, c_1) = (H256::random(), make_changeset(&[1], &[]));
+		let (h_2, c_2) = (H256::random(), make_changeset(&[2], &[]));
+
+		let mut overlay = NonCanonicalOverlay::<H256, H256>::new(&db).unwrap();
+		db.commit(&overlay.insert::<io::Error>(&h_1, 1, &H256::default(), c_1).unwrap());
+		db.commit(&overlay.insert::<io::Error>(&h_2, 1, &H256::default(), c_2).unwrap());
+
+		overlay.pin(&h_1);
+
+		let mut commit = CommitSet::default();
+		overlay.canonicalize::<io::Error>(&h_2, &mut commit).unwrap();
+		db.commit(&commit);
+		overlay.apply_pending();
+		assert!(contains(&overlay, 1));
+		overlay.unpin(&h_1);
+		assert!(!contains(&overlay, 1));
+	}
+}
-- 
GitLab