Unverified Commit 2b66d2ce authored by thiolliere's avatar thiolliere Committed by GitHub
Browse files

Implement min_encoded_len and required_len (#119)

parent 9835ab32
Pipeline #43535 failed with stages
in 13 minutes and 23 seconds
......@@ -28,6 +28,7 @@ The `Encode` trait is used for encoding of data into the SCALE format. The `Enco
The `Decode` trait is used for deserialization/decoding of encoded data into the respective types.
* `fn min_encoded_len() -> usize`: The minimum length a valid encoded value can have.
* `fn decode<I: Input>(value: &mut I) -> Result<Self, Error>`: Tries to decode the value from SCALE format to the type it is called on. Returns an `Err` if the decoding fails.
### CompactAs
......@@ -156,4 +157,4 @@ This repository also contains an implementation of derive macros for the Parity
## License
This Rust implementation of Parity SCALE Codec is licenced under the [Apache 2 license](./LICENSE).
\ No newline at end of file
This Rust implementation of Parity SCALE Codec is licenced under the [Apache 2 license](./LICENSE).
......@@ -12,85 +12,122 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use proc_macro2::{Span, TokenStream, Ident};
use proc_macro2::{TokenStream, Ident};
use syn::{Data, Fields, Field, spanned::Spanned, Error};
use crate::utils;
use std::iter::FromIterator;
pub fn quote(data: &Data, type_name: &Ident, input: &TokenStream) -> TokenStream {
let call_site = Span::call_site();
// Encode macro use one byte to encode the index of the variant when encoding an enum.
const ENUM_VARIANT_INDEX_ENCODED_LEN: usize = 1;
pub struct Impl {
pub decode: TokenStream,
pub min_encoded_len: TokenStream,
}
pub fn quote(data: &Data, type_name: &Ident, input: &TokenStream) -> Result<Impl, Error> {
match *data {
Data::Struct(ref data) => match data.fields {
Fields::Named(_) | Fields::Unnamed(_) => create_instance(
call_site,
Fields::Named(_) | Fields::Unnamed(_) => fields_impl(
quote! { #type_name },
input,
&data.fields,
),
Fields::Unit => {
quote_spanned! {call_site =>
let decode = quote_spanned! { data.fields.span() =>
drop(#input);
Ok(#type_name)
}
};
let min_encoded_len = quote_spanned! { data.fields.span() =>
0
};
Ok(Impl { decode, min_encoded_len })
},
},
Data::Enum(ref data) => {
let data_variants = || data.variants.iter().filter(|variant| crate::utils::get_skip(&variant.attrs).is_none());
if data_variants().count() > 256 {
return Error::new(
Span::call_site(),
return Err(Error::new(
data.variants.span(),
"Currently only enums with at most 256 variants are encodable."
).to_compile_error();
));
}
let recurse = data_variants().enumerate().map(|(i, v)| {
let name = &v.ident;
let index = utils::index(v, i);
let create = create_instance(
call_site,
let impl_ = fields_impl(
quote! { #type_name :: #name },
input,
&v.fields,
);
)?;
let impl_decode = impl_.decode;
let impl_min_encoded_len = impl_.min_encoded_len;
quote_spanned! { v.span() =>
let decode = quote_spanned! { v.span() =>
x if x == #index as u8 => {
#create
#impl_decode
},
}
};
let min_encoded_len = quote_spanned! { v.span() =>
#ENUM_VARIANT_INDEX_ENCODED_LEN + #impl_min_encoded_len
};
Ok(Impl { decode, min_encoded_len })
});
let recurse: Vec<_> = Result::<_, Error>::from_iter(recurse)?;
let recurse_decode = recurse.iter().map(|i| &i.decode);
let recurse_min_encoded_len = recurse.iter().map(|i| &i.min_encoded_len);
let err_msg = format!("No such variant in enum {}", type_name);
quote! {
let decode = quote! {
match #input.read_byte()? {
#( #recurse )*
#( #recurse_decode )*
x => Err(#err_msg.into()),
}
}
};
let min_encoded_len = quote! {
let mut res = usize::max_value();
#( res = res.min( #recurse_min_encoded_len); )*
res
};
Ok(Impl { decode, min_encoded_len })
},
Data::Union(_) => Error::new(Span::call_site(), "Union types are not supported.").to_compile_error(),
Data::Union(ref data) => Err(Error::new(
data.union_token.span(),
"Union types are not supported."
)),
}
}
fn create_decode_expr(field: &Field, name: &String, input: &TokenStream) -> TokenStream {
fn field_impl(field: &Field, name: &String, input: &TokenStream) -> Result<Impl, Error> {
let encoded_as = utils::get_encoded_as_type(field);
let compact = utils::get_enable_compact(field);
let skip = utils::get_skip(&field.attrs).is_some();
if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 {
return Error::new(
Span::call_site(),
return Err(Error::new(
field.span(),
"`encoded_as`, `compact` and `skip` can only be used one at a time!"
).to_compile_error();
));
}
let err_msg = format!("Error decoding field {}", name);
if compact {
let field_type = &field.ty;
quote_spanned! { field.span() =>
let decode = quote_spanned! { field.span() =>
{
let res = <
<#field_type as _parity_scale_codec::HasCompact>::Type as _parity_scale_codec::Decode
......@@ -100,9 +137,17 @@ fn create_decode_expr(field: &Field, name: &String, input: &TokenStream) -> Toke
Ok(a) => a.into(),
}
}
}
};
let min_encoded_len = quote_spanned! { field.span() =>
<
<#field_type as _parity_scale_codec::HasCompact>::Type as _parity_scale_codec::Decode
>::min_encoded_len()
};
Ok(Impl { decode, min_encoded_len })
} else if let Some(encoded_as) = encoded_as {
quote_spanned! { field.span() =>
let decode = quote_spanned! { field.span() =>
{
let res = <#encoded_as as _parity_scale_codec::Decode>::decode(#input);
match res {
......@@ -110,11 +155,22 @@ fn create_decode_expr(field: &Field, name: &String, input: &TokenStream) -> Toke
Ok(a) => a.into(),
}
}
}
};
let min_encoded_len = quote_spanned! { field.span() =>
<#encoded_as as _parity_scale_codec::Decode>::min_encoded_len()
};
Ok(Impl { decode, min_encoded_len })
} else if skip {
quote_spanned! { field.span() => Default::default() }
let decode = quote_spanned! { field.span() => Default::default() };
let min_encoded_len = quote_spanned! { field.span() => 0 };
Ok(Impl { decode, min_encoded_len })
} else {
quote_spanned! { field.span() =>
let field_type = &field.ty;
let decode = quote_spanned! { field.span() =>
{
let res = _parity_scale_codec::Decode::decode(#input);
match res {
......@@ -122,16 +178,21 @@ fn create_decode_expr(field: &Field, name: &String, input: &TokenStream) -> Toke
Ok(a) => a,
}
}
}
};
let min_encoded_len = quote_spanned! { field.span() =>
<#field_type as _parity_scale_codec::Decode>::min_encoded_len()
};
Ok(Impl { decode, min_encoded_len })
}
}
fn create_instance(
call_site: Span,
fn fields_impl(
name: TokenStream,
input: &TokenStream,
fields: &Fields
) -> TokenStream {
) -> Result<Impl, Error> {
match *fields {
Fields::Named(ref fields) => {
let recurse = fields.named.iter().map(|f| {
......@@ -140,36 +201,73 @@ fn create_instance(
Some(a) => format!("{}.{}", name, a),
None => format!("{}", name),
};
let decode = create_decode_expr(f, &field, input);
let impl_ = field_impl(f, &field, input)?;
quote_spanned! { f.span() =>
#name_ident: #decode
}
let impl_decode = impl_.decode;
let decode = quote_spanned! { f.span() =>
#name_ident: #impl_decode
};
let impl_min_encoded_len = impl_.min_encoded_len;
let min_encoded_len = quote_spanned! { f.span() =>
#impl_min_encoded_len
};
Ok(Impl { decode, min_encoded_len })
});
quote_spanned! {call_site =>
let recurse: Vec<_> = Result::<_, Error>::from_iter(recurse)?;
let recurse_decode = recurse.iter().map(|i| &i.decode);
let recurse_min_encoded_len = recurse.iter().map(|i| &i.min_encoded_len);
let decode = quote_spanned! { fields.span() =>
Ok(#name {
#( #recurse, )*
#( #recurse_decode, )*
})
}
};
let min_encoded_len = quote_spanned! { fields.span() =>
0 #( + #recurse_min_encoded_len )*
};
Ok(Impl { decode, min_encoded_len })
},
Fields::Unnamed(ref fields) => {
let recurse = fields.unnamed.iter().enumerate().map(|(i, f) | {
let name = format!("{}.{}", name, i);
create_decode_expr(f, &name, input)
field_impl(f, &name, input)
});
quote_spanned! {call_site =>
let recurse: Vec<_> = Result::from_iter(recurse)?;
let recurse_decode = recurse.iter().map(|i| &i.decode);
let recurse_min_encoded_len = recurse.iter().map(|i| &i.min_encoded_len);
let decode = quote_spanned! { fields.span() =>
Ok(#name (
#( #recurse, )*
#( #recurse_decode, )*
))
}
};
let min_encoded_len = quote_spanned! { fields.span() =>
0 #( + #recurse_min_encoded_len )*
};
Ok(Impl { decode, min_encoded_len })
},
Fields::Unit => {
quote_spanned! {call_site =>
let decode = quote_spanned! { fields.span() =>
Ok(#name)
}
};
let min_encoded_len = quote_spanned! { fields.span() =>
0
};
Ok(Impl { decode, min_encoded_len })
},
}
}
......@@ -89,7 +89,7 @@ fn encode_fields<F>(
if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 {
return Error::new(
Span::call_site(),
f.span(),
"`encoded_as`, `compact` and `skip` can only be used one at a time!"
).to_compile_error();
}
......@@ -197,7 +197,7 @@ fn impl_encode(data: &Data, type_name: &Ident) -> TokenStream {
if data_variants().count() > 256 {
return Error::new(
Span::call_site(),
data.variants.span(),
"Currently only enums with at most 256 variants are encodable."
).to_compile_error();
}
......@@ -274,7 +274,10 @@ fn impl_encode(data: &Data, type_name: &Ident) -> TokenStream {
}
}
},
Data::Union(_) => Error::new(Span::call_site(), "Union types are not supported.").to_compile_error(),
Data::Union(ref data) => Error::new(
data.union_token.span(),
"Union types are not supported."
).to_compile_error(),
};
quote! {
......
......@@ -132,14 +132,24 @@ pub fn decode_derive(input: TokenStream) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let input_ = quote!(input);
let decoding = decode::quote(&input.data, name, &input_);
let impl_ = match decode::quote(&input.data, name, &input_) {
Ok(impl_) => impl_,
Err(e) => return e.to_compile_error().into(),
};
let decode = impl_.decode;
let min_encoded_len = impl_.min_encoded_len;
let impl_block = quote! {
impl #impl_generics _parity_scale_codec::Decode for #name #ty_generics #where_clause {
fn min_encoded_len() -> usize {
#min_encoded_len
}
fn decode<DecIn: _parity_scale_codec::Input>(
#input_: &mut DecIn
) -> Result<Self, _parity_scale_codec::Error> {
#decoding
#decode
}
}
};
......@@ -176,7 +186,6 @@ pub fn compact_as_derive(input: TokenStream) -> TokenStream {
}
}
let call_site = Span::call_site();
let (inner_ty, inner_field, constructor) = match input.data {
Data::Struct(ref data) => {
match data.fields {
......@@ -188,7 +197,7 @@ pub fn compact_as_derive(input: TokenStream) -> TokenStream {
});
let field = utils::filter_skip_named(fields).next().expect("Exactly one field");
let field_name = &field.ident;
let constructor = quote_spanned!(call_site=> #name { #( #recurse, )* });
let constructor = quote!( #name { #( #recurse, )* } );
(&field.ty, quote!(&self.#field_name), constructor)
},
Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
......@@ -198,22 +207,20 @@ pub fn compact_as_derive(input: TokenStream) -> TokenStream {
});
let (id, field) = utils::filter_skip_unnamed(fields).next().expect("Exactly one field");
let id = syn::Index::from(id);
let constructor = quote_spanned!(call_site=> #name(#( #recurse, )*));
let constructor = quote!( #name(#( #recurse, )*) );
(&field.ty, quote!(&self.#id), constructor)
},
_ => {
return Error::new(
Span::call_site(),
data.fields.span(),
"Only structs with a single non-skipped field can derive CompactAs"
).to_compile_error().into();
},
}
},
_ => {
return Error::new(
Span::call_site(),
"Only structs can derive CompactAs"
).to_compile_error().into();
Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) |
Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) => {
return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into();
},
};
......
......@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use proc_macro2::Span;
use syn::{Generics, Ident, visit::{Visit, self}, Type, TypePath};
use syn::{Generics, Ident, visit::{Visit, self}, Type, TypePath, spanned::Spanned};
use std::iter;
/// Visits the ast and checks if one of the given idents is found.
......@@ -226,7 +225,10 @@ fn collect_types(
}
}).collect(),
Data::Union(_) => return Err(Error::new(Span::call_site(), "Union types are not supported.")),
Data::Union(ref data) => return Err(Error::new(
data.union_token.span(),
"Union types are not supported."
)),
};
Ok(types)
......
......@@ -47,6 +47,10 @@ impl<C: Cursor, T: Bits + ToByteSlice> Encode for BitVec<C, T> {
}
impl<C: Cursor, T: Bits + FromByteSlice> Decode for BitVec<C, T> {
fn min_encoded_len() -> usize {
<Compact<u32>>::min_encoded_len()
}
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
<Compact<u32>>::decode(input).and_then(move |Compact(bits)| {
let bits = bits as usize;
......@@ -69,6 +73,10 @@ impl<C: Cursor, T: Bits + ToByteSlice> Encode for BitBox<C, T> {
}
impl<C: Cursor, T: Bits + FromByteSlice> Decode for BitBox<C, T> {
fn min_encoded_len() -> usize {
<BitVec<C, T>>::min_encoded_len()
}
fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
Ok(Self::from_bitslice(BitVec::<C, T>::decode(input)?.as_bitslice()))
}
......@@ -84,6 +92,7 @@ fn required_bytes<T>(bits: usize) -> usize {
mod tests {
use super::*;
use bitvec::{bitvec, cursor::BigEndian};
use crate::codec::DecodeM;
macro_rules! test_data {
($inner_type: ty) => (
......@@ -146,7 +155,7 @@ mod tests {
fn bitvec_u8() {
for v in &test_data!(u8) {
let encoded = v.encode();
assert_eq!(*v, BitVec::<BigEndian, u8>::decode(&mut &encoded[..]).unwrap());
assert_eq!(*v, BitVec::<BigEndian, u8>::decode_m(&mut &encoded[..]).unwrap());
}
}
......@@ -154,7 +163,7 @@ mod tests {
fn bitvec_u16() {
for v in &test_data!(u16) {
let encoded = v.encode();
assert_eq!(*v, BitVec::<BigEndian, u16>::decode(&mut &encoded[..]).unwrap());
assert_eq!(*v, BitVec::<BigEndian, u16>::decode_m(&mut &encoded[..]).unwrap());
}
}
......@@ -162,7 +171,7 @@ mod tests {
fn bitvec_u32() {
for v in &test_data!(u32) {
let encoded = v.encode();
assert_eq!(*v, BitVec::<BigEndian, u32>::decode(&mut &encoded[..]).unwrap());
assert_eq!(*v, BitVec::<BigEndian, u32>::decode_m(&mut &encoded[..]).unwrap());
}
}
......@@ -170,7 +179,7 @@ mod tests {
fn bitvec_u64() {
for v in &test_data!(u64) {
let encoded = v.encode();
assert_eq!(*v, BitVec::<BigEndian, u64>::decode(&mut &encoded[..]).unwrap());
assert_eq!(*v, BitVec::<BigEndian, u64>::decode_m(&mut &encoded[..]).unwrap());
}
}
......@@ -179,7 +188,7 @@ mod tests {
let data: &[u8] = &[0x69];
let slice: &BitSlice = data.into();
let encoded = slice.encode();
let decoded = BitVec::<BigEndian, u8>::decode(&mut &encoded[..]).unwrap();
let decoded = BitVec::<BigEndian, u8>::decode_m(&mut &encoded[..]).unwrap();
assert_eq!(slice, decoded.as_bitslice());
}
......@@ -188,7 +197,7 @@ mod tests {
let data: &[u8] = &[5, 10];
let bb: BitBox = data.into();
let encoded = bb.encode();
let decoded = BitBox::<BigEndian, u8>::decode(&mut &encoded[..]).unwrap();
let decoded = BitBox::<BigEndian, u8>::decode_m(&mut &encoded[..]).unwrap();
assert_eq!(bb, decoded);
}
}
......@@ -94,6 +94,10 @@ impl From<&'static str> for Error {
/// Trait that allows reading of data into a slice.
pub trait Input {
/// Require the input to be at least the len specified. This allow to ensure a valid value can
/// be constructed with the given input, thus allowing allocating memory upfront.
fn require_min_len(&mut self, len: usize) -> Result<(), Error>;
/// Read the exact number of bytes required to fill the given buffer.
///
/// Note that this function is similar to `std::io::Read::read_exact` and not
......@@ -108,8 +112,15 @@ pub trait Input {
}
}
#[cfg(not(feature = "std"))]
impl<'a> Input for &'a [u8] {
fn require_min_len(&mut self, len: usize) -> Result<(), Error> {
if self.len() < len {
return Err("Not enough data for required minimum length".into());
}
Ok(())
}
fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
if into.len() > self.len() {
return Err("Not enough data to fill buffer".into());
......@@ -129,13 +140,49 @@ impl From<std::io::Error> for Error {
}
#[cfg(feature = "std")]
impl<R: std::io::Read> Input for R {
pub struct IoReader<R: std::io::Read> {
buffer: Vec<u8>,
reader: R,
}
// TODO TODO: either test it or remove it
#[cfg(feature = "std")]
impl<R: std::io::Read> Input for IoReader<R> {
fn require_min_len(&mut self, len: usize) -> Result<(), Error> {
if self.buffer.len() >= len {
return Ok(())
}
let filled_len = self.buffer.len();
self.buffer.resize(len, 0);
self.reader.read_exact(&mut self.buffer[filled_len..])?;
Ok(())
}
fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
(self as &mut dyn std::io::Read).read_exact(into)?;
let into_len = into.len();
let buffer_len = self.buffer.len();
into.copy_from_slice(&self.buffer[..into_len.min(buffer_len)]);
self.buffer.resize(buffer_len - into_len.min(buffer_len), 0);
if into_len > buffer_len {
self.reader.read_exact(&mut into[buffer_len..])?;
}
Ok(())
}
}
#[cfg(feature = "std")]
impl<R: std::io::Read> From<R> for IoReader<R> {
fn from(reader: R) -> Self {
IoReader {
buffer: vec![],
reader,
}
}
}
/// Trait that allows writing of data.
pub trait Output: Sized {
/// Write to the output.
......@@ -231,6 +278,10 @@ pub trait Decode: Sized {
#[doc(hidden)]
const IS_U8: IsU8 = IsU8::No;
/// The minimum length of an encoded value. This is used to prevent allocating memory on
/// crafted wrong input.
fn min_encoded_len() -> usize;
/// Attempt to deserialise the value from input.
fn decode<I: Input>(value: &mut I) -> Result<Self, Error>;
}
......@@ -304,6 +355,10 @@ impl<T, X> Decode for X where
T: Decode + Into<X>,