Skip to content

Commit

Permalink
Merge pull request #117 from ralexstokes/derive-ser-transparent
Browse files Browse the repository at this point in the history
extend `transparent` proc macro attr behavior to Serialize and Deserialize
  • Loading branch information
ralexstokes authored Nov 8, 2023
2 parents 5f1ec83 + 5df033e commit db3bca5
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 41 deletions.
121 changes: 82 additions & 39 deletions ssz-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const BYTES_PER_CHUNK: usize = 32;

const SSZ_HELPER_ATTRIBUTE: &str = "ssz";

fn derive_serialize_impl(data: &Data) -> TokenStream {
fn derive_serialize_impl(data: &Data, helper_attr: Option<&HelperAttr>) -> TokenStream {
match data {
Data::Struct(ref data) => {
let fields = match data.fields {
Expand Down Expand Up @@ -55,12 +55,18 @@ fn derive_serialize_impl(data: &Data) -> TokenStream {
let variant_name = &variant.ident;
match &variant.fields {
Fields::Unnamed(..) => {
quote_spanned! { variant.span() =>
Self::#variant_name(value) => {
let selector = #i as u8;
let selector_bytes = selector.serialize(buffer)?;
let value_bytes = value.serialize(buffer)?;
Ok(selector_bytes + value_bytes)
if matches!(helper_attr, Some(&HelperAttr::Transparent)) {
quote_spanned! { variant.span() =>
Self::#variant_name(value) => value.serialize(buffer),
}
} else {
quote_spanned! { variant.span() =>
Self::#variant_name(value) => {
let selector = #i as u8;
let selector_bytes = selector.serialize(buffer)?;
let value_bytes = value.serialize(buffer)?;
Ok(selector_bytes + value_bytes)
}
}
}
}
Expand All @@ -87,7 +93,7 @@ fn derive_serialize_impl(data: &Data) -> TokenStream {
}
}

fn derive_deserialize_impl(data: &Data) -> TokenStream {
fn derive_deserialize_impl(data: &Data, helper_attr: Option<&HelperAttr>) -> TokenStream {
match data {
Data::Struct(ref data) => {
let fields = match data.fields {
Expand Down Expand Up @@ -145,40 +151,71 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream {
}
}
Data::Enum(ref data) => {
let deserialization_by_variant =
data.variants.iter().enumerate().map(|(i, variant)| {
let body = if matches!(helper_attr, Some(&HelperAttr::Transparent)) {
let deserialization_by_variant = data.variants.iter().rev().map(|variant| {
// NOTE: this is "safe" as the number of legal variants fits into `u8`
let i = i as u8;
let variant_name = &variant.ident;
match &variant.fields {
Fields::Unnamed(inner) => {
// SAFETY: index is safe because Punctuated always has a first element;
// qed
// SAFETY: index is safe because Punctuated always has a first
// element; qed
let variant_type = &inner.unnamed[0];
quote_spanned! { variant.span() =>
#i => {
// SAFETY: index is safe because encoding isn't empty; qed
let value = <#variant_type>::deserialize(&encoding[1..])?;
Ok(Self::#variant_name(value))
if let Ok(value) = <#variant_type>::deserialize(encoding) {
return Ok(Self::#variant_name(value))
}
}
}
Fields::Unit => {
quote_spanned! { variant.span() =>
0 => {
if encoding.len() != 1 {
return Err(DeserializeError::AdditionalInput {
provided: encoding.len(),
expected: 1,
})
_ => unreachable!("validated to exclude this condition"),
}
});
quote! {
#(#deserialization_by_variant)*
Err(ssz_rs::DeserializeError::NoMatchingVariant)
}
} else {
let deserialization_by_variant =
data.variants.iter().enumerate().map(|(i, variant)| {
// NOTE: this is "safe" as the number of legal variants fits into `u8`
let i = i as u8;
let variant_name = &variant.ident;
match &variant.fields {
Fields::Unnamed(inner) => {
// SAFETY: index is safe because Punctuated always has a first
// element; qed
let variant_type = &inner.unnamed[0];
quote_spanned! { variant.span() =>
#i => {
// SAFETY: index is safe because encoding isn't empty; qed
let value = <#variant_type>::deserialize(&encoding[1..])?;
Ok(Self::#variant_name(value))
}
Ok(Self::None)
},
}
}
Fields::Unit => {
quote_spanned! { variant.span() =>
0 => {
if encoding.len() != 1 {
return Err(DeserializeError::AdditionalInput {
provided: encoding.len(),
expected: 1,
})
}
Ok(Self::None)
},
}
}
_ => unreachable!(),
}
_ => unreachable!(),
});
quote! {
// SAFETY: index is safe because encoding isn't empty; qed
match encoding[0] {
#(#deserialization_by_variant)*
b => Err(ssz_rs::DeserializeError::InvalidByte(b)),
}
});
}
};

quote! {
fn deserialize(encoding: &[u8]) -> Result<Self, ssz_rs::DeserializeError> {
Expand All @@ -189,11 +226,7 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream {
});
}

// SAFETY: index is safe because encoding isn't empty; qed
match encoding[0] {
#(#deserialization_by_variant)*
b => Err(ssz_rs::DeserializeError::InvalidByte(b)),
}
#body
}
}
}
Expand Down Expand Up @@ -223,6 +256,10 @@ fn derive_variable_size_impl(data: &Data) -> TokenStream {
}
}
Data::Enum(..) => {
// NOTE: interaction with `transparent` attribute:
// no code in this repo should ever directly call this generated method
// on the "wrapping enum" used with `transparent`
// thus, we can simply provide the existing implementation
quote! { true }
}
Data::Union(..) => unreachable!("data was already validated to exclude union types"),
Expand Down Expand Up @@ -255,6 +292,10 @@ fn derive_size_hint_impl(data: &Data) -> TokenStream {
}
}
Data::Enum(..) => {
// NOTE: interaction with `transparent` attribute:
// no code in this repo should ever directly call this generated method
// on the "wrapping enum" used with `transparent`
// thus, we can simply provide the existing implementation
quote! { 0 }
}
Data::Union(..) => unreachable!("data was already validated to exclude union types"),
Expand Down Expand Up @@ -466,9 +507,10 @@ fn derive_serializable_impl(
data: &Data,
name: &Ident,
generics: &Generics,
helper_attr: Option<&HelperAttr>,
) -> proc_macro2::TokenStream {
let serialize_impl = derive_serialize_impl(data);
let deserialize_impl = derive_deserialize_impl(data);
let serialize_impl = derive_serialize_impl(data, helper_attr);
let deserialize_impl = derive_deserialize_impl(data, helper_attr);
let is_variable_size_impl = derive_variable_size_impl(data);
let size_hint_impl = derive_size_hint_impl(data);

Expand Down Expand Up @@ -502,7 +544,7 @@ fn derive_simple_serialize_impl(name: &Ident, generics: &Generics) -> proc_macro
}
}

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
enum HelperAttr {
Transparent,
}
Expand Down Expand Up @@ -552,11 +594,12 @@ pub fn derive_serializable(input: proc_macro::TokenStream) -> proc_macro::TokenS
let data = &input.data;
let helper_attrs = extract_helper_attrs(&input);
validate_derive_input(data, &helper_attrs);
let helper_attr = helper_attrs.first();

let name = &input.ident;
let generics = &input.generics;

let expansion = derive_serializable_impl(data, name, generics);
let expansion = derive_serializable_impl(data, name, generics, helper_attr);
proc_macro::TokenStream::from(expansion)
}

Expand Down Expand Up @@ -589,7 +632,7 @@ pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let generics = &input.generics;
let merkleization_impl = derive_merkleization_impl(data, name, generics, helper_attr);

let serializable_impl = derive_serializable_impl(data, name, generics);
let serializable_impl = derive_serializable_impl(data, name, generics, helper_attr);

let simple_serialize_impl = derive_simple_serialize_impl(name, generics);

Expand Down
10 changes: 8 additions & 2 deletions ssz-rs-derive/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use ssz_rs::prelude::*;
use ssz_rs_derive::SimpleSerialize;

#[derive(Debug, SimpleSerialize)]
#[derive(Debug, SimpleSerialize, PartialEq, Eq)]
struct Foo {
a: u8,
b: u32,
}

#[derive(Debug, SimpleSerialize)]
#[derive(Debug, SimpleSerialize, PartialEq, Eq)]
#[ssz(transparent)]
enum Bar {
A(u8),
Expand All @@ -22,6 +22,12 @@ fn test_transparent_helper() {
let mut f = Foo { a: 23, b: 445 };
let f_root = f.hash_tree_root().unwrap();
let mut bar = Bar::B(f);

let mut buf = vec![];
let _ = bar.serialize(&mut buf).unwrap();
let recovered_bar = Bar::deserialize(&buf).unwrap();
assert_eq!(bar, recovered_bar);

let bar_root = bar.hash_tree_root().unwrap();
assert_eq!(f_root, bar_root);
}
4 changes: 4 additions & 0 deletions ssz-rs/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ pub enum DeserializeError {
OffsetNotIncreasing { start: usize, end: usize },
/// An offset was absent when expected.
MissingOffset,
/// No corresponding variant of the requested enum was present. (refer to `transparent`
/// attribute of `ssz-rs-derive` macro)
NoMatchingVariant,
}

impl From<InstanceError> for DeserializeError {
Expand Down Expand Up @@ -52,6 +55,7 @@ impl Display for DeserializeError {
DeserializeError::InvalidOffsetsLength(len) => write!(f, "the offsets length provided {len} is not a multiple of the size per length offset {BYTES_PER_LENGTH_OFFSET} bytes"),
DeserializeError::OffsetNotIncreasing { start, end } => write!(f, "invalid offset points to byte {end} before byte {start}"),
DeserializeError::MissingOffset => write!(f, "an offset was missing when deserializing a variable-sized type"),
DeserializeError::NoMatchingVariant => write!(f, "no corresponding variant of the requested enum was present"),
}
}
}
Expand Down

0 comments on commit db3bca5

Please sign in to comment.