diff --git a/substrate/frame/contracts/src/exec.rs b/substrate/frame/contracts/src/exec.rs index e73b29e54378b3229a71c3b11bf0f06ef3cb0fce..7aa5c0b731fad432704ab8095e493bd6f956f538 100644 --- a/substrate/frame/contracts/src/exec.rs +++ b/substrate/frame/contracts/src/exec.rs @@ -774,14 +774,27 @@ where // All changes performed by the contract are executed under a storage transaction. // This allows for roll back on error. Changes to the cached contract_info are - // comitted or rolled back when popping the frame. - let (success, output) = with_transaction(|| { - let output = do_transaction(); - match &output { - Ok(result) if !result.did_revert() => TransactionOutcome::Commit((true, output)), - _ => TransactionOutcome::Rollback((false, output)), - } - }); + // committed or rolled back when popping the frame. + // + // `with_transactional` may return an error caused by a limit in the + // transactional storage depth. + let transaction_outcome = + with_transaction(|| -> TransactionOutcome<Result<_, DispatchError>> { + let output = do_transaction(); + match &output { + Ok(result) if !result.did_revert() => + TransactionOutcome::Commit(Ok((true, output))), + _ => TransactionOutcome::Rollback(Ok((false, output))), + } + }); + + let (success, output) = match transaction_outcome { + // `with_transactional` executed successfully, and we have the expected output. + Ok((success, output)) => (success, output), + // `with_transactional` returned an error, and we propagate that error and note no state + // has changed. + Err(error) => (false, Err(error.into())), + }; self.pop_frame(success); output } diff --git a/substrate/frame/support/procedural/src/transactional.rs b/substrate/frame/support/procedural/src/transactional.rs index 66a8d083fb562f4f953cb3ffa4a0fc838bda2f7b..ba75fbc9737aa353a9fc3722d5c76443ab4bc04d 100644 --- a/substrate/frame/support/procedural/src/transactional.rs +++ b/substrate/frame/support/procedural/src/transactional.rs @@ -49,7 +49,9 @@ pub fn require_transactional(_attr: TokenStream, input: TokenStream) -> Result<T let output = quote! { #(#attrs)* #vis #sig { - #crate_::storage::require_transaction(); + if !#crate_::storage::is_transactional() { + return Err(#crate_::sp_runtime::TransactionalError::NoLayer.into()); + } #block } }; diff --git a/substrate/frame/support/src/storage/mod.rs b/substrate/frame/support/src/storage/mod.rs index 4a0eebf5679931e8a565041a8ecbaf97fb16d4f3..c9814e28a7ae401ac9afa7b2c3ce4836ab9bc23b 100644 --- a/substrate/frame/support/src/storage/mod.rs +++ b/substrate/frame/support/src/storage/mod.rs @@ -27,8 +27,11 @@ use crate::{ }; use codec::{Decode, Encode, EncodeLike, FullCodec, FullEncode}; use sp_core::storage::ChildInfo; -use sp_runtime::generic::{Digest, DigestItem}; pub use sp_runtime::TransactionOutcome; +use sp_runtime::{ + generic::{Digest, DigestItem}, + DispatchError, TransactionalError, +}; use sp_std::prelude::*; pub use types::Key; @@ -44,55 +47,65 @@ pub mod types; pub mod unhashed; pub mod weak_bounded_vec; -#[cfg(all(feature = "std", any(test, debug_assertions)))] -mod debug_helper { - use std::cell::RefCell; +mod transaction_level_tracker { + use core::sync::atomic::{AtomicU32, Ordering}; + + type Layer = u32; + static NUM_LEVELS: AtomicU32 = AtomicU32::new(0); + const TRANSACTIONAL_LIMIT: Layer = 255; - thread_local! { - static TRANSACTION_LEVEL: RefCell<u32> = RefCell::new(0); + pub fn get_transaction_level() -> Layer { + NUM_LEVELS.load(Ordering::SeqCst) } - pub fn require_transaction() { - let level = TRANSACTION_LEVEL.with(|v| *v.borrow()); - if level == 0 { - panic!("Require transaction not called within with_transaction"); - } + /// Increments the transaction level. Returns an error if levels go past the limit. + /// + /// Returns a guard that when dropped decrements the transaction level automatically. + pub fn inc_transaction_level() -> Result<StorageLayerGuard, ()> { + NUM_LEVELS + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |existing_levels| { + if existing_levels >= TRANSACTIONAL_LIMIT { + return None + } + // Cannot overflow because of check above. + Some(existing_levels + 1) + }) + .map_err(|_| ())?; + Ok(StorageLayerGuard) } - pub struct TransactionLevelGuard; + fn dec_transaction_level() { + NUM_LEVELS + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |existing_levels| { + if existing_levels == 0 { + log::warn!( + "We are underflowing with calculating transactional levels. Not great, but let's not panic...", + ); + None + } else { + // Cannot underflow because of checks above. + Some(existing_levels - 1) + } + }) + .ok(); + } - impl Drop for TransactionLevelGuard { - fn drop(&mut self) { - TRANSACTION_LEVEL.with(|v| *v.borrow_mut() -= 1); - } + pub fn is_transactional() -> bool { + get_transaction_level() > 0 } - /// Increments the transaction level. - /// - /// Returns a guard that when dropped decrements the transaction level automatically. - pub fn inc_transaction_level() -> TransactionLevelGuard { - TRANSACTION_LEVEL.with(|v| { - let mut val = v.borrow_mut(); - *val += 1; - if *val > 10 { - log::warn!( - "Detected with_transaction with nest level {}. Nested usage of with_transaction is not recommended.", - *val - ); - } - }); + pub struct StorageLayerGuard; - TransactionLevelGuard + impl Drop for StorageLayerGuard { + fn drop(&mut self) { + dec_transaction_level() + } } } -/// Assert this method is called within a storage transaction. -/// This will **panic** if is not called within a storage transaction. -/// -/// This assertion is enabled for native execution and when `debug_assertions` are enabled. -pub fn require_transaction() { - #[cfg(all(feature = "std", any(test, debug_assertions)))] - debug_helper::require_transaction(); +/// Check if the current call is within a transactional layer. +pub fn is_transactional() -> bool { + transaction_level_tracker::is_transactional() } /// Execute the supplied function in a new storage transaction. @@ -100,15 +113,55 @@ pub fn require_transaction() { /// All changes to storage performed by the supplied function are discarded if the returned /// outcome is `TransactionOutcome::Rollback`. /// -/// Transactions can be nested to any depth. Commits happen to the parent transaction. -pub fn with_transaction<R>(f: impl FnOnce() -> TransactionOutcome<R>) -> R { +/// Transactions can be nested up to `TRANSACTIONAL_LIMIT` times; more than that will result in an +/// error. +/// +/// Commits happen to the parent transaction. +pub fn with_transaction<T, E>(f: impl FnOnce() -> TransactionOutcome<Result<T, E>>) -> Result<T, E> +where + E: From<DispatchError>, +{ use sp_io::storage::{commit_transaction, rollback_transaction, start_transaction}; use TransactionOutcome::*; + let _guard = transaction_level_tracker::inc_transaction_level() + .map_err(|()| TransactionalError::LimitReached.into())?; + start_transaction(); - #[cfg(all(feature = "std", any(test, debug_assertions)))] - let _guard = debug_helper::inc_transaction_level(); + match f() { + Commit(res) => { + commit_transaction(); + res + }, + Rollback(res) => { + rollback_transaction(); + res + }, + } +} + +/// Same as [`with_transaction`] but without a limit check on nested transactional layers. +/// +/// This is mostly for backwards compatibility before there was a transactional layer limit. +/// It is recommended to only use [`with_transaction`] to avoid users from generating too many +/// transactional layers. +pub fn with_transaction_unchecked<R>(f: impl FnOnce() -> TransactionOutcome<R>) -> R { + use sp_io::storage::{commit_transaction, rollback_transaction, start_transaction}; + use TransactionOutcome::*; + + let maybe_guard = transaction_level_tracker::inc_transaction_level(); + + if maybe_guard.is_err() { + log::warn!( + "The transactional layer limit has been reached, and new transactional layers are being + spawned with `with_transaction_unchecked`. This could be caused by someone trying to + attack your chain, and you should investigate usage of `with_transaction_unchecked` and + potentially migrate to `with_transaction`, which enforces a transactional limit.", + ); + } + + start_transaction(); match f() { Commit(res) => { @@ -1418,12 +1471,13 @@ pub fn storage_prefix(pallet_name: &[u8], storage_name: &[u8]) -> [u8; 32] { #[cfg(test)] mod test { use super::*; - use crate::{assert_ok, hash::Identity, Twox128}; + use crate::{assert_noop, assert_ok, hash::Identity, Twox128}; use bounded_vec::BoundedVec; use frame_support::traits::ConstU32; use generator::StorageValue as _; use sp_core::hashing::twox_128; use sp_io::TestExternalities; + use sp_runtime::DispatchResult; use weak_bounded_vec::WeakBoundedVec; #[test] @@ -1535,25 +1589,67 @@ mod test { } #[test] - #[should_panic(expected = "Require transaction not called within with_transaction")] - fn require_transaction_should_panic() { + fn is_transactional_should_return_false() { TestExternalities::default().execute_with(|| { - require_transaction(); + assert!(!is_transactional()); }); } #[test] - fn require_transaction_should_not_panic_in_with_transaction() { + fn is_transactional_should_not_error_in_with_transaction() { TestExternalities::default().execute_with(|| { - with_transaction(|| { - require_transaction(); - TransactionOutcome::Commit(()) - }); - - with_transaction(|| { - require_transaction(); - TransactionOutcome::Rollback(()) - }); + assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> { + assert!(is_transactional()); + TransactionOutcome::Commit(Ok(())) + })); + + assert_noop!( + with_transaction(|| -> TransactionOutcome<DispatchResult> { + assert!(is_transactional()); + TransactionOutcome::Rollback(Err("revert".into())) + }), + "revert" + ); + }); + } + + fn recursive_transactional(num: u32) -> DispatchResult { + if num == 0 { + return Ok(()) + } + + with_transaction(|| -> TransactionOutcome<DispatchResult> { + let res = recursive_transactional(num - 1); + TransactionOutcome::Commit(res) + }) + } + + #[test] + fn transaction_limit_should_work() { + TestExternalities::default().execute_with(|| { + assert_eq!(transaction_level_tracker::get_transaction_level(), 0); + + assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> { + assert_eq!(transaction_level_tracker::get_transaction_level(), 1); + TransactionOutcome::Commit(Ok(())) + })); + + assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> { + assert_eq!(transaction_level_tracker::get_transaction_level(), 1); + let res = with_transaction(|| -> TransactionOutcome<DispatchResult> { + assert_eq!(transaction_level_tracker::get_transaction_level(), 2); + TransactionOutcome::Commit(Ok(())) + }); + TransactionOutcome::Commit(res) + })); + + assert_ok!(recursive_transactional(255)); + assert_noop!( + recursive_transactional(256), + sp_runtime::TransactionalError::LimitReached + ); + + assert_eq!(transaction_level_tracker::get_transaction_level(), 0); }); } diff --git a/substrate/frame/support/test/tests/storage_transaction.rs b/substrate/frame/support/test/tests/storage_transaction.rs index 0f1c3a2e0c536e1657b65183c965810828f3b6f5..848a91a7f5a868ef97e346129b7ecb2f5e8220f2 100644 --- a/substrate/frame/support/test/tests/storage_transaction.rs +++ b/substrate/frame/support/test/tests/storage_transaction.rs @@ -16,12 +16,13 @@ // limitations under the License. use frame_support::{ - assert_noop, assert_ok, + assert_noop, assert_ok, assert_storage_noop, dispatch::{DispatchError, DispatchResult}, storage::{with_transaction, TransactionOutcome::*}, transactional, StorageMap, StorageValue, }; use sp_io::TestExternalities; +use sp_runtime::TransactionOutcome; use sp_std::result; pub trait Config: frame_support_test::Config {} @@ -67,13 +68,13 @@ fn storage_transaction_basic_commit() { assert_eq!(Value::get(), 0); assert!(!Map::contains_key("val0")); - with_transaction(|| { + assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> { Value::set(99); Map::insert("val0", 99); assert_eq!(Value::get(), 99); assert_eq!(Map::get("val0"), 99); - Commit(()) - }); + Commit(Ok(())) + })); assert_eq!(Value::get(), 99); assert_eq!(Map::get("val0"), 99); @@ -86,13 +87,26 @@ fn storage_transaction_basic_rollback() { assert_eq!(Value::get(), 0); assert_eq!(Map::get("val0"), 0); - with_transaction(|| { - Value::set(99); - Map::insert("val0", 99); - assert_eq!(Value::get(), 99); - assert_eq!(Map::get("val0"), 99); - Rollback(()) - }); + assert_noop!( + with_transaction(|| -> TransactionOutcome<DispatchResult> { + Value::set(99); + Map::insert("val0", 99); + assert_eq!(Value::get(), 99); + assert_eq!(Map::get("val0"), 99); + Rollback(Err("revert".into())) + }), + "revert" + ); + + assert_storage_noop!(assert_ok!(with_transaction( + || -> TransactionOutcome<DispatchResult> { + Value::set(99); + Map::insert("val0", 99); + assert_eq!(Value::get(), 99); + assert_eq!(Map::get("val0"), 99); + Rollback(Ok(())) + } + ))); assert_eq!(Value::get(), 0); assert_eq!(Map::get("val0"), 0); @@ -105,32 +119,35 @@ fn storage_transaction_rollback_then_commit() { Value::set(1); Map::insert("val1", 1); - with_transaction(|| { + assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> { Value::set(2); Map::insert("val1", 2); Map::insert("val2", 2); - with_transaction(|| { - Value::set(3); - Map::insert("val1", 3); - Map::insert("val2", 3); - Map::insert("val3", 3); + assert_noop!( + with_transaction(|| -> TransactionOutcome<DispatchResult> { + Value::set(3); + Map::insert("val1", 3); + Map::insert("val2", 3); + Map::insert("val3", 3); - assert_eq!(Value::get(), 3); - assert_eq!(Map::get("val1"), 3); - assert_eq!(Map::get("val2"), 3); - assert_eq!(Map::get("val3"), 3); + assert_eq!(Value::get(), 3); + assert_eq!(Map::get("val1"), 3); + assert_eq!(Map::get("val2"), 3); + assert_eq!(Map::get("val3"), 3); - Rollback(()) - }); + Rollback(Err("revert".into())) + }), + "revert" + ); assert_eq!(Value::get(), 2); assert_eq!(Map::get("val1"), 2); assert_eq!(Map::get("val2"), 2); assert_eq!(Map::get("val3"), 0); - Commit(()) - }); + Commit(Ok(())) + })); assert_eq!(Value::get(), 2); assert_eq!(Map::get("val1"), 2); @@ -145,32 +162,35 @@ fn storage_transaction_commit_then_rollback() { Value::set(1); Map::insert("val1", 1); - with_transaction(|| { - Value::set(2); - Map::insert("val1", 2); - Map::insert("val2", 2); + assert_noop!( + with_transaction(|| -> TransactionOutcome<DispatchResult> { + Value::set(2); + Map::insert("val1", 2); + Map::insert("val2", 2); + + assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> { + Value::set(3); + Map::insert("val1", 3); + Map::insert("val2", 3); + Map::insert("val3", 3); - with_transaction(|| { - Value::set(3); - Map::insert("val1", 3); - Map::insert("val2", 3); - Map::insert("val3", 3); + assert_eq!(Value::get(), 3); + assert_eq!(Map::get("val1"), 3); + assert_eq!(Map::get("val2"), 3); + assert_eq!(Map::get("val3"), 3); + + Commit(Ok(())) + })); assert_eq!(Value::get(), 3); assert_eq!(Map::get("val1"), 3); assert_eq!(Map::get("val2"), 3); assert_eq!(Map::get("val3"), 3); - Commit(()) - }); - - assert_eq!(Value::get(), 3); - assert_eq!(Map::get("val1"), 3); - assert_eq!(Map::get("val2"), 3); - assert_eq!(Map::get("val3"), 3); - - Rollback(()) - }); + Rollback(Err("revert".into())) + }), + "revert" + ); assert_eq!(Value::get(), 1); assert_eq!(Map::get("val1"), 1); diff --git a/substrate/primitives/runtime/src/lib.rs b/substrate/primitives/runtime/src/lib.rs index 337fac5812aed564d2a2cc9429318739f16bbf98..c09db5124cc1f136c39e0f700bd264efaba50e1a 100644 --- a/substrate/primitives/runtime/src/lib.rs +++ b/substrate/primitives/runtime/src/lib.rs @@ -486,6 +486,31 @@ impl PartialEq for ModuleError { } } +/// Errors related to transactional storage layers. +#[derive(Eq, PartialEq, Clone, Copy, Encode, Decode, Debug, TypeInfo)] +#[cfg_attr(feature = "std", derive(Serialize, Deserialize))] +pub enum TransactionalError { + /// Too many transactional layers have been spawned. + LimitReached, + /// A transactional layer was expected, but does not exist. + NoLayer, +} + +impl From<TransactionalError> for &'static str { + fn from(e: TransactionalError) -> &'static str { + match e { + TransactionalError::LimitReached => "Too many transactional layers have been spawned", + TransactionalError::NoLayer => "A transactional layer was expected, but does not exist", + } + } +} + +impl From<TransactionalError> for DispatchError { + fn from(e: TransactionalError) -> DispatchError { + Self::Transactional(e) + } +} + /// Reason why a dispatch call failed. #[derive(Eq, Clone, Copy, Encode, Decode, Debug, TypeInfo, PartialEq)] #[cfg_attr(feature = "std", derive(Serialize, Deserialize))] @@ -512,6 +537,9 @@ pub enum DispatchError { Token(TokenError), /// An arithmetic error. Arithmetic(ArithmeticError), + /// The number of transactional layers has been reached, or we are not in a transactional + /// layer. + Transactional(TransactionalError), } /// Result of a `Dispatchable` which contains the `DispatchResult` and additional information about @@ -647,6 +675,7 @@ impl From<DispatchError> for &'static str { DispatchError::TooManyConsumers => "Too many consumers", DispatchError::Token(e) => e.into(), DispatchError::Arithmetic(e) => e.into(), + DispatchError::Transactional(e) => e.into(), } } } @@ -685,6 +714,10 @@ impl traits::Printable for DispatchError { "Arithmetic error: ".print(); <&'static str>::from(*e).print(); }, + Self::Transactional(e) => { + "Transactional error: ".print(); + <&'static str>::from(*e).print(); + }, } } }