From 5d34c24504bdd2c535abac210a2878d0e8f20a3a Mon Sep 17 00:00:00 2001
From: Gavin Wood <gavin@parity.io>
Date: Thu, 14 Oct 2021 16:45:49 +0200
Subject: [PATCH] More lenient mechanism for identifying stash accounts in
 purge_keys (#10004)

* More lenient StashOf finder

* Slightly safer version of the change

* Slightly safer version of the change

* Tests

* Formatting
---
 substrate/frame/session/src/lib.rs   | 38 ++++++++++++++++++----------
 substrate/frame/session/src/mock.rs  | 23 ++++++++++++++---
 substrate/frame/session/src/tests.rs | 29 ++++++++++++++++++++-
 3 files changed, 73 insertions(+), 17 deletions(-)

diff --git a/substrate/frame/session/src/lib.rs b/substrate/frame/session/src/lib.rs
index 10c7ea42b3e..7fe163e0dfe 100644
--- a/substrate/frame/session/src/lib.rs
+++ b/substrate/frame/session/src/lib.rs
@@ -114,17 +114,6 @@ mod mock;
 mod tests;
 pub mod weights;
 
-use sp_runtime::{
-	traits::{AtLeast32BitUnsigned, Convert, Member, One, OpaqueKeys, Zero},
-	ConsensusEngineId, KeyTypeId, Permill, RuntimeAppPublic,
-};
-use sp_staking::SessionIndex;
-use sp_std::{
-	marker::PhantomData,
-	ops::{Rem, Sub},
-	prelude::*,
-};
-
 use frame_support::{
 	codec::{Decode, MaxEncodedLen},
 	dispatch::{DispatchError, DispatchResult},
@@ -136,6 +125,17 @@ use frame_support::{
 	weights::Weight,
 	Parameter,
 };
+use sp_runtime::{
+	traits::{AtLeast32BitUnsigned, Convert, Member, One, OpaqueKeys, Zero},
+	ConsensusEngineId, KeyTypeId, Permill, RuntimeAppPublic,
+};
+use sp_staking::SessionIndex;
+use sp_std::{
+	convert::TryFrom,
+	marker::PhantomData,
+	ops::{Rem, Sub},
+	prelude::*,
+};
 
 pub use pallet::*;
 pub use weights::WeightInfo;
@@ -377,7 +377,11 @@ pub mod pallet {
 		type Event: From<Event> + IsType<<Self as frame_system::Config>::Event>;
 
 		/// A stable ID for a validator.
-		type ValidatorId: Member + Parameter + MaybeSerializeDeserialize + MaxEncodedLen;
+		type ValidatorId: Member
+			+ Parameter
+			+ MaybeSerializeDeserialize
+			+ MaxEncodedLen
+			+ TryFrom<Self::AccountId>;
 
 		/// A conversion from account ID to validator ID.
 		///
@@ -595,9 +599,13 @@ pub mod pallet {
 		}
 
 		/// Removes any session key(s) of the function caller.
+		///
 		/// This doesn't take effect until the next session.
 		///
-		/// The dispatch origin of this function must be signed.
+		/// The dispatch origin of this function must be Signed and the account must be either be
+		/// convertible to a validator ID using the chain's typical addressing system (this usually
+		/// means being a controller account) or directly convertible into a validator ID (which
+		/// usually means being a stash account).
 		///
 		/// # <weight>
 		/// - Complexity: `O(1)` in number of key types. Actual cost depends on the number of length
@@ -841,6 +849,10 @@ impl<T: Config> Pallet<T> {
 
 	fn do_purge_keys(account: &T::AccountId) -> DispatchResult {
 		let who = T::ValidatorIdOf::convert(account.clone())
+			// `purge_keys` may not have a controller-stash pair any more. If so then we expect the
+			// stash account to be passed in directly and convert that to a `ValidatorId` using the
+			// `TryFrom` trait if supported.
+			.or_else(|| T::ValidatorId::try_from(account.clone()).ok())
 			.ok_or(Error::<T>::NoAssociatedValidatorId)?;
 
 		let old_keys = Self::take_keys(&who).ok_or(Error::<T>::NoKeys)?;
diff --git a/substrate/frame/session/src/mock.rs b/substrate/frame/session/src/mock.rs
index 277dec61065..6db7727fa53 100644
--- a/substrate/frame/session/src/mock.rs
+++ b/substrate/frame/session/src/mock.rs
@@ -22,13 +22,13 @@ use crate as pallet_session;
 #[cfg(feature = "historical")]
 use crate::historical as pallet_session_historical;
 
-use std::cell::RefCell;
+use std::{cell::RefCell, collections::BTreeMap};
 
 use sp_core::{crypto::key_types::DUMMY, H256};
 use sp_runtime::{
 	impl_opaque_keys,
 	testing::{Header, UintAuthorityId},
-	traits::{BlakeTwo256, ConvertInto, IdentityLookup},
+	traits::{BlakeTwo256, IdentityLookup},
 };
 use sp_staking::SessionIndex;
 
@@ -111,6 +111,7 @@ thread_local! {
 	pub static DISABLED: RefCell<bool> = RefCell::new(false);
 	// Stores if `on_before_session_end` was called
 	pub static BEFORE_SESSION_END_CALLED: RefCell<bool> = RefCell::new(false);
+	pub static VALIDATOR_ACCOUNTS: RefCell<BTreeMap<u64, u64>> = RefCell::new(BTreeMap::new());
 }
 
 pub struct TestShouldEndSession;
@@ -225,6 +226,10 @@ pub fn new_test_ext() -> sp_io::TestExternalities {
 	pallet_session::GenesisConfig::<Test> { keys }
 		.assimilate_storage(&mut t)
 		.unwrap();
+	NEXT_VALIDATORS.with(|l| {
+		let v = l.borrow().iter().map(|&i| (i, i)).collect();
+		VALIDATOR_ACCOUNTS.with(|m| *m.borrow_mut() = v);
+	});
 	sp_io::TestExternalities::new(t)
 }
 
@@ -268,6 +273,18 @@ impl pallet_timestamp::Config for Test {
 	type WeightInfo = ();
 }
 
+pub struct TestValidatorIdOf;
+impl TestValidatorIdOf {
+	pub fn set(v: BTreeMap<u64, u64>) {
+		VALIDATOR_ACCOUNTS.with(|m| *m.borrow_mut() = v);
+	}
+}
+impl Convert<u64, Option<u64>> for TestValidatorIdOf {
+	fn convert(x: u64) -> Option<u64> {
+		VALIDATOR_ACCOUNTS.with(|m| m.borrow().get(&x).cloned())
+	}
+}
+
 impl Config for Test {
 	type ShouldEndSession = TestShouldEndSession;
 	#[cfg(feature = "historical")]
@@ -276,7 +293,7 @@ impl Config for Test {
 	type SessionManager = TestSessionManager;
 	type SessionHandler = TestSessionHandler;
 	type ValidatorId = u64;
-	type ValidatorIdOf = ConvertInto;
+	type ValidatorIdOf = TestValidatorIdOf;
 	type Keys = MockSessionKeys;
 	type Event = Event;
 	type NextSessionRotation = ();
diff --git a/substrate/frame/session/src/tests.rs b/substrate/frame/session/src/tests.rs
index 42a2dd74fd9..308ed7c5e54 100644
--- a/substrate/frame/session/src/tests.rs
+++ b/substrate/frame/session/src/tests.rs
@@ -21,7 +21,7 @@ use super::*;
 use crate::mock::{
 	authorities, before_session_end_called, force_new_session, new_test_ext,
 	reset_before_session_end_called, session_changed, set_next_validators, set_session_length,
-	Origin, PreUpgradeMockSessionKeys, Session, System, Test, SESSION_CHANGED,
+	Origin, PreUpgradeMockSessionKeys, Session, System, Test, TestValidatorIdOf, SESSION_CHANGED,
 	TEST_SESSION_CHANGED,
 };
 
@@ -72,11 +72,35 @@ fn keys_cleared_on_kill() {
 	})
 }
 
+#[test]
+fn purge_keys_works_for_stash_id() {
+	let mut ext = new_test_ext();
+	ext.execute_with(|| {
+		assert_eq!(Session::validators(), vec![1, 2, 3]);
+		TestValidatorIdOf::set(vec![(10, 1), (20, 2), (3, 3)].into_iter().collect());
+		assert_eq!(Session::load_keys(&1), Some(UintAuthorityId(1).into()));
+		assert_eq!(Session::load_keys(&2), Some(UintAuthorityId(2).into()));
+
+		let id = DUMMY;
+		assert_eq!(Session::key_owner(id, UintAuthorityId(1).get_raw(id)), Some(1));
+
+		assert_ok!(Session::purge_keys(Origin::signed(10)));
+		assert_ok!(Session::purge_keys(Origin::signed(2)));
+
+		assert_eq!(Session::load_keys(&10), None);
+		assert_eq!(Session::load_keys(&20), None);
+		assert_eq!(Session::key_owner(id, UintAuthorityId(10).get_raw(id)), None);
+		assert_eq!(Session::key_owner(id, UintAuthorityId(20).get_raw(id)), None);
+	})
+}
+
 #[test]
 fn authorities_should_track_validators() {
 	reset_before_session_end_called();
 
 	new_test_ext().execute_with(|| {
+		TestValidatorIdOf::set(vec![(1, 1), (2, 2), (3, 3), (4, 4)].into_iter().collect());
+
 		set_next_validators(vec![1, 2]);
 		force_new_session();
 		initialize_block(1);
@@ -187,6 +211,8 @@ fn session_change_should_work() {
 #[test]
 fn duplicates_are_not_allowed() {
 	new_test_ext().execute_with(|| {
+		TestValidatorIdOf::set(vec![(1, 1), (2, 2), (3, 3), (4, 4)].into_iter().collect());
+
 		System::set_block_number(1);
 		Session::on_initialize(1);
 		assert_noop!(
@@ -205,6 +231,7 @@ fn session_changed_flag_works() {
 	reset_before_session_end_called();
 
 	new_test_ext().execute_with(|| {
+		TestValidatorIdOf::set(vec![(1, 1), (2, 2), (3, 3), (69, 69)].into_iter().collect());
 		TEST_SESSION_CHANGED.with(|l| *l.borrow_mut() = true);
 
 		force_new_session();
-- 
GitLab