diff --git a/substrate/client/finality-grandpa/src/until_imported.rs b/substrate/client/finality-grandpa/src/until_imported.rs index c3a52fcf56fe197602ba844e61500686820f42cf..c3804e1272ffa62ea5f3154aad6ccb520e980a90 100644 --- a/substrate/client/finality-grandpa/src/until_imported.rs +++ b/substrate/client/finality-grandpa/src/until_imported.rs @@ -40,7 +40,7 @@ use sp_runtime::traits::{Block as BlockT, Header as HeaderT, NumberFor}; use std::collections::{HashMap, VecDeque}; use std::pin::Pin; -use std::sync::{atomic::{AtomicUsize, Ordering}, Arc}; +use std::sync::Arc; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use sp_finality_grandpa::AuthorityId; @@ -307,8 +307,12 @@ pub(crate) type UntilVoteTargetImported<Block, BlockStatus, BlockSyncRequester, /// /// This is used for compact commits and catch up messages which have already /// been checked for structural soundness (e.g. valid signatures). +/// +/// We use the `Arc`'s reference count to implicitly count the number of outstanding blocks that we +/// are waiting on for the same message (i.e. other `BlockGlobalMessage` instances with the same +/// `inner`). pub(crate) struct BlockGlobalMessage<Block: BlockT> { - inner: Arc<(AtomicUsize, Mutex<Option<CommunicationIn<Block>>>)>, + inner: Arc<Mutex<Option<CommunicationIn<Block>>>>, target_number: NumberFor<Block>, } @@ -416,7 +420,7 @@ impl<Block: BlockT> BlockUntilImported<Block> for BlockGlobalMessage<Block> { return Ok(()) } - let locked_global = Arc::new((AtomicUsize::new(unknown_count), Mutex::new(Some(input)))); + let locked_global = Arc::new(Mutex::new(Some(input))); // schedule waits for all unknown messages. // when the last one of these has `wait_completed` called on it, @@ -438,30 +442,20 @@ impl<Block: BlockT> BlockUntilImported<Block> for BlockGlobalMessage<Block> { fn wait_completed(self, canon_number: NumberFor<Block>) -> Option<Self::Blocked> { if self.target_number != canon_number { - // if we return without deducting the counter, then none of the other - // handles can return the commit message. + // Delete the inner message so it won't ever be forwarded. Future calls to + // `wait_completed` on the same `inner` will ignore it. + *self.inner.lock() = None; return None; } - let mut last_count = self.inner.0.load(Ordering::Acquire); - - // CAS loop to ensure that we always have a last reader. - loop { - if last_count == 1 { // we are the last one left. - return self.inner.1.lock().take(); - } - - let prev_value = self.inner.0.compare_and_swap( - last_count, - last_count - 1, - Ordering::SeqCst, - ); - - if prev_value == last_count { - return None; - } else { - last_count = prev_value; - } + match Arc::try_unwrap(self.inner) { + // This is the last reference and thus the last outstanding block to be awaited. `inner` + // is either `Some(_)` or `None`. The latter implies that a previous `wait_completed` + // call witnessed a block number mismatch (see above). + Ok(inner) => Mutex::into_inner(inner), + // There are still other strong references to this `Arc`, thus the message is blocked on + // other blocks to be imported. + Err(_) => None, } } } @@ -941,4 +935,67 @@ mod tests { futures::executor::block_on(test); } + + fn test_catch_up() -> Arc<Mutex<Option<CommunicationIn<Block>>>> { + let header = make_header(5); + + let unknown_catch_up = finality_grandpa::CatchUp { + round_number: 1, + precommits: vec![], + prevotes: vec![], + base_hash: header.hash(), + base_number: *header.number(), + }; + + let catch_up = voter::CommunicationIn::CatchUp( + unknown_catch_up.clone(), + voter::Callback::Blank, + ); + + Arc::new(Mutex::new(Some(catch_up))) + } + + #[test] + fn block_global_message_wait_completed_return_when_all_awaited() { + let msg_inner = test_catch_up(); + + let waiting_block_1 = BlockGlobalMessage::<Block> { + inner: msg_inner.clone(), + target_number: 1, + }; + + let waiting_block_2 = BlockGlobalMessage::<Block> { + inner: msg_inner, + target_number: 2, + }; + + // waiting_block_2 is still waiting for block 2, thus this should return `None`. + assert!(waiting_block_1.wait_completed(1).is_none()); + + // Message only depended on block 1 and 2. Both have been imported, thus this should yield + // the message. + assert!(waiting_block_2.wait_completed(2).is_some()); + } + + #[test] + fn block_global_message_wait_completed_return_none_on_block_number_missmatch() { + let msg_inner = test_catch_up(); + + let waiting_block_1 = BlockGlobalMessage::<Block> { + inner: msg_inner.clone(), + target_number: 1, + }; + + let waiting_block_2 = BlockGlobalMessage::<Block> { + inner: msg_inner, + target_number: 2, + }; + + // Calling wait_completed with wrong block number should yield None. + assert!(waiting_block_1.wait_completed(1234).is_none()); + + // All blocks, that the message depended on, have been imported. Still, given the above + // block number mismatch this should return None. + assert!(waiting_block_2.wait_completed(2).is_none()); + } }