Unverified Commit 46c19554 authored by Bastian Köcher's avatar Bastian Köcher Committed by GitHub
Browse files

Adds support for custom where bounds (#263)



* Adds support for custom where bounds

The user can now specify a custom where bound when using the derive
macros:

- `#[codec(encode_bound(T: Encode))]` for `Encode`
- `#[codec(decode_bound(T: Encode))]` for `Decode`

If nothing is specified (`encode_bound()`) the where bounds will be empty.

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: default avatarGavin Wood <gavin@parity.io>
parent ab18de24
Pipeline #132931 passed with stages
in 21 minutes and 41 seconds
......@@ -4,6 +4,13 @@ All notable changes to this crate are documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this crate adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [2.1.0] - 2021-04-06
### Fix
- Add support for custom where bounds `codec(encode_bound(T: Encode))` and `codec(decode_bound(T: Decode))` when
deriving the traits. Pr #262
- Switch to const generics for array implementations. Pr #261
## [2.0.1] - 2021-02-26
### Fix
......
[package]
name = "parity-scale-codec"
description = "SCALE - Simple Concatenating Aggregated Little Endians"
version = "2.0.1"
version = "2.1.0"
authors = ["Parity Technologies <admin@parity.io>"]
license = "Apache-2.0"
repository = "https://github.com/paritytech/parity-scale-codec"
......
[package]
name = "parity-scale-codec-derive"
description = "Serialization and deserialization derive macro for Parity SCALE Codec"
version = "2.0.1"
version = "2.1.0"
authors = ["Parity Technologies <admin@parity.io>"]
license = "Apache-2.0"
edition = "2018"
......
......@@ -80,6 +80,8 @@ fn wrap_with_dummy_const(impl_block: proc_macro2::TokenStream) -> proc_macro::To
/// type must implement `parity_scale_codec::EncodeAsRef<'_, $FieldType>` with $FieldType the
/// type of the field with the attribute. This is intended to be used for types implementing
/// `HasCompact` as shown in the example.
/// * `#[codec(encode_bound(T: Encode))]`: a custom where bound that will be used when deriving the `Encode` trait.
/// * `#[codec(decode_bound(T: Encode))]`: a custom where bound that will be used when deriving the `Decode` trait.
///
/// ```
/// # use parity_scale_codec_derive::Encode;
......@@ -140,7 +142,9 @@ pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream
return e.to_compile_error().into();
}
if let Err(e) = trait_bounds::add(
if let Some(custom_bound) = utils::custom_encode_trait_bound(&input.attrs) {
input.generics.make_where_clause().predicates.extend(custom_bound);
} else if let Err(e) = trait_bounds::add(
&input.ident,
&mut input.generics,
&input.data,
......@@ -181,7 +185,9 @@ pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream
return e.to_compile_error().into();
}
if let Err(e) = trait_bounds::add(
if let Some(custom_bound) = utils::custom_decode_trait_bound(&input.attrs) {
input.generics.make_where_clause().predicates.extend(custom_bound);
} else if let Err(e) = trait_bounds::add(
&input.ident,
&mut input.generics,
&input.data,
......
......@@ -23,24 +23,15 @@ use proc_macro2::TokenStream;
use syn::{
spanned::Spanned,
Meta, NestedMeta, Lit, Attribute, Variant, Field, DeriveInput, Fields, Data, FieldsUnnamed,
FieldsNamed, MetaNameValue
FieldsNamed, MetaNameValue, punctuated::Punctuated, token, parse::Parse,
};
fn find_meta_item<'a, F, R, I>(itr: I, pred: F) -> Option<R> where
F: FnMut(&NestedMeta) -> Option<R> + Clone,
I: Iterator<Item=&'a Attribute>
fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option<R> where
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item=&'a Attribute>,
M: Parse,
{
itr.filter_map(|attr| {
if attr.path.is_ident("codec") {
if let Meta::List(ref meta_list) = attr.parse_meta()
.expect("Internal error, parse_meta must have been checked")
{
return meta_list.nested.iter().filter_map(pred.clone()).next();
}
}
None
}).next()
itr.find_map(|attr| attr.path.is_ident("codec").then(|| pred(attr.parse_args().ok()?)).flatten())
}
/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
......@@ -128,6 +119,48 @@ pub fn has_dumb_trait_bound(attrs: &[Attribute]) -> bool {
}).is_some()
}
/// Trait bounds.
pub type TraitBounds = Punctuated<syn::WherePredicate, token::Comma>;
/// Parse `name(T: Bound, N: Bound)` as a custom trait bound.
struct CustomTraitBound<N> {
_name: N,
_paren_token: token::Paren,
bounds: TraitBounds,
}
impl<N: Parse> Parse for CustomTraitBound<N> {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let content;
Ok(Self {
_name: input.parse()?,
_paren_token: syn::parenthesized!(content in input),
bounds: content.parse_terminated(syn::WherePredicate::parse)?,
})
}
}
syn::custom_keyword!(encode_bound);
syn::custom_keyword!(decode_bound);
/// Look for a `#[codec(decode_bound(T: Decode))]`in the given attributes.
///
/// If found, it should be used as trait bounds when deriving the `Decode` trait.
pub fn custom_decode_trait_bound(attrs: &[Attribute]) -> Option<TraitBounds> {
find_meta_item(attrs.iter(), |meta: CustomTraitBound<decode_bound>| {
Some(meta.bounds)
})
}
/// Look for a `#[codec(encode_bound(T: Encode))]`in the given attributes.
///
/// If found, it should be used as trait bounds when deriving the `Encode` trait.
pub fn custom_encode_trait_bound(attrs: &[Attribute]) -> Option<TraitBounds> {
find_meta_item(attrs.iter(), |meta: CustomTraitBound<encode_bound>| {
Some(meta.bounds)
})
}
/// Given a set of named fields, return an iterator of `Field` where all fields
/// marked `#[codec(skip)]` are filtered out.
pub fn filter_skip_named<'a>(fields: &'a syn::FieldsNamed) -> impl Iterator<Item=&Field> + 'a {
......@@ -252,9 +285,15 @@ fn check_variant_attribute(attr: &Attribute) -> syn::Result<()> {
// Only `#[codec(dumb_trait_bound)]` is accepted as top attribute
fn check_top_attribute(attr: &Attribute) -> syn::Result<()> {
let top_error = "Invalid attribute only `#[codec(dumb_trait_bound)]` is accepted as top \
attribute";
let top_error =
"Invalid attribute only `#[codec(dumb_trait_bound)]`, `#[codec(encode_bound(T: Encode))]` or \
`#[codec(decode_bound(T: Decode))]` are accepted as top attribute";
if attr.path.is_ident("codec") {
if attr.parse_args::<CustomTraitBound<encode_bound>>().is_ok() {
return Ok(())
} else if attr.parse_args::<CustomTraitBound<decode_bound>>().is_ok() {
return Ok(())
} else {
match attr.parse_meta()? {
Meta::List(ref meta_list) if meta_list.nested.len() == 1 => {
match meta_list.nested.first().expect("Just checked that there is one item; qed") {
......@@ -264,7 +303,8 @@ fn check_top_attribute(attr: &Attribute) -> syn::Result<()> {
elt @ _ => Err(syn::Error::new(elt.span(), top_error)),
}
},
meta @ _ => Err(syn::Error::new(meta.span(), top_error)),
_ => Err(syn::Error::new(attr.span(), top_error)),
}
}
} else {
Ok(())
......
......@@ -558,3 +558,31 @@ fn weird_derive() {
fn output_trait_object() {
let _: Box<dyn Output>;
}
#[test]
fn custom_trait_bound() {
#[derive(Encode, Decode)]
#[codec(encode_bound(N: Encode, T: Default))]
#[codec(decode_bound(N: Decode, T: Default))]
struct Something<T, N> {
hello: Hello<T>,
val: N,
}
#[derive(Encode, Decode)]
#[codec(encode_bound())]
#[codec(decode_bound())]
struct Hello<T> {
_phantom: std::marker::PhantomData<T>,
}
#[derive(Default)]
struct NotEncode;
let encoded = Something::<NotEncode, u32> {
hello: Hello { _phantom: Default::default() },
val: 32u32,
}.encode();
Something::<NotEncode, u32>::decode(&mut &encoded[..]).unwrap();
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment