Commit 007f878d authored by thiolliere's avatar thiolliere Committed by Bastian Köcher
Browse files

Bound preallocation to input len (#139)

* impl bound to input_len

* typo

* fix + doc

* fix

* hide alloc

* address issues
parent 648e34f2
Pipeline #44657 passed with stages
in 16 minutes and 55 seconds
......@@ -75,7 +75,7 @@ fn encode_single_field(
_parity_scale_codec::Encode::encode_to(&#final_field_variable, dest)
}
fn encode(&self) -> Vec<u8> {
fn encode(&self) -> _parity_scale_codec::alloc::vec::Vec<u8> {
_parity_scale_codec::Encode::encode(&#final_field_variable)
}
......
......@@ -95,6 +95,9 @@ impl From<&'static str> for Error {
/// Trait that allows reading of data into a slice.
pub trait Input {
/// Return remaining length of input.
fn remaining_len(&mut self) -> Result<usize, Error>;
/// Read the exact number of bytes required to fill the given buffer.
///
/// Note that this function is similar to `std::io::Read::read_exact` and not
......@@ -109,8 +112,11 @@ pub trait Input {
}
}
#[cfg(not(feature = "std"))]
impl<'a> Input for &'a [u8] {
fn remaining_len(&mut self) -> Result<usize, Error> {
Ok(self.len())
}
fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
if into.len() > self.len() {
return Err("Not enough data to fill buffer".into());
......@@ -150,11 +156,32 @@ impl From<std::io::Error> for Error {
}
}
/// Wrapper that implements Input for any `Read` and `Seek` type.
#[cfg(feature = "std")]
pub struct IoReader<R: std::io::Read + std::io::Seek>(pub R);
#[cfg(feature = "std")]
impl<R: std::io::Read> Input for R {
impl<R: std::io::Read + std::io::Seek> Input for IoReader<R> {
fn remaining_len(&mut self) -> Result<usize, Error> {
use std::convert::TryInto;
use std::io::SeekFrom;
let old_pos = self.0.seek(SeekFrom::Current(0))?;
let len = self.0.seek(SeekFrom::End(0))?;
// Avoid seeking a third time when we were already at the end of the
// stream. The branch is usually way cheaper than a seek operation.
if old_pos != len {
self.0.seek(SeekFrom::Start(old_pos))?;
}
len.saturating_sub(old_pos)
.try_into()
.map_err(|_| "Input cannot fit into usize length".into())
}
fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
(self as &mut dyn std::io::Read).read_exact(into)?;
Ok(())
self.0.read_exact(into).map_err(Into::into)
}
}
......@@ -554,13 +581,18 @@ impl<T: Decode> Decode for Vec<T> {
<Compact<u32>>::decode(input).and_then(move |Compact(len)| {
let len = len as usize;
if let IsU8::Yes = <T as Decode>::IS_U8 {
let mut r = vec![0; len];
if len > input.remaining_len()? {
return Err("Not enough data to decode vector".into());
}
input.read(&mut r[..len])?;
let r = unsafe { core::mem::transmute::<Vec<u8>, Vec<T>>(r) };
let mut r = vec![0; len];
input.read(&mut r)?;
let r = unsafe { mem::transmute::<Vec<u8>, Vec<T>>(r) };
Ok(r)
} else {
let mut r = Vec::with_capacity(len);
let capacity = input.remaining_len()?.checked_div(mem::size_of::<T>())
.unwrap_or(0);
let mut r = Vec::with_capacity(capacity);
for _ in 0..len {
r.push(T::decode(input)?);
}
......@@ -1150,4 +1182,23 @@ mod tests {
);
assert_eq!(Decode::decode(&mut &t8.encode()[..]), Ok(t8));
}
#[test]
fn io_reader() {
use std::io::{Seek, SeekFrom};
let mut io_reader = IoReader(std::io::Cursor::new(&[1u8, 2, 3][..]));
assert_eq!(io_reader.0.seek(SeekFrom::Current(0)).unwrap(), 0);
assert_eq!(io_reader.remaining_len().unwrap(), 3);
assert_eq!(io_reader.read_byte().unwrap(), 1);
assert_eq!(io_reader.0.seek(SeekFrom::Current(0)).unwrap(), 1);
assert_eq!(io_reader.remaining_len().unwrap(), 2);
assert_eq!(io_reader.read_byte().unwrap(), 2);
assert_eq!(io_reader.read_byte().unwrap(), 3);
assert_eq!(io_reader.0.seek(SeekFrom::Current(0)).unwrap(), 3);
assert_eq!(io_reader.remaining_len().unwrap(), 0);
}
}
......@@ -46,12 +46,15 @@ struct PrefixInput<'a, T> {
}
impl<'a, T: 'a + Input> Input for PrefixInput<'a, T> {
fn remaining_len(&mut self) -> Result<usize, Error> {
Ok(self.input.remaining_len()?.saturating_add(self.prefix.iter().count()))
}
fn read(&mut self, buffer: &mut [u8]) -> Result<(), Error> {
match self.prefix.take() {
Some(v) if !buffer.is_empty() => {
buffer[0] = v;
self.input.read(&mut buffer[1..])?;
Ok(())
self.input.read(&mut buffer[1..])
}
_ => self.input.read(buffer)
}
......
......@@ -209,7 +209,8 @@
#[cfg(not(feature = "std"))]
#[macro_use]
extern crate alloc;
#[doc(hidden)]
pub extern crate alloc;
#[cfg(feature = "parity-scale-codec-derive")]
#[allow(unused_imports)]
......@@ -248,6 +249,8 @@ pub use self::codec::{
Input, Output, Error, Encode, Decode, Codec, EncodeAsRef, EncodeAppend, WrapperTypeEncode,
WrapperTypeDecode, OptionBool,
};
#[cfg(feature = "std")]
pub use self::codec::IoReader;
pub use self::compact::{Compact, HasCompact, CompactAs};
pub use self::joiner::Joiner;
pub use self::keyedvec::KeyedVec;
......
......@@ -491,3 +491,19 @@ fn recursive_type() {
}
}
#[test]
fn crafted_input_for_vec_u8() {
assert_eq!(
Vec::<u8>::decode(&mut &Compact(u32::max_value()).encode()[..]).err().unwrap().what(),
"Not enough data to decode vector"
);
}
#[test]
fn crafted_input_for_vec_t() {
assert_eq!(
Vec::<u32>::decode(&mut &Compact(u32::max_value()).encode()[..]).err().unwrap().what(),
"Not enough data to fill buffer"
);
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment