From 4a353f1e81c351ccd99db676fdab0df8cb9088ff Mon Sep 17 00:00:00 2001
From: Kian Paimani <5588131+kianenigma@users.noreply.github.com>
Date: Tue, 14 Jan 2020 08:41:18 +0100
Subject: [PATCH] custom weight function wrapper (#4158)

* custom weight function wrapper

* dox

* Better tests.

* remove u8 encoding

* Update frame/support/src/weights.rs

Co-Authored-By: thiolliere <gui.thiolliere@gmail.com>

* fix pays fee

Co-authored-by: thiolliere <gui.thiolliere@gmail.com>
---
 substrate/frame/staking/src/lib.rs      |  1 -
 substrate/frame/staking/src/slashing.rs |  3 -
 substrate/frame/support/src/weights.rs  | 73 +++++++++++++++++++++++++
 3 files changed, 73 insertions(+), 4 deletions(-)

diff --git a/substrate/frame/staking/src/lib.rs b/substrate/frame/staking/src/lib.rs
index a59aa5d0c8e..94618c93498 100644
--- a/substrate/frame/staking/src/lib.rs
+++ b/substrate/frame/staking/src/lib.rs
@@ -1628,7 +1628,6 @@ impl<T: Trait> Module<T> {
 	/// For each element in the iterator the given number of points in u32 is added to the
 	/// validator, thus duplicates are handled.
 	pub fn reward_by_indices(validators_points: impl IntoIterator<Item = (u32, u32)>) {
-		// TODO: This can be optimised once #3302 is implemented.
 		let current_elected_len = <Module<T>>::current_elected().len() as u32;
 
 		CurrentEraPointsEarned::mutate(|rewards| {
diff --git a/substrate/frame/staking/src/slashing.rs b/substrate/frame/staking/src/slashing.rs
index 6d591603fdb..7322b9a1d31 100644
--- a/substrate/frame/staking/src/slashing.rs
+++ b/substrate/frame/staking/src/slashing.rs
@@ -649,9 +649,6 @@ fn pay_reporters<T: Trait>(
 	T::Slash::on_unbalanced(value_slashed);
 }
 
-// TODO: function for undoing a slash.
-//
-
 #[cfg(test)]
 mod tests {
 	use super::*;
diff --git a/substrate/frame/support/src/weights.rs b/substrate/frame/support/src/weights.rs
index f1092b50023..e44ab164588 100644
--- a/substrate/frame/support/src/weights.rs
+++ b/substrate/frame/support/src/weights.rs
@@ -236,6 +236,36 @@ impl SimpleDispatchInfo {
 	}
 }
 
+/// A struct to represent a weight which is a function of the input arguments. The given items have
+/// the following types:
+///
+/// - `F`: a closure with the same argument list as the dispatched, wrapped in a tuple.
+/// - `DispatchClass`: class of the dispatch.
+/// - `bool`: whether this dispatch pays fee or not.
+pub struct FunctionOf<F>(pub F, pub DispatchClass, pub bool);
+
+impl<Args, F> WeighData<Args> for FunctionOf<F>
+where
+	F : Fn(Args) -> Weight
+{
+	fn weigh_data(&self, args: Args) -> Weight {
+		(self.0)(args)
+	}
+}
+
+impl<Args, F> ClassifyDispatch<Args> for FunctionOf<F> {
+	fn classify_dispatch(&self, _: Args) -> DispatchClass {
+		self.1.clone()
+	}
+}
+
+impl<T, F> PaysFee<T> for FunctionOf<F> {
+	fn pays_fee(&self, _: T) -> bool {
+		self.2
+	}
+}
+
+
 /// Implementation for unchecked extrinsic.
 impl<Address, Call, Signature, Extra> GetDispatchInfo
 	for UncheckedExtrinsic<Address, Call, Signature, Extra>
@@ -271,3 +301,46 @@ impl<Call: Encode, Extra: Encode> GetDispatchInfo for sp_runtime::testing::TestX
 		}
 	}
 }
+
+#[cfg(test)]
+#[allow(dead_code)]
+mod tests {
+	use crate::decl_module;
+	use super::*;
+
+	pub trait Trait {
+		type Origin;
+		type Balance;
+		type BlockNumber;
+	}
+
+	pub struct TraitImpl {}
+
+	impl Trait for TraitImpl {
+		type Origin = u32;
+		type BlockNumber = u32;
+		type Balance = u32;
+	}
+
+	decl_module! {
+		pub struct Module<T: Trait> for enum Call where origin: T::Origin {
+			// no arguments, fixed weight
+			#[weight = SimpleDispatchInfo::FixedNormal(1000)]
+			fn f0(_origin) { unimplemented!(); }
+
+			// weight = a x 10 + b
+			#[weight = FunctionOf(|args: (&u32, &u32)| args.0 * 10 + args.1, DispatchClass::Normal, true)]
+			fn f11(_origin, _a: u32, _eb: u32) { unimplemented!(); }
+
+			#[weight = FunctionOf(|_: (&u32, &u32)| 0, DispatchClass::Operational, true)]
+			fn f12(_origin, _a: u32, _eb: u32) { unimplemented!(); }
+		}
+	}
+
+	#[test]
+	fn weights_are_correct() {
+		assert_eq!(Call::<TraitImpl>::f11(10, 20).get_dispatch_info().weight, 120);
+		assert_eq!(Call::<TraitImpl>::f11(10, 20).get_dispatch_info().class, DispatchClass::Normal);
+		assert_eq!(Call::<TraitImpl>::f0().get_dispatch_info().weight, 1000);
+	}
+}
-- 
GitLab