Commit 215cda34 authored by Andrew Jones's avatar Andrew Jones Committed by Bastian Köcher
Browse files

Derive `CompactAs` for single field structs (#97)

* Basic CompactAs derive

* Support single field named structs

* Extract function to wrap impl in dummy const

* Support single non-skipped field

* Restore original include_parity_scale_codec_crate function

* Add test for Compact<T> codec
parent 23a292df
Pipeline #40249 passed with stages
in 12 minutes and 14 seconds
......@@ -133,22 +133,11 @@ fn encode_fields<F>(
fn try_impl_encode_single_field_optimisation(data: &Data) -> Option<TokenStream> {
let closure = &quote!(f);
fn filter_skip_named<'a>(fields: &'a syn::FieldsNamed) -> impl Iterator<Item=&Field> + 'a {
fields.named.iter()
.filter(|f| utils::get_skip(&f.attrs).is_none())
}
fn filter_skip_unnamed<'a>(fields: &'a syn::FieldsUnnamed) -> impl Iterator<Item=(usize, &Field)> + 'a {
fields.unnamed.iter()
.enumerate()
.filter(|(_, f)| utils::get_skip(&f.attrs).is_none())
}
let optimisation = match *data {
Data::Struct(ref data) => {
match data.fields {
Fields::Named(ref fields) if filter_skip_named(fields).count() == 1 => {
let field = filter_skip_named(fields).next().unwrap();
Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
let field = utils::filter_skip_named(fields).next().unwrap();
let name = &field.ident;
Some(encode_single_field(
closure,
......@@ -156,8 +145,8 @@ fn try_impl_encode_single_field_optimisation(data: &Data) -> Option<TokenStream>
quote!(&self.#name)
))
},
Fields::Unnamed(ref fields) if filter_skip_unnamed(fields).count() == 1 => {
let (id, field) = filter_skip_unnamed(fields).next().unwrap();
Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
let (id, field) = utils::filter_skip_unnamed(fields).next().unwrap();
let id = syn::Index::from(id);
Some(encode_single_field(
......
......@@ -26,7 +26,7 @@ extern crate quote;
use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::{DeriveInput, Ident, parse::Error};
use syn::{Data, Field, Fields, DeriveInput, Ident, parse::Error, spanned::Spanned};
use proc_macro_crate::crate_name;
use std::env;
......@@ -52,6 +52,26 @@ fn include_parity_scale_codec_crate() -> proc_macro2::TokenStream {
}
}
/// Wraps the impl block in a "dummy const"
fn wrap_with_dummy_const(input: &DeriveInput, prefix: &str, impl_block: proc_macro2::TokenStream) -> TokenStream {
let parity_codec_crate = include_parity_scale_codec_crate();
let mut new_name = prefix.to_string();
new_name.push_str(input.ident.to_string().trim_start_matches("r#"));
let dummy_const = Ident::new(&new_name, Span::call_site());
let generated = quote! {
#[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
const #dummy_const: () = {
#[allow(unknown_lints)]
#[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))]
#[allow(rust_2018_idioms)]
#parity_codec_crate
#impl_block
};
};
generated.into()
}
#[proc_macro_derive(Encode, attributes(codec))]
pub fn encode_derive(input: TokenStream) -> TokenStream {
let mut input: DeriveInput = match syn::parse(input) {
......@@ -84,23 +104,7 @@ pub fn encode_derive(input: TokenStream) -> TokenStream {
}
};
let mut new_name = "_IMPL_ENCODE_FOR_".to_string();
new_name.push_str(name.to_string().trim_start_matches("r#"));
let dummy_const = Ident::new(&new_name, Span::call_site());
let parity_codec_crate = include_parity_scale_codec_crate();
let generated = quote! {
#[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
const #dummy_const: () = {
#[allow(unknown_lints)]
#[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))]
#[allow(rust_2018_idioms)]
#parity_codec_crate
#impl_block
};
};
generated.into()
wrap_with_dummy_const(&input, "_IMPL_ENCODE_FOR_", impl_block)
}
#[proc_macro_derive(Decode, attributes(codec))]
......@@ -140,21 +144,96 @@ pub fn decode_derive(input: TokenStream) -> TokenStream {
}
};
let mut new_name = "_IMPL_DECODE_FOR_".to_string();
new_name.push_str(name.to_string().trim_start_matches("r#"));
let dummy_const = Ident::new(&new_name, Span::call_site());
let parity_codec_crate = include_parity_scale_codec_crate();
wrap_with_dummy_const(&input, "_IMPL_DECODE_FOR_", impl_block)
}
let generated = quote! {
#[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
const #dummy_const: () = {
#[allow(unknown_lints)]
#[cfg_attr(feature = "cargo-clippy", allow(useless_attribute))]
#[allow(rust_2018_idioms)]
#parity_codec_crate
#impl_block
};
#[proc_macro_derive(CompactAs, attributes(codec))]
pub fn compact_as_derive(input: TokenStream) -> TokenStream {
let mut input: DeriveInput = match syn::parse(input) {
Ok(input) => input,
Err(e) => return e.to_compile_error().into(),
};
generated.into()
if let Err(e) = trait_bounds::add(
&input.ident,
&mut input.generics,
&input.data,
parse_quote!(_parity_scale_codec::CompactAs),
None,
) {
return e.to_compile_error().into();
}
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
fn val_or_default(field: &Field) -> proc_macro2::TokenStream {
let skip = utils::get_skip(&field.attrs).is_some();
if skip {
quote_spanned!(field.span()=> Default::default())
} else {
quote_spanned!(field.span()=> x)
}
}
let call_site = Span::call_site();
let (inner_ty, inner_field, constructor) = match input.data {
Data::Struct(ref data) => {
match data.fields {
Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
let recurse = fields.named.iter().map(|f| {
let name_ident = &f.ident;
let val_or_default = val_or_default(&f);
quote_spanned!(f.span()=> #name_ident: #val_or_default)
});
let field = utils::filter_skip_named(fields).next().expect("Exactly one field");
let field_name = &field.ident;
let constructor = quote_spanned!(call_site=> #name { #( #recurse, )* });
(&field.ty, quote!(&self.#field_name), constructor)
},
Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
let recurse = fields.unnamed.iter().enumerate().map(|(_, f) | {
let val_or_default = val_or_default(&f);
quote_spanned!(f.span()=> #val_or_default)
});
let (id, field) = utils::filter_skip_unnamed(fields).next().expect("Exactly one field");
let id = syn::Index::from(id);
let constructor = quote_spanned!(call_site=> #name(#( #recurse, )*));
(&field.ty, quote!(&self.#id), constructor)
},
_ => {
return Error::new(
Span::call_site(),
"Only structs with a single non-skipped field can derive CompactAs"
).to_compile_error().into();
},
}
},
_ => {
return Error::new(
Span::call_site(),
"Only structs can derive CompactAs"
).to_compile_error().into();
},
};
let impl_block = quote! {
impl #impl_generics _parity_scale_codec::CompactAs for #name #ty_generics #where_clause {
type As = #inner_ty;
fn encode_as(&self) -> &#inner_ty {
#inner_field
}
fn decode_from(x: #inner_ty) -> #name #ty_generics {
#constructor
}
}
impl #impl_generics From<_parity_scale_codec::Compact<#name #ty_generics>> for #name #ty_generics #where_clause {
fn from(x: _parity_scale_codec::Compact<#name #ty_generics>) -> #name #ty_generics {
x.0
}
}
};
wrap_with_dummy_const(&input, "_IMPL_COMPACTAS_FOR_", impl_block)
}
......@@ -102,3 +102,14 @@ pub fn get_skip(attrs: &Vec<Attribute>) -> Option<Span> {
None
})
}
pub fn filter_skip_named<'a>(fields: &'a syn::FieldsNamed) -> impl Iterator<Item=&Field> + 'a {
fields.named.iter()
.filter(|f| get_skip(&f.attrs).is_none())
}
pub fn filter_skip_unnamed<'a>(fields: &'a syn::FieldsUnnamed) -> impl Iterator<Item=(usize, &Field)> + 'a {
fields.unnamed.iter()
.enumerate()
.filter(|(_, f)| get_skip(&f.attrs).is_none())
}
#[macro_use]
extern crate parity_scale_codec_derive;
use parity_scale_codec::{Encode, HasCompact, Decode};
use parity_scale_codec::{Compact, Encode, HasCompact, Decode};
use serde_derive::{Serialize, Deserialize};
#[derive(Debug, PartialEq, Encode, Decode)]
struct S {
x: u32,
}
#[derive(Debug, PartialEq, Encode, Decode)]
#[cfg_attr(feature = "std", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Encode, Decode, CompactAs)]
struct SSkip {
#[codec(skip)]
s1: u32,
......@@ -29,15 +31,30 @@ struct Sh<T: HasCompact> {
x: T,
}
#[derive(Debug, PartialEq, Encode, Decode)]
#[cfg_attr(feature = "std", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Encode, Decode, CompactAs)]
struct U(u32);
#[derive(Debug, PartialEq, Encode, Decode)]
#[cfg_attr(feature = "std", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Encode, Decode, CompactAs)]
struct U2 { a: u64 }
#[cfg_attr(feature = "std", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Encode, Decode, CompactAs)]
struct USkip(#[codec(skip)] u32, u32, #[codec(skip)] u32);
#[derive(Debug, PartialEq, Encode, Decode)]
struct Uc(#[codec(compact)] u32);
#[derive(Debug, PartialEq, Clone, Encode, Decode)]
struct Ucas(#[codec(compact)] U);
#[derive(Debug, PartialEq, Clone, Encode, Decode)]
struct USkipcas(#[codec(compact)] USkip);
#[derive(Debug, PartialEq, Clone, Encode, Decode)]
struct SSkipcas(#[codec(compact)] SSkip);
#[derive(Debug, PartialEq, Encode, Decode)]
struct Uh<T: HasCompact>(#[codec(encoded_as = "<T as HasCompact>::Type")] T);
......@@ -51,6 +68,10 @@ fn test_encoding() {
let u = U(x);
let u_skip = USkip(Default::default(), x, Default::default());
let uc = Uc(x);
let ucom = Compact(u);
let ucas = Ucas(u);
let u_skip_cas = USkipcas(u_skip);
let s_skip_cas = SSkipcas(s_skip);
let uh = Uh(x);
let mut s_encoded: &[u8] = &[3, 0, 0, 0];
......@@ -60,6 +81,10 @@ fn test_encoding() {
let mut u_encoded: &[u8] = &[3, 0, 0, 0];
let mut u_skip_encoded: &[u8] = &[3, 0, 0, 0];
let mut uc_encoded: &[u8] = &[12];
let mut ucom_encoded: &[u8] = &[12];
let mut ucas_encoded: &[u8] = &[12];
let mut u_skip_cas_encoded: &[u8] = &[12];
let mut s_skip_cas_encoded: &[u8] = &[12];
let mut uh_encoded: &[u8] = &[12];
assert_eq!(s.encode(), s_encoded);
......@@ -69,6 +94,10 @@ fn test_encoding() {
assert_eq!(u.encode(), u_encoded);
assert_eq!(u_skip.encode(), u_skip_encoded);
assert_eq!(uc.encode(), uc_encoded);
assert_eq!(ucom.encode(), ucom_encoded);
assert_eq!(ucas.encode(), ucas_encoded);
assert_eq!(u_skip_cas.encode(), u_skip_cas_encoded);
assert_eq!(s_skip_cas.encode(), s_skip_cas_encoded);
assert_eq!(uh.encode(), uh_encoded);
assert_eq!(s, S::decode(&mut s_encoded).unwrap());
......@@ -78,5 +107,9 @@ fn test_encoding() {
assert_eq!(u, U::decode(&mut u_encoded).unwrap());
assert_eq!(u_skip, USkip::decode(&mut u_skip_encoded).unwrap());
assert_eq!(uc, Uc::decode(&mut uc_encoded).unwrap());
assert_eq!(ucom, <Compact::<U>>::decode(&mut ucom_encoded).unwrap());
assert_eq!(ucas, Ucas::decode(&mut ucas_encoded).unwrap());
assert_eq!(u_skip_cas, USkipcas::decode(&mut u_skip_cas_encoded).unwrap());
assert_eq!(s_skip_cas, SSkipcas::decode(&mut s_skip_cas_encoded).unwrap());
assert_eq!(uh, Uh::decode(&mut uh_encoded).unwrap());
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment