From bf1f5c803e298452c9406b72a7eec28b024919c0 Mon Sep 17 00:00:00 2001
From: Liam Aharon <liam.aharon@hotmail.com>
Date: Mon, 1 Apr 2024 16:54:21 +0400
Subject: [PATCH] wip logic

---
 substrate/frame/staking-rewards/src/lib.rs   | 201 +++++++++--
 substrate/frame/staking-rewards/src/tests.rs | 341 ++++++++++++-------
 2 files changed, 377 insertions(+), 165 deletions(-)

diff --git a/substrate/frame/staking-rewards/src/lib.rs b/substrate/frame/staking-rewards/src/lib.rs
index 9504eda3478..b74d4b50eb4 100644
--- a/substrate/frame/staking-rewards/src/lib.rs
+++ b/substrate/frame/staking-rewards/src/lib.rs
@@ -76,12 +76,18 @@ mod tests;
 /// The type of the unique id for each pool.
 pub type PoolId = u32;
 
+/// Multiplier to maintain precision when calculating rewards.
+pub(crate) const PRECISION_SCALING_FACTOR: u32 = u32::MAX;
+
 /// A pool staker.
-#[derive(Decode, Encode, MaxEncodedLen, TypeInfo)]
+#[derive(Default, Decode, Encode, MaxEncodedLen, TypeInfo)]
 pub struct PoolStakerInfo<Balance> {
+	/// Amount of tokens staked.
 	amount: Balance,
+	/// Accumulated, unpaid rewards.
 	rewards: Balance,
-	reward_debt: Balance,
+	/// Reward per token value at the time of the staker's last interaction with the contract.
+	reward_per_token_paid: Balance,
 }
 
 /// A staking pool.
@@ -95,20 +101,24 @@ pub struct PoolInfo<AccountId, AssetId, Balance, BlockNumber> {
 	reward_rate_per_block: Balance,
 	/// The total amount of tokens staked in this pool.
 	total_tokens_staked: Balance,
-	/// Total accumulated rewards per share. Used when calculating payouts.
-	accumulated_rewards_per_share: Balance,
+	/// Total rewards accumulated per token, up to the last time the rewards were updated.
+	reward_per_token_stored: Balance,
 	/// Last block number the pool was updated. Used when calculating payouts.
-	last_rewarded_block: BlockNumber,
+	last_update_block: BlockNumber,
 	/// Permissioned account that can manage this pool.
 	admin: AccountId,
 }
 
 #[frame_support::pallet(dev_mode)]
 pub mod pallet {
+
 	use super::*;
-	use frame_support::{pallet_prelude::*, traits::tokens::AssetId};
+	use frame_support::{
+		pallet_prelude::*,
+		traits::tokens::{AssetId, Preservation},
+	};
 	use frame_system::pallet_prelude::*;
-	use sp_runtime::traits::{AccountIdConversion, Saturating};
+	use sp_runtime::traits::{AccountIdConversion, EnsureDiv, Saturating};
 
 	#[pallet::pallet]
 	pub struct Pallet<T>(_);
@@ -232,6 +242,8 @@ pub mod pallet {
 		NonExistentPool,
 		/// An operation was attempted using a non-existent asset.
 		NonExistentAsset,
+		/// There was an error converting a block number.
+		BlockNumberConversionError,
 	}
 
 	#[pallet::hooks]
@@ -268,7 +280,7 @@ pub mod pallet {
 				Error::<T>::NonExistentAsset
 			);
 
-			// Get the admin, or try to use the origin as admin.
+			// Get the admin, defaulting to the origin.
 			let origin_acc_id = ensure_signed(origin)?;
 			let admin = match admin {
 				Some(admin) => admin,
@@ -281,17 +293,17 @@ pub mod pallet {
 				reward_asset_id: *reward_asset_id.clone(),
 				reward_rate_per_block,
 				total_tokens_staked: 0u32.into(),
-				accumulated_rewards_per_share: 0u32.into(),
-				last_rewarded_block: 0u32.into(),
+				reward_per_token_stored: 0u32.into(),
+				last_update_block: 0u32.into(),
 				admin: admin.clone(),
 			};
 
-			// Insert the pool into storage.
+			// Insert it into storage.
 			let pool_id = NextPoolId::<T>::get();
 			Pools::<T>::insert(pool_id, pool);
 			NextPoolId::<T>::put(pool_id.saturating_add(1));
 
-			// Emit the event.
+			// Emit created event.
 			Self::deposit_event(Event::PoolCreated {
 				creator: origin_acc_id,
 				pool_id,
@@ -312,30 +324,86 @@ pub mod pallet {
 		}
 
 		/// Stake tokens in a pool.
-		pub fn stake(
-			_origin: OriginFor<T>,
-			_pool_id: PoolId,
-			_amount: T::Balance,
-		) -> DispatchResult {
-			todo!()
+		pub fn stake(origin: OriginFor<T>, pool_id: PoolId, amount: T::Balance) -> DispatchResult {
+			let caller = ensure_signed(origin)?;
+
+			// Always start by updating the pool rewards.
+			Self::update_pool_rewards(&pool_id, &caller)?;
+
+			// Try to freeze the staker assets.
+			// TODO: (blocked https://github.com/paritytech/polkadot-sdk/issues/3342)
+
+			// Update Pools.
+			let mut pool = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?;
+			pool.total_tokens_staked.saturating_accrue(amount);
+			Pools::<T>::insert(pool_id, pool);
+
+			// Update PoolStakers.
+			let mut staker = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default();
+			staker.amount.saturating_accrue(amount);
+			PoolStakers::<T>::insert(pool_id, &caller, staker);
+
+			Ok(())
 		}
 
 		/// Unstake tokens from a pool.
 		pub fn unstake(
-			_origin: OriginFor<T>,
-			_pool_id: PoolId,
-			_amount: T::Balance,
+			origin: OriginFor<T>,
+			pool_id: PoolId,
+			amount: T::Balance,
 		) -> DispatchResult {
-			todo!()
+			let caller = ensure_signed(origin)?;
+
+			// Always start by updating the pool rewards.
+			Self::update_pool_rewards(&pool_id, &caller)?;
+
+			// Unfreeze staker assets.
+			// TODO: (blocked https://github.com/paritytech/polkadot-sdk/issues/3342)
+
+			// Update Pools.
+			let mut pool = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?;
+			pool.total_tokens_staked.saturating_reduce(amount);
+
+			// Update PoolStakers.
+			let mut staker = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default();
+			staker.amount.saturating_reduce(amount);
+
+			Ok(())
 		}
 
 		/// Harvest unclaimed pool rewards for a staker.
 		pub fn harvest_rewards(
-			_origin: OriginFor<T>,
-			_staker: T::AccountId,
-			_pool_id: PoolId,
+			origin: OriginFor<T>,
+			pool_id: PoolId,
+			staker: Option<T::AccountId>,
 		) -> DispatchResult {
-			todo!()
+			let caller = ensure_signed(origin)?;
+
+			let staker = match staker {
+				Some(staker) => staker,
+				None => caller.clone(),
+			};
+
+			// Always start by updating the pool rewards.
+			Self::update_pool_rewards(&pool_id, &staker)?;
+
+			// Transfer unclaimed rewards from the pool to the staker.
+			let mut staker_info = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default();
+			let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?;
+			let pool_account_id = Self::pool_account_id(&pool_id)?;
+
+			T::Assets::transfer(
+				pool_info.reward_asset_id,
+				&pool_account_id,
+				&staker,
+				staker_info.rewards,
+				Preservation::Preserve,
+			)?;
+
+			// Reset staker unclaimed rewards.
+			staker_info.rewards = 0u32.into();
+
+			Ok(())
 		}
 
 		/// Modify the reward rate of a pool.
@@ -353,11 +421,21 @@ pub mod pallet {
 		/// pool pot address), but is provided for convenience so manual derivation of the
 		/// account id is not required.
 		pub fn deposit_reward_tokens(
-			_origin: OriginFor<T>,
-			_pool_id: PoolId,
-			_amount: T::Balance,
+			origin: OriginFor<T>,
+			pool_id: PoolId,
+			amount: T::Balance,
 		) -> DispatchResult {
-			todo!()
+			let caller = ensure_signed(origin)?;
+			let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?;
+			let pool_account_id = Self::pool_account_id(&pool_id)?;
+			T::Assets::transfer(
+				pool_info.reward_asset_id,
+				&caller,
+				&pool_account_id,
+				amount,
+				Preservation::Preserve,
+			)?;
+			Ok(())
 		}
 	}
 
@@ -371,9 +449,66 @@ pub mod pallet {
 			}
 		}
 
-		/// Update pool state in preparation for reward harvesting.
-		fn update_pool_rewards(_staked_asset_id: T::AssetId, _reward_asset_id: T::AssetId) {
-			todo!()
+		/// Update pool reward state.
+		fn update_pool_rewards(pool_id: &PoolId, staker: &T::AccountId) -> DispatchResult {
+			let reward_per_token = Self::reward_per_token(pool_id)?;
+
+			let mut pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?;
+			pool_info.last_update_block = frame_system::Pallet::<T>::block_number();
+			Pools::<T>::insert(pool_id, pool_info);
+
+			let mut staker_info = PoolStakers::<T>::get(pool_id, staker).unwrap_or_default();
+			staker_info.rewards = Self::derive_rewards(pool_id, staker)?;
+			staker_info.reward_per_token_paid = reward_per_token;
+			PoolStakers::<T>::insert(pool_id, staker, staker_info);
+
+			Ok(())
+		}
+
+		/// Derives the current reward per token for this pool.
+		///
+		/// Helper function for update_pool_rewards. Should not be called directly.
+		fn reward_per_token(pool_id: &PoolId) -> Result<T::Balance, DispatchError> {
+			let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?;
+
+			if pool_info.total_tokens_staked.eq(&0u32.into()) {
+				return Ok(0u32.into());
+			}
+
+			let blocks_elapsed: u32 = match frame_system::Pallet::<T>::block_number()
+				.saturating_sub(pool_info.last_update_block)
+				.try_into()
+			{
+				Ok(b) => b,
+				Err(_) => return Err(Error::<T>::BlockNumberConversionError.into()),
+			};
+
+			Ok(pool_info
+				.reward_per_token_stored
+				.saturating_add(
+					pool_info
+						.reward_rate_per_block
+						.saturating_mul(blocks_elapsed.into())
+						.saturating_mul(PRECISION_SCALING_FACTOR.into()),
+				)
+				.ensure_div(pool_info.total_tokens_staked)?)
+		}
+
+		/// Derives the amount of rewards earned by a staker.
+		///
+		/// Helper function for update_pool_rewards. Should not be called directly.
+		fn derive_rewards(
+			pool_id: &PoolId,
+			staker: &T::AccountId,
+		) -> Result<T::Balance, DispatchError> {
+			let reward_per_token = Self::reward_per_token(pool_id)?;
+			let staker_info = PoolStakers::<T>::get(pool_id, staker).unwrap_or_default();
+
+			Ok(staker_info
+				.amount
+				.saturating_mul(reward_per_token.saturating_sub(staker_info.reward_per_token_paid))
+				.ensure_div(PRECISION_SCALING_FACTOR.into())?
+				.saturating_add(staker_info.rewards))
 		}
 	}
 }
diff --git a/substrate/frame/staking-rewards/src/tests.rs b/substrate/frame/staking-rewards/src/tests.rs
index bf2d66940eb..1494f1c776b 100644
--- a/substrate/frame/staking-rewards/src/tests.rs
+++ b/substrate/frame/staking-rewards/src/tests.rs
@@ -54,88 +54,49 @@ fn pools() -> Vec<(u32, PoolInfo<u128, NativeOrWithId<u32>, u128, u64>)> {
 	Pools::<MockRuntime>::iter().collect()
 }
 
-#[test]
-fn create_pool_works() {
-	new_test_ext().execute_with(|| {
-		// Setup
-		let user = 1;
-		let staking_asset_id = NativeOrWithId::<u32>::Native;
-		let reward_asset_id = NativeOrWithId::<u32>::WithId(1);
-		let reward_rate_per_block = 100;
-
-		create_tokens(user, vec![reward_asset_id.clone()]);
-		assert_ok!(Balances::force_set_balance(RuntimeOrigin::root(), user, 1000));
-
-		// Create a pool with default admin.
-		assert_eq!(NextPoolId::<MockRuntime>::get(), 0);
-		assert_ok!(StakingRewards::create_pool(
-			RuntimeOrigin::signed(user),
-			Box::new(staking_asset_id.clone()),
-			Box::new(reward_asset_id.clone()),
-			reward_rate_per_block,
-			None
-		));
-
-		// Event is emitted.
-		assert_eq!(
-			events(),
-			[Event::<MockRuntime>::PoolCreated {
-				creator: user,
-				pool_id: 0,
-				staking_asset_id: staking_asset_id.clone(),
-				reward_asset_id: reward_asset_id.clone(),
+mod create_pool {
+	use super::*;
+
+	#[test]
+	fn success() {
+		new_test_ext().execute_with(|| {
+			// Setup
+			let user = 1;
+			let staking_asset_id = NativeOrWithId::<u32>::Native;
+			let reward_asset_id = NativeOrWithId::<u32>::WithId(1);
+			let reward_rate_per_block = 100;
+
+			create_tokens(user, vec![reward_asset_id.clone()]);
+			assert_ok!(Balances::force_set_balance(RuntimeOrigin::root(), user, 1000));
+
+			// Create a pool with default admin.
+			assert_eq!(NextPoolId::<MockRuntime>::get(), 0);
+			assert_ok!(StakingRewards::create_pool(
+				RuntimeOrigin::signed(user),
+				Box::new(staking_asset_id.clone()),
+				Box::new(reward_asset_id.clone()),
 				reward_rate_per_block,
-				admin: user,
-			}]
-		);
-
-		// State is updated correctly.
-		assert_eq!(NextPoolId::<MockRuntime>::get(), 1);
-		assert_eq!(
-			pools(),
-			vec![(
-				0,
-				PoolInfo {
+				None
+			));
+
+			// Event is emitted.
+			assert_eq!(
+				events(),
+				[Event::<MockRuntime>::PoolCreated {
+					creator: user,
+					pool_id: 0,
 					staking_asset_id: staking_asset_id.clone(),
 					reward_asset_id: reward_asset_id.clone(),
 					reward_rate_per_block,
 					admin: user,
-					total_tokens_staked: 0,
-					accumulated_rewards_per_share: 0,
-					last_rewarded_block: 0
-				}
-			)]
-		);
-
-		// Create another pool with explicit admin.
-		let admin = 2;
-		assert_ok!(StakingRewards::create_pool(
-			RuntimeOrigin::signed(user),
-			Box::new(staking_asset_id.clone()),
-			Box::new(reward_asset_id.clone()),
-			reward_rate_per_block,
-			Some(admin)
-		));
-
-		// Event is emitted.
-		assert_eq!(
-			events(),
-			[Event::<MockRuntime>::PoolCreated {
-				creator: user,
-				pool_id: 1,
-				staking_asset_id: staking_asset_id.clone(),
-				reward_asset_id: reward_asset_id.clone(),
-				reward_rate_per_block,
-				admin,
-			}]
-		);
-
-		// State is updated correctly.
-		assert_eq!(NextPoolId::<MockRuntime>::get(), 2);
-		assert_eq!(
-			pools(),
-			vec![
-				(
+				}]
+			);
+
+			// State is updated correctly.
+			assert_eq!(NextPoolId::<MockRuntime>::get(), 1);
+			assert_eq!(
+				pools(),
+				vec![(
 					0,
 					PoolInfo {
 						staking_asset_id: staking_asset_id.clone(),
@@ -143,64 +104,180 @@ fn create_pool_works() {
 						reward_rate_per_block,
 						admin: user,
 						total_tokens_staked: 0,
-						accumulated_rewards_per_share: 0,
-						last_rewarded_block: 0
+						reward_per_token_stored: 0,
+						last_update_block: 0
 					}
+				)]
+			);
+
+			// Create another pool with explicit admin.
+			let admin = 2;
+			assert_ok!(StakingRewards::create_pool(
+				RuntimeOrigin::signed(user),
+				Box::new(staking_asset_id.clone()),
+				Box::new(reward_asset_id.clone()),
+				reward_rate_per_block,
+				Some(admin)
+			));
+
+			// Event is emitted.
+			assert_eq!(
+				events(),
+				[Event::<MockRuntime>::PoolCreated {
+					creator: user,
+					pool_id: 1,
+					staking_asset_id: staking_asset_id.clone(),
+					reward_asset_id: reward_asset_id.clone(),
+					reward_rate_per_block,
+					admin,
+				}]
+			);
+
+			// State is updated correctly.
+			assert_eq!(NextPoolId::<MockRuntime>::get(), 2);
+			assert_eq!(
+				pools(),
+				vec![
+					(
+						0,
+						PoolInfo {
+							staking_asset_id: staking_asset_id.clone(),
+							reward_asset_id: reward_asset_id.clone(),
+							reward_rate_per_block,
+							admin: user,
+							total_tokens_staked: 0,
+							reward_per_token_stored: 0,
+							last_update_block: 0
+						}
+					),
+					(
+						1,
+						PoolInfo {
+							staking_asset_id,
+							reward_asset_id,
+							reward_rate_per_block,
+							admin,
+							total_tokens_staked: 0,
+							reward_per_token_stored: 0,
+							last_update_block: 0
+						}
+					)
+				]
+			);
+		});
+	}
+
+	#[test]
+	fn non_existent_asset_fails() {
+		new_test_ext().execute_with(|| {
+			let valid_asset = NativeOrWithId::<u32>::WithId(1);
+			let invalid_asset = NativeOrWithId::<u32>::WithId(200);
+
+			assert_err!(
+				StakingRewards::create_pool(
+					RuntimeOrigin::signed(1),
+					Box::new(valid_asset.clone()),
+					Box::new(invalid_asset.clone()),
+					10,
+					None
 				),
-				(
-					1,
-					PoolInfo {
-						staking_asset_id,
-						reward_asset_id,
-						reward_rate_per_block,
-						admin,
-						total_tokens_staked: 0,
-						accumulated_rewards_per_share: 0,
-						last_rewarded_block: 0
-					}
-				)
-			]
-		);
-	});
+				Error::<MockRuntime>::NonExistentAsset
+			);
+
+			assert_err!(
+				StakingRewards::create_pool(
+					RuntimeOrigin::signed(1),
+					Box::new(invalid_asset.clone()),
+					Box::new(valid_asset.clone()),
+					10,
+					None
+				),
+				Error::<MockRuntime>::NonExistentAsset
+			);
+
+			assert_err!(
+				StakingRewards::create_pool(
+					RuntimeOrigin::signed(1),
+					Box::new(invalid_asset.clone()),
+					Box::new(invalid_asset.clone()),
+					10,
+					None
+				),
+				Error::<MockRuntime>::NonExistentAsset
+			);
+		})
+	}
 }
 
-#[test]
-fn create_pool_with_non_existent_asset_fails() {
-	new_test_ext().execute_with(|| {
-		let valid_asset = NativeOrWithId::<u32>::WithId(1);
-		let invalid_asset = NativeOrWithId::<u32>::WithId(200);
-
-		assert_err!(
-			StakingRewards::create_pool(
-				RuntimeOrigin::signed(1),
-				Box::new(valid_asset.clone()),
-				Box::new(invalid_asset.clone()),
-				10,
-				None
-			),
-			Error::<MockRuntime>::NonExistentAsset
-		);
-
-		assert_err!(
-			StakingRewards::create_pool(
-				RuntimeOrigin::signed(1),
-				Box::new(invalid_asset.clone()),
-				Box::new(valid_asset.clone()),
-				10,
-				None
-			),
-			Error::<MockRuntime>::NonExistentAsset
-		);
-
-		assert_err!(
-			StakingRewards::create_pool(
-				RuntimeOrigin::signed(1),
-				Box::new(invalid_asset.clone()),
-				Box::new(invalid_asset.clone()),
-				10,
+mod stake {
+	use super::*;
+
+	#[test]
+	fn success() {
+		new_test_ext().execute_with(|| {
+			// Setup
+			let user = 1;
+			let staking_asset_id = NativeOrWithId::<u32>::WithId(1);
+			let reward_asset_id = NativeOrWithId::<u32>::Native;
+			let reward_rate_per_block = 100;
+
+			create_tokens(user, vec![staking_asset_id.clone()]);
+
+			assert_ok!(StakingRewards::create_pool(
+				RuntimeOrigin::signed(user),
+				Box::new(staking_asset_id.clone()),
+				Box::new(reward_asset_id.clone()),
+				reward_rate_per_block,
 				None
-			),
-			Error::<MockRuntime>::NonExistentAsset
-		);
-	})
+			));
+
+			let pool_id = 0;
+
+			// User stakes tokens
+			assert_ok!(StakingRewards::stake(RuntimeOrigin::signed(user), pool_id, 1000));
+
+			// Check that the user's staked amount is updated
+			assert_eq!(PoolStakers::<MockRuntime>::get(pool_id, user).unwrap().amount, 1000);
+
+			// Check that the pool's total tokens staked is updated
+			assert_eq!(Pools::<MockRuntime>::get(pool_id).unwrap().total_tokens_staked, 1000);
+
+			// TODO: Check user's frozen balance is updated
+
+			// User stakes more tokens
+			assert_ok!(StakingRewards::stake(RuntimeOrigin::signed(user), pool_id, 500));
+
+			// Check that the user's staked amount is updated
+			assert_eq!(PoolStakers::<MockRuntime>::get(pool_id, user).unwrap().amount, 1500);
+
+			// Check that the pool's total tokens staked is updated
+			assert_eq!(Pools::<MockRuntime>::get(pool_id).unwrap().total_tokens_staked, 1500);
+
+			// TODO: Check user's frozen balance is updated
+		});
+	}
+
+	#[test]
+	fn non_existent_pool() {
+		new_test_ext().execute_with(|| {
+			// Setup
+			let user = 1;
+			let staking_asset_id = NativeOrWithId::<u32>::WithId(1);
+
+			create_tokens(user, vec![staking_asset_id.clone()]);
+
+			let non_existent_pool_id = 999;
+
+			// User tries to stake tokens in a non-existent pool
+			assert_err!(
+				StakingRewards::stake(RuntimeOrigin::signed(user), non_existent_pool_id, 1000),
+				Error::<MockRuntime>::NonExistentPool
+			);
+		});
+	}
+
+	#[test]
+	fn insufficient_balance() {
+		// TODO: When we're able to freeze assets.
+	}
 }
-- 
GitLab