From 81aeb0c7b3961179e8d954b245bb1af9eefade8c Mon Sep 17 00:00:00 2001 From: Liam Aharon <liam.aharon@hotmail.com> Date: Wed, 3 Apr 2024 13:34:29 +0400 Subject: [PATCH] pure internal functions --- substrate/frame/asset-rewards/src/lib.rs | 138 ++++++++++++++--------- 1 file changed, 83 insertions(+), 55 deletions(-) diff --git a/substrate/frame/asset-rewards/src/lib.rs b/substrate/frame/asset-rewards/src/lib.rs index ea4430f7b82..d9c1688a53b 100644 --- a/substrate/frame/asset-rewards/src/lib.rs +++ b/substrate/frame/asset-rewards/src/lib.rs @@ -43,6 +43,15 @@ //! //! ## Implementation Notes //! +//! Internal logic functions such as `update_pool_and_staker_rewards` where deliberately written +//! without any side-effects like storage interaction. +//! +//! Storage interaction such as reads and writes are instead all performed in the top level +//! pallet Call method, which while slightly more verbose, makes it much easier to understand the +//! code and reason about where side-effects occur in the pallet. +//! +//! ## Implementation Notes +//! //! The implementation is based on the [AccumulatedRewardsPerShare](https://dev.to/heymarkkop/understanding-sushiswaps-masterchef-staking-rewards-1m6f) algorithm. //! //! Rewards are calculated JIT (just-in-time), when a staker claims their rewards. @@ -79,6 +88,14 @@ pub type PoolId = u32; /// Multiplier to maintain precision when calculating rewards. pub(crate) const PRECISION_SCALING_FACTOR: u32 = u32::MAX; +/// Convenience type alias for `PoolInfo`. +pub type PoolInfoFor<T> = PoolInfo< + <T as frame_system::Config>::AccountId, + <T as Config>::AssetId, + <T as Config>::Balance, + BlockNumberFor<T>, +>; + /// A pool staker. #[derive(Debug, Default, Clone, Decode, Encode, MaxEncodedLen, TypeInfo)] pub struct PoolStakerInfo<Balance> { @@ -320,7 +337,7 @@ pub mod pallet { }; // Create the pool. - let pool = PoolInfo::<T::AccountId, T::AssetId, T::Balance, BlockNumberFor<T>> { + let pool = PoolInfoFor::<T> { staking_asset_id: *staked_asset_id.clone(), reward_asset_id: *reward_asset_id.clone(), reward_rate_per_block, @@ -361,21 +378,22 @@ pub mod pallet { pub fn stake(origin: OriginFor<T>, pool_id: PoolId, amount: T::Balance) -> DispatchResult { let caller = ensure_signed(origin)?; - // Always start by updating the pool rewards. - Self::update_pool_rewards(&pool_id, Some(&caller))?; + // Always start by updating staker and pool rewards. + let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; + let staker_info = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default(); + let (mut pool_info, mut staker_info) = + Self::update_pool_and_staker_rewards(pool_info, staker_info)?; // Try to freeze the staker assets. // TODO: (blocked https://github.com/paritytech/polkadot-sdk/issues/3342) // Update Pools. - let mut pool = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; - pool.total_tokens_staked.saturating_accrue(amount); - Pools::<T>::insert(pool_id, pool); + pool_info.total_tokens_staked.saturating_accrue(amount); + Pools::<T>::insert(pool_id, pool_info); // Update PoolStakers. - let mut staker = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default(); - staker.amount.saturating_accrue(amount); - PoolStakers::<T>::insert(pool_id, &caller, staker); + staker_info.amount.saturating_accrue(amount); + PoolStakers::<T>::insert(pool_id, &caller, staker_info); // Emit event. Self::deposit_event(Event::Staked { who: caller, pool_id, amount }); @@ -392,23 +410,24 @@ pub mod pallet { let caller = ensure_signed(origin)?; // Always start by updating the pool rewards. - Self::update_pool_rewards(&pool_id, Some(&caller))?; + let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; + let staker_info = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default(); + let (mut pool_info, mut staker_info) = + Self::update_pool_and_staker_rewards(pool_info, staker_info)?; // Check the staker has enough staked tokens. - let mut staker = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default(); - ensure!(staker.amount >= amount, Error::<T>::NotEnoughTokens); + ensure!(staker_info.amount >= amount, Error::<T>::NotEnoughTokens); // Unfreeze staker assets. // TODO: (blocked https://github.com/paritytech/polkadot-sdk/issues/3342) // Update Pools. - let mut pool = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; - pool.total_tokens_staked.saturating_reduce(amount); - Pools::<T>::insert(pool_id, pool); + pool_info.total_tokens_staked.saturating_reduce(amount); + Pools::<T>::insert(pool_id, pool_info); // Update PoolStakers. - staker.amount.saturating_reduce(amount); - PoolStakers::<T>::insert(pool_id, &caller, staker); + staker_info.amount.saturating_reduce(amount); + PoolStakers::<T>::insert(pool_id, &caller, staker_info); // Emit event. Self::deposit_event(Event::Unstaked { who: caller, pool_id, amount }); @@ -429,12 +448,13 @@ pub mod pallet { None => caller.clone(), }; - // Always start by updating the pool rewards. - Self::update_pool_rewards(&pool_id, Some(&staker))?; + // Always start by updating the pool and staker rewards. + let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; + let staker_info = PoolStakers::<T>::get(pool_id, &staker).unwrap_or_default(); + let (pool_info, mut staker_info) = + Self::update_pool_and_staker_rewards(pool_info, staker_info)?; // Transfer unclaimed rewards from the pool to the staker. - let mut staker_info = PoolStakers::<T>::get(pool_id, &caller).unwrap_or_default(); - let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; let pool_account_id = Self::pool_account_id(&pool_id)?; T::Assets::transfer( pool_info.reward_asset_id, @@ -469,9 +489,9 @@ pub mod pallet { let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; ensure!(pool_info.admin == caller, BadOrigin); - Self::update_pool_rewards(&pool_id, None)?; + // Always start by updating the pool rewards. + let mut pool_info = Self::update_pool_rewards(pool_info)?; - let mut pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; pool_info.reward_rate_per_block = new_reward_rate_per_block; Pools::<T>::insert(pool_id, pool_info); @@ -514,9 +534,10 @@ pub mod pallet { Error::<T>::ExpiryBlockMustBeInTheFuture ); - Self::update_pool_rewards(&pool_id, None)?; + // Always start by updating the pool rewards. + let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; + let mut pool_info = Self::update_pool_rewards(pool_info)?; - let mut pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; ensure!(pool_info.admin == caller, BadOrigin); pool_info.expiry_block = new_expiry_block; Pools::<T>::insert(pool_id, pool_info); @@ -560,43 +581,51 @@ pub mod pallet { } } - /// Update pool reward state, and optionally also a staker's rewards. + /// Computes update pool and staker reward state. + /// + /// Should be called prior to any operation involving a staker. + /// + /// Returns the updated pool and staker info. + /// + /// NOTE: this is a pure function without side effects. It does not modify any state + /// directly, that is the responsibility of the caller. + pub fn update_pool_and_staker_rewards( + mut pool_info: PoolInfoFor<T>, + mut staker_info: PoolStakerInfo<T::Balance>, + ) -> Result<(PoolInfoFor<T>, PoolStakerInfo<T::Balance>), DispatchError> { + let reward_per_token = Self::reward_per_token(&pool_info)?; + + pool_info.last_update_block = frame_system::Pallet::<T>::block_number(); + pool_info.reward_per_token_stored = reward_per_token; + + staker_info.rewards = Self::derive_rewards(&pool_info, &staker_info)?; + staker_info.reward_per_token_paid = reward_per_token; + return Ok((pool_info, staker_info)); + } + + /// Computes update pool reward state. + /// + /// Should be called every time the pool is adjusted, and a staker is not involved. + /// + /// Returns the updated pool and staker info. /// - /// Returns the updated pool info and optional staker info. + /// NOTE: this is a pure function without side effects. It does not modify any state + /// directly, that is the responsibility of the caller. pub fn update_pool_rewards( - pool_id: &PoolId, - staker: Option<&T::AccountId>, - ) -> Result< - ( - PoolInfo<T::AccountId, T::AssetId, T::Balance, BlockNumberFor<T>>, - Option<PoolStakerInfo<T::Balance>>, - ), - DispatchError, - > { - let reward_per_token = Self::reward_per_token(pool_id)?; + mut pool_info: PoolInfoFor<T>, + ) -> Result<PoolInfoFor<T>, DispatchError> { + let reward_per_token = Self::reward_per_token(&pool_info)?; - let mut pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; pool_info.last_update_block = frame_system::Pallet::<T>::block_number(); pool_info.reward_per_token_stored = reward_per_token; - Pools::<T>::insert(pool_id, pool_info.clone()); - - if let Some(staker) = staker { - let mut staker_info = PoolStakers::<T>::get(pool_id, staker).unwrap_or_default(); - staker_info.rewards = Self::derive_rewards(pool_id, staker)?; - staker_info.reward_per_token_paid = reward_per_token; - PoolStakers::<T>::insert(pool_id, staker, staker_info.clone()); - return Ok((pool_info, Some(staker_info))); - } - Ok((pool_info, None)) + Ok(pool_info) } /// Derives the current reward per token for this pool. /// /// Helper function for update_pool_rewards. Should not be called directly. - fn reward_per_token(pool_id: &PoolId) -> Result<T::Balance, DispatchError> { - let pool_info = Pools::<T>::get(pool_id).ok_or(Error::<T>::NonExistentPool)?; - + fn reward_per_token(pool_info: &PoolInfoFor<T>) -> Result<T::Balance, DispatchError> { if pool_info.total_tokens_staked.eq(&0u32.into()) { return Ok(pool_info.reward_per_token_stored) } @@ -623,11 +652,10 @@ pub mod pallet { /// /// Helper function for update_pool_rewards. Should not be called directly. fn derive_rewards( - pool_id: &PoolId, - staker: &T::AccountId, + pool_info: &PoolInfoFor<T>, + staker_info: &PoolStakerInfo<T::Balance>, ) -> Result<T::Balance, DispatchError> { - let reward_per_token = Self::reward_per_token(pool_id)?; - let staker_info = PoolStakers::<T>::get(pool_id, staker).unwrap_or_default(); + let reward_per_token = Self::reward_per_token(&pool_info)?; Ok(staker_info .amount -- GitLab