diff --git a/ssz-rs-derive/src/lib.rs b/ssz-rs-derive/src/lib.rs index 119ad86a..aecef697 100644 --- a/ssz-rs-derive/src/lib.rs +++ b/ssz-rs-derive/src/lib.rs @@ -14,7 +14,7 @@ const BYTES_PER_CHUNK: usize = 32; const SSZ_HELPER_ATTRIBUTE: &str = "ssz"; -fn derive_serialize_impl(data: &Data) -> TokenStream { +fn derive_serialize_impl(data: &Data, helper_attr: Option<&HelperAttr>) -> TokenStream { match data { Data::Struct(ref data) => { let fields = match data.fields { @@ -55,12 +55,18 @@ fn derive_serialize_impl(data: &Data) -> TokenStream { let variant_name = &variant.ident; match &variant.fields { Fields::Unnamed(..) => { - quote_spanned! { variant.span() => - Self::#variant_name(value) => { - let selector = #i as u8; - let selector_bytes = selector.serialize(buffer)?; - let value_bytes = value.serialize(buffer)?; - Ok(selector_bytes + value_bytes) + if matches!(helper_attr, Some(&HelperAttr::Transparent)) { + quote_spanned! { variant.span() => + Self::#variant_name(value) => value.serialize(buffer), + } + } else { + quote_spanned! { variant.span() => + Self::#variant_name(value) => { + let selector = #i as u8; + let selector_bytes = selector.serialize(buffer)?; + let value_bytes = value.serialize(buffer)?; + Ok(selector_bytes + value_bytes) + } } } } @@ -87,7 +93,7 @@ fn derive_serialize_impl(data: &Data) -> TokenStream { } } -fn derive_deserialize_impl(data: &Data) -> TokenStream { +fn derive_deserialize_impl(data: &Data, helper_attr: Option<&HelperAttr>) -> TokenStream { match data { Data::Struct(ref data) => { let fields = match data.fields { @@ -145,40 +151,71 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream { } } Data::Enum(ref data) => { - let deserialization_by_variant = - data.variants.iter().enumerate().map(|(i, variant)| { + let body = if matches!(helper_attr, Some(&HelperAttr::Transparent)) { + let deserialization_by_variant = data.variants.iter().rev().map(|variant| { // NOTE: this is "safe" as the number of legal variants fits into `u8` - let i = i as u8; let variant_name = &variant.ident; match &variant.fields { Fields::Unnamed(inner) => { - // SAFETY: index is safe because Punctuated always has a first element; - // qed + // SAFETY: index is safe because Punctuated always has a first + // element; qed let variant_type = &inner.unnamed[0]; quote_spanned! { variant.span() => - #i => { - // SAFETY: index is safe because encoding isn't empty; qed - let value = <#variant_type>::deserialize(&encoding[1..])?; - Ok(Self::#variant_name(value)) + if let Ok(value) = <#variant_type>::deserialize(encoding) { + return Ok(Self::#variant_name(value)) } } } - Fields::Unit => { - quote_spanned! { variant.span() => - 0 => { - if encoding.len() != 1 { - return Err(DeserializeError::AdditionalInput { - provided: encoding.len(), - expected: 1, - }) + _ => unreachable!("validated to exclude this condition"), + } + }); + quote! { + #(#deserialization_by_variant)* + Err(ssz_rs::DeserializeError::NoMatchingVariant) + } + } else { + let deserialization_by_variant = + data.variants.iter().enumerate().map(|(i, variant)| { + // NOTE: this is "safe" as the number of legal variants fits into `u8` + let i = i as u8; + let variant_name = &variant.ident; + match &variant.fields { + Fields::Unnamed(inner) => { + // SAFETY: index is safe because Punctuated always has a first + // element; qed + let variant_type = &inner.unnamed[0]; + quote_spanned! { variant.span() => + #i => { + // SAFETY: index is safe because encoding isn't empty; qed + let value = <#variant_type>::deserialize(&encoding[1..])?; + Ok(Self::#variant_name(value)) } - Ok(Self::None) - }, + } } + Fields::Unit => { + quote_spanned! { variant.span() => + 0 => { + if encoding.len() != 1 { + return Err(DeserializeError::AdditionalInput { + provided: encoding.len(), + expected: 1, + }) + } + Ok(Self::None) + }, + } + } + _ => unreachable!(), } - _ => unreachable!(), + }); + quote! { + // SAFETY: index is safe because encoding isn't empty; qed + match encoding[0] { + #(#deserialization_by_variant)* + b => Err(ssz_rs::DeserializeError::InvalidByte(b)), } - }); + } + }; quote! { fn deserialize(encoding: &[u8]) -> Result { @@ -189,11 +226,7 @@ fn derive_deserialize_impl(data: &Data) -> TokenStream { }); } - // SAFETY: index is safe because encoding isn't empty; qed - match encoding[0] { - #(#deserialization_by_variant)* - b => Err(ssz_rs::DeserializeError::InvalidByte(b)), - } + #body } } } @@ -223,6 +256,10 @@ fn derive_variable_size_impl(data: &Data) -> TokenStream { } } Data::Enum(..) => { + // NOTE: interaction with `transparent` attribute: + // no code in this repo should ever directly call this generated method + // on the "wrapping enum" used with `transparent` + // thus, we can simply provide the existing implementation quote! { true } } Data::Union(..) => unreachable!("data was already validated to exclude union types"), @@ -255,6 +292,10 @@ fn derive_size_hint_impl(data: &Data) -> TokenStream { } } Data::Enum(..) => { + // NOTE: interaction with `transparent` attribute: + // no code in this repo should ever directly call this generated method + // on the "wrapping enum" used with `transparent` + // thus, we can simply provide the existing implementation quote! { 0 } } Data::Union(..) => unreachable!("data was already validated to exclude union types"), @@ -466,9 +507,10 @@ fn derive_serializable_impl( data: &Data, name: &Ident, generics: &Generics, + helper_attr: Option<&HelperAttr>, ) -> proc_macro2::TokenStream { - let serialize_impl = derive_serialize_impl(data); - let deserialize_impl = derive_deserialize_impl(data); + let serialize_impl = derive_serialize_impl(data, helper_attr); + let deserialize_impl = derive_deserialize_impl(data, helper_attr); let is_variable_size_impl = derive_variable_size_impl(data); let size_hint_impl = derive_size_hint_impl(data); @@ -502,7 +544,7 @@ fn derive_simple_serialize_impl(name: &Ident, generics: &Generics) -> proc_macro } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] enum HelperAttr { Transparent, } @@ -552,11 +594,12 @@ pub fn derive_serializable(input: proc_macro::TokenStream) -> proc_macro::TokenS let data = &input.data; let helper_attrs = extract_helper_attrs(&input); validate_derive_input(data, &helper_attrs); + let helper_attr = helper_attrs.first(); let name = &input.ident; let generics = &input.generics; - let expansion = derive_serializable_impl(data, name, generics); + let expansion = derive_serializable_impl(data, name, generics, helper_attr); proc_macro::TokenStream::from(expansion) } @@ -589,7 +632,7 @@ pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let generics = &input.generics; let merkleization_impl = derive_merkleization_impl(data, name, generics, helper_attr); - let serializable_impl = derive_serializable_impl(data, name, generics); + let serializable_impl = derive_serializable_impl(data, name, generics, helper_attr); let simple_serialize_impl = derive_simple_serialize_impl(name, generics); diff --git a/ssz-rs-derive/tests/mod.rs b/ssz-rs-derive/tests/mod.rs index 94f5c86f..93adcd25 100644 --- a/ssz-rs-derive/tests/mod.rs +++ b/ssz-rs-derive/tests/mod.rs @@ -1,13 +1,13 @@ use ssz_rs::prelude::*; use ssz_rs_derive::SimpleSerialize; -#[derive(Debug, SimpleSerialize)] +#[derive(Debug, SimpleSerialize, PartialEq, Eq)] struct Foo { a: u8, b: u32, } -#[derive(Debug, SimpleSerialize)] +#[derive(Debug, SimpleSerialize, PartialEq, Eq)] #[ssz(transparent)] enum Bar { A(u8), @@ -22,6 +22,12 @@ fn test_transparent_helper() { let mut f = Foo { a: 23, b: 445 }; let f_root = f.hash_tree_root().unwrap(); let mut bar = Bar::B(f); + + let mut buf = vec![]; + let _ = bar.serialize(&mut buf).unwrap(); + let recovered_bar = Bar::deserialize(&buf).unwrap(); + assert_eq!(bar, recovered_bar); + let bar_root = bar.hash_tree_root().unwrap(); assert_eq!(f_root, bar_root); } diff --git a/ssz-rs/src/de.rs b/ssz-rs/src/de.rs index 3a72809d..6b721dc1 100644 --- a/ssz-rs/src/de.rs +++ b/ssz-rs/src/de.rs @@ -24,6 +24,9 @@ pub enum DeserializeError { OffsetNotIncreasing { start: usize, end: usize }, /// An offset was absent when expected. MissingOffset, + /// No corresponding variant of the requested enum was present. (refer to `transparent` + /// attribute of `ssz-rs-derive` macro) + NoMatchingVariant, } impl From for DeserializeError { @@ -52,6 +55,7 @@ impl Display for DeserializeError { DeserializeError::InvalidOffsetsLength(len) => write!(f, "the offsets length provided {len} is not a multiple of the size per length offset {BYTES_PER_LENGTH_OFFSET} bytes"), DeserializeError::OffsetNotIncreasing { start, end } => write!(f, "invalid offset points to byte {end} before byte {start}"), DeserializeError::MissingOffset => write!(f, "an offset was missing when deserializing a variable-sized type"), + DeserializeError::NoMatchingVariant => write!(f, "no corresponding variant of the requested enum was present"), } } }