depth_limit.rs 2.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// Copyright 2017, 2018 Parity Technologies
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::codec::{Error, Decode, Input};

/// The error message returned when depth limit is reached.
const DECODE_MAX_DEPTH_MSG: &str = "Maximum recursion depth reached when decoding";

Bastian Köcher's avatar
Bastian Köcher committed
20
/// Extension trait to [`Decode`] for decoding with a maximum recursion depth.
21
pub trait DecodeLimit: Sized {
Bastian Köcher's avatar
Bastian Köcher committed
22
23
24
	/// Decode `Self` with the given maximum recursion depth.
	///
	/// If `limit` is hit, an error is returned.
25
	fn decode_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error>;
Bastian Köcher's avatar
Bastian Köcher committed
26
27
28
29

	/// Decode `Self` and consume all of the given input data.
	///
	/// If not all data is consumed or `limit` is hit, an error is returned.
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
	fn decode_all_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error>;
}


struct DepthTrackingInput<'a, I> {
	input: &'a mut I,
	depth: u32,
	max_depth: u32,
}

impl<'a, I:Input> Input for DepthTrackingInput<'a, I> {
	fn remaining_len(&mut self) -> Result<Option<usize>, Error> {
		self.input.remaining_len()
	}

	fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
		self.input.read(into)
	}

	fn read_byte(&mut self) -> Result<u8, Error> {
		self.input.read_byte()
	}

	fn descend_ref(&mut self) -> Result<(), Error> {
		self.input.descend_ref()?;
		self.depth += 1;
		if self.depth > self.max_depth {
			Err(DECODE_MAX_DEPTH_MSG.into())
		} else {
			Ok(())
		}
	}

	fn ascend_ref(&mut self) {
		self.input.ascend_ref();
		self.depth -= 1;
	}
}

impl<T: Decode> DecodeLimit for T {
	fn decode_all_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error> {
		let mut input = DepthTrackingInput {
			input: &mut &input[..],
			depth: 0,
			max_depth: limit,
		};
		let res = T::decode(&mut input)?;

		if input.input.is_empty() {
			Ok(res)
		} else {
			Err(crate::decode_all::DECODE_ALL_ERR_MSG.into())
		}
	}

	fn decode_with_depth_limit(limit: u32, input: &[u8]) -> Result<Self, Error> {
		let mut input = DepthTrackingInput {
			input: &mut &input[..],
			depth: 0,
			max_depth: limit,
		};
		T::decode(&mut input)
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use crate::Encode;

	#[test]
	fn decode_limit_works() {
		type NestedVec = Vec<Vec<Vec<Vec<u8>>>>;
		let nested: NestedVec = vec![vec![vec![vec![1]]]];
		let encoded = nested.encode();

		let decoded = NestedVec::decode_with_depth_limit(3, &encoded).unwrap();
		assert_eq!(decoded, nested);
		assert!(NestedVec::decode_with_depth_limit(2, &encoded).is_err());
	}
}