From c672cce4bd3a97f423780fac32ba6ad204daefbf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Alexander=20Thei=C3=9Fen?= <alex.theissen@me.com>
Date: Thu, 28 May 2020 14:33:10 +0200
Subject: [PATCH] Make post dispatch fee consistent with a direct calculation
 (#6165)

* Make post dispatch fee consistent with a direct calculation

* Remove unnecessary `saturated_into` calls

* Add test with negative multipliers

* Added regression test

* Test improvements
---
 substrate/frame/support/src/weights.rs        |  19 +--
 .../frame/transaction-payment/src/lib.rs      | 142 +++++++++++++++---
 substrate/primitives/arithmetic/src/lib.rs    |   2 +-
 substrate/primitives/runtime/src/lib.rs       |   2 +-
 4 files changed, 136 insertions(+), 29 deletions(-)

diff --git a/substrate/frame/support/src/weights.rs b/substrate/frame/support/src/weights.rs
index 771f908ecf7..dd80f8d8a8e 100644
--- a/substrate/frame/support/src/weights.rs
+++ b/substrate/frame/support/src/weights.rs
@@ -272,14 +272,15 @@ pub struct PostDispatchInfo {
 impl PostDispatchInfo {
 	/// Calculate how much (if any) weight was not used by the `Dispatchable`.
 	pub fn calc_unspent(&self, info: &DispatchInfo) -> Weight {
+		info.weight - self.calc_actual_weight(info)
+	}
+
+	/// Calculate how much weight was actually spent by the `Dispatchable`.
+	pub fn calc_actual_weight(&self, info: &DispatchInfo) -> Weight {
 		if let Some(actual_weight) = self.actual_weight {
-			if actual_weight >= info.weight {
-				0
-			} else {
-				info.weight - actual_weight
-			}
+			actual_weight.min(info.weight)
 		} else {
-			0
+			info.weight
 		}
 	}
 }
@@ -287,9 +288,9 @@ impl PostDispatchInfo {
 /// Extract the actual weight from a dispatch result if any or fall back to the default weight.
 pub fn extract_actual_weight(result: &DispatchResultWithPostInfo, info: &DispatchInfo) -> Weight {
 	match result {
-		Ok(post_info) => &post_info.actual_weight,
-		Err(err) => &err.post_info.actual_weight,
-	}.unwrap_or_else(|| info.weight).min(info.weight)
+		Ok(post_info) => &post_info,
+		Err(err) => &err.post_info,
+	}.calc_actual_weight(info)
 }
 
 impl From<Option<Weight>> for PostDispatchInfo {
diff --git a/substrate/frame/transaction-payment/src/lib.rs b/substrate/frame/transaction-payment/src/lib.rs
index 55d448bce34..78c638b4844 100644
--- a/substrate/frame/transaction-payment/src/lib.rs
+++ b/substrate/frame/transaction-payment/src/lib.rs
@@ -44,7 +44,7 @@ use frame_support::{
 	dispatch::DispatchResult,
 };
 use sp_runtime::{
-	Fixed128, FixedPointNumber,
+	Fixed128, FixedPointNumber, FixedPointOperand,
 	transaction_validity::{
 		TransactionPriority, ValidTransaction, InvalidTransaction, TransactionValidityError,
 		TransactionValidity,
@@ -104,7 +104,9 @@ decl_module! {
 	}
 }
 
-impl<T: Trait> Module<T> {
+impl<T: Trait> Module<T> where
+	BalanceOf<T>: FixedPointOperand
+{
 	/// Query the data that we know about the fee of a given `call`.
 	///
 	/// This module is not and cannot be aware of the internals of a signed extension, for example
@@ -163,35 +165,63 @@ impl<T: Trait> Module<T> {
 	) -> BalanceOf<T> where
 		T::Call: Dispatchable<Info=DispatchInfo>,
 	{
-		if info.pays_fee == Pays::Yes {
+		Self::compute_fee_raw(len, info.weight, tip, info.pays_fee)
+	}
+
+	/// Compute the actual post dispatch fee for a particular transaction.
+	///
+	/// Identical to `compute_fee` with the only difference that the post dispatch corrected
+	/// weight is used for the weight fee calculation.
+	pub fn compute_actual_fee(
+		len: u32,
+		info: &DispatchInfoOf<T::Call>,
+		post_info: &PostDispatchInfoOf<T::Call>,
+		tip: BalanceOf<T>,
+	) -> BalanceOf<T> where
+		T::Call: Dispatchable<Info=DispatchInfo,PostInfo=PostDispatchInfo>,
+	{
+		Self::compute_fee_raw(len, post_info.calc_actual_weight(info), tip, info.pays_fee)
+	}
+
+	fn compute_fee_raw(
+		len: u32,
+		weight: Weight,
+		tip: BalanceOf<T>,
+		pays_fee: Pays,
+	) -> BalanceOf<T> {
+		if pays_fee == Pays::Yes {
 			let len = <BalanceOf<T>>::from(len);
 			let per_byte = T::TransactionByteFee::get();
 			let len_fee = per_byte.saturating_mul(len);
-			let unadjusted_weight_fee = Self::weight_to_fee(info.weight);
+			let unadjusted_weight_fee = Self::weight_to_fee(weight);
 
 			// the adjustable part of the fee
 			let adjustable_fee = len_fee.saturating_add(unadjusted_weight_fee);
 			let targeted_fee_adjustment = NextFeeMultiplier::get();
-			let adjusted_fee = targeted_fee_adjustment.saturating_mul_acc_int(adjustable_fee.saturated_into());
+			let adjusted_fee = targeted_fee_adjustment.saturating_mul_acc_int(adjustable_fee);
 
 			let base_fee = Self::weight_to_fee(T::ExtrinsicBaseWeight::get());
-			base_fee.saturating_add(adjusted_fee.saturated_into()).saturating_add(tip)
+			base_fee.saturating_add(adjusted_fee).saturating_add(tip)
 		} else {
 			tip
 		}
 	}
+}
 
+impl<T: Trait> Module<T> {
 	/// Compute the fee for the specified weight.
 	///
 	/// This fee is already adjusted by the per block fee adjustment factor and is therefore
 	/// the share that the weight contributes to the overall fee of a transaction.
+	///
+	/// This function is generic in order to supply the contracts module with a way
+	/// to calculate the gas price. The contracts module is not able to put the necessary
+	/// `BalanceOf<T>` contraints on its trait. This function is not to be used by this module.
 	pub fn weight_to_fee_with_adjustment<Balance>(weight: Weight) -> Balance where
 		Balance: UniqueSaturatedFrom<u128>
 	{
-		let fee = UniqueSaturatedInto::<u128>::unique_saturated_into(Self::weight_to_fee(weight));
-		UniqueSaturatedFrom::unique_saturated_from(
-			NextFeeMultiplier::get().saturating_mul_acc_int(fee)
-		)
+		let fee: u128 = Self::weight_to_fee(weight).unique_saturated_into();
+		Balance::unique_saturated_from(NextFeeMultiplier::get().saturating_mul_acc_int(fee))
 	}
 
 	fn weight_to_fee(weight: Weight) -> BalanceOf<T> {
@@ -209,7 +239,7 @@ pub struct ChargeTransactionPayment<T: Trait + Send + Sync>(#[codec(compact)] Ba
 
 impl<T: Trait + Send + Sync> ChargeTransactionPayment<T> where
 	T::Call: Dispatchable<Info=DispatchInfo, PostInfo=PostDispatchInfo>,
-	BalanceOf<T>: Send + Sync,
+	BalanceOf<T>: Send + Sync + FixedPointOperand,
 {
 	/// utility constructor. Used only in client/factory code.
 	pub fn from(fee: BalanceOf<T>) -> Self {
@@ -258,14 +288,14 @@ impl<T: Trait + Send + Sync> sp_std::fmt::Debug for ChargeTransactionPayment<T>
 }
 
 impl<T: Trait + Send + Sync> SignedExtension for ChargeTransactionPayment<T> where
-	BalanceOf<T>: Send + Sync + From<u64>,
+	BalanceOf<T>: Send + Sync + From<u64> + FixedPointOperand,
 	T::Call: Dispatchable<Info=DispatchInfo, PostInfo=PostDispatchInfo>,
 {
 	const IDENTIFIER: &'static str = "ChargeTransactionPayment";
 	type AccountId = T::AccountId;
 	type Call = T::Call;
 	type AdditionalSigned = ();
-	type Pre = (BalanceOf<T>, Self::AccountId, Option<NegativeImbalanceOf<T>>);
+	type Pre = (BalanceOf<T>, Self::AccountId, Option<NegativeImbalanceOf<T>>, BalanceOf<T>);
 	fn additional_signed(&self) -> sp_std::result::Result<(), TransactionValidityError> { Ok(()) }
 
 	fn validate(
@@ -291,20 +321,26 @@ impl<T: Trait + Send + Sync> SignedExtension for ChargeTransactionPayment<T> whe
 		info: &DispatchInfoOf<Self::Call>,
 		len: usize
 	) -> Result<Self::Pre, TransactionValidityError> {
-		let (_, imbalance) = self.withdraw_fee(who, info, len)?;
-		Ok((self.0, who.clone(), imbalance))
+		let (fee, imbalance) = self.withdraw_fee(who, info, len)?;
+		Ok((self.0, who.clone(), imbalance, fee))
 	}
 
 	fn post_dispatch(
 		pre: Self::Pre,
 		info: &DispatchInfoOf<Self::Call>,
 		post_info: &PostDispatchInfoOf<Self::Call>,
-		_len: usize,
+		len: usize,
 		_result: &DispatchResult,
 	) -> Result<(), TransactionValidityError> {
-		let (tip, who, imbalance) = pre;
+		let (tip, who, imbalance, fee) = pre;
 		if let Some(payed) = imbalance {
-			let refund = Module::<T>::weight_to_fee_with_adjustment(post_info.calc_unspent(info));
+			let actual_fee = Module::<T>::compute_actual_fee(
+				len as u32,
+				info,
+				post_info,
+				tip,
+			);
+			let refund = fee.saturating_sub(actual_fee);
 			let actual_payment = match T::Currency::deposit_into_existing(&who, refund) {
 				Ok(refund_imbalance) => {
 					// The refund cannot be larger than the up front payed max weight.
@@ -789,6 +825,39 @@ mod tests {
 		});
 	}
 
+	#[test]
+	fn compute_fee_works_with_negative_multiplier() {
+		ExtBuilder::default()
+			.base_weight(100)
+			.byte_fee(10)
+			.balance_factor(0)
+			.build()
+			.execute_with(||
+		{
+			// Add a next fee multiplier
+			NextFeeMultiplier::put(Fixed128::saturating_from_rational(-1, 2)); // = -1/2 = -.5
+			// Base fee is unaffected by multiplier
+			let dispatch_info = DispatchInfo {
+				weight: 0,
+				class: DispatchClass::Operational,
+				pays_fee: Pays::Yes,
+			};
+			assert_eq!(Module::<Runtime>::compute_fee(0, &dispatch_info, 0), 100);
+
+			// Everything works together :)
+			let dispatch_info = DispatchInfo {
+				weight: 123,
+				class: DispatchClass::Operational,
+				pays_fee: Pays::Yes,
+			};
+			// 123 weight, 456 length, 100 base
+			// adjustable fee = (123 * 1) + (456 * 10) = 4683
+			// adjusted fee = 4683 - (4683 * -.5)  = 4683 - 2341.5 = 4683 - 2341 = 2342
+			// final fee = 100 + 2342 + 789 tip = 3231
+			assert_eq!(Module::<Runtime>::compute_fee(456, &dispatch_info, 789), 3231);
+		});
+	}
+
 	#[test]
 	fn compute_fee_does_not_overflow() {
 		ExtBuilder::default()
@@ -906,4 +975,41 @@ mod tests {
 			assert_eq!(System::events().len(), 0);
 		});
 	}
+
+	#[test]
+	fn refund_consistent_with_actual_weight() {
+		ExtBuilder::default()
+			.balance_factor(10)
+			.base_weight(7)
+			.build()
+			.execute_with(||
+		{
+			let info = info_from_weight(100);
+			let post_info = post_info_from_weight(33);
+			let prev_balance = Balances::free_balance(2);
+			let len = 10;
+			let tip = 5;
+
+			NextFeeMultiplier::put(Fixed128::saturating_from_rational(1, 4));
+
+			let pre = ChargeTransactionPayment::<Runtime>::from(tip)
+				.pre_dispatch(&2, CALL, &info, len)
+				.unwrap();
+
+			ChargeTransactionPayment::<Runtime>
+				::post_dispatch(pre, &info, &post_info, len, &Ok(()))
+				.unwrap();
+
+			let refund_based_fee = prev_balance - Balances::free_balance(2);
+			let actual_fee = Module::<Runtime>
+				::compute_actual_fee(len as u32, &info, &post_info, tip);
+
+			// 33 weight, 10 length, 7 base
+			// adjustable fee = (33 * 1) + (10 * 1) = 43
+			// adjusted fee = 43 + (43 * .25)  = 43 + 10.75 = 43 + 10 = 53
+			// final fee = 7 + 53 + 5 tip = 65
+			assert_eq!(actual_fee, 65);
+			assert_eq!(refund_based_fee, actual_fee);
+		});
+	}
 }
diff --git a/substrate/primitives/arithmetic/src/lib.rs b/substrate/primitives/arithmetic/src/lib.rs
index 0ac58b12fe0..c4f95af6463 100644
--- a/substrate/primitives/arithmetic/src/lib.rs
+++ b/substrate/primitives/arithmetic/src/lib.rs
@@ -40,7 +40,7 @@ mod per_things;
 mod fixed;
 mod rational128;
 
-pub use fixed::{FixedPointNumber, Fixed64, Fixed128};
+pub use fixed::{FixedPointNumber, Fixed64, Fixed128, FixedPointOperand};
 pub use per_things::{PerThing, Percent, PerU16, Permill, Perbill, Perquintill};
 pub use rational128::Rational128;
 
diff --git a/substrate/primitives/runtime/src/lib.rs b/substrate/primitives/runtime/src/lib.rs
index 79b91424598..52ae46c6624 100644
--- a/substrate/primitives/runtime/src/lib.rs
+++ b/substrate/primitives/runtime/src/lib.rs
@@ -72,7 +72,7 @@ pub use sp_core::RuntimeDebug;
 /// Re-export top-level arithmetic stuff.
 pub use sp_arithmetic::{
 	Perquintill, Perbill, Permill, Percent, PerU16, Rational128, Fixed64, Fixed128,
-	PerThing, traits::SaturatedConversion, FixedPointNumber,
+	PerThing, traits::SaturatedConversion, FixedPointNumber, FixedPointOperand,
 };
 /// Re-export 128 bit helpers.
 pub use sp_arithmetic::helpers_128bit;
-- 
GitLab