diff --git a/src/derive_util.rs b/src/derive_util.rs index 84f8278786..1bfd9c1cc4 100644 --- a/src/derive_util.rs +++ b/src/derive_util.rs @@ -139,5 +139,37 @@ mod tests { assert_eq!(try_transmute!([1u8, 1]), Some(Foo(1, true))); assert_eq!(try_transmute!([2u8, 2]), None::); assert_eq!(try_transmute!([128u8, 1]), None::); + + #[derive(TryFromBytes, Eq, PartialEq, Debug)] + #[zerocopy(validator = Baz::validate)] + #[repr(u8)] + enum Baz { + A, // 0 + B = 5, // 5 + C, // 6 + D = 1 + 1, // 2 + E, // 3 + F, // 4 + } + + impl Baz { + fn validate(&self) -> bool { + use Baz::*; + matches!(self, A | B | C | D | E) + } + } + + impl_known_layout!(Baz); + + assert_eq!(try_transmute!(0u8), Some(Baz::A)); + assert_eq!(try_transmute!(1u8), None::); + assert_eq!(try_transmute!(2u8), Some(Baz::D)); + assert_eq!(try_transmute!(3u8), Some(Baz::E)); + + assert_eq!(try_transmute!(4u8), None::); + + assert_eq!(try_transmute!(5u8), Some(Baz::B)); + assert_eq!(try_transmute!(6u8), Some(Baz::C)); + assert_eq!(try_transmute!(7u8), None::); } } diff --git a/zerocopy-derive/src/ext.rs b/zerocopy-derive/src/ext.rs index 8c482a8499..c775fc0daa 100644 --- a/zerocopy-derive/src/ext.rs +++ b/zerocopy-derive/src/ext.rs @@ -4,7 +4,7 @@ use proc_macro2::TokenStream; use quote::ToTokens; -use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Index, Type}; +use syn::{Data, DataEnum, DataStruct, DataUnion, Field, Fields, Index, Type}; pub trait DataExt { /// Extract the names and types of all fields. For enums, extract the names @@ -46,6 +46,12 @@ impl DataExt for DataUnion { } } +impl DataExt for Fields { + fn fields(&self) -> Vec<(TokenStream, &Type)> { + map_fields(self) + } +} + fn map_fields<'a>( fields: impl 'a + IntoIterator, ) -> Vec<(TokenStream, &'a Type)> { diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index daa0580a1e..e84d74ff9c 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -31,8 +31,8 @@ use { proc_macro2::{Span, TokenStream}, quote::quote, syn::{ - parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput, - Error, Expr, ExprLit, GenericParam, Ident, Lit, + parse_quote, spanned::Spanned, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, + Expr, ExprLit, GenericParam, Ident, Index, Lit, Type, }, }; @@ -88,9 +88,16 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt let validator = try_or_print!(parse_zerocopy_attrs(&ast.attrs).map_err(|e| vec![e])); let opts = validator.map(IsBitValidOpts::WithValidator).unwrap_or(IsBitValidOpts::NoValidator); match &ast.data { - Data::Struct(strct) => impl_block(&ast, strct, "TryFromBytes", true, None, Some(opts)), - Data::Enum(_) => { - Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() + Data::Struct(strct) => { + impl_block(&ast, strct, "TryFromBytes", true, None, Some(opts), None) + } + // TODO: REQUIRE REPRS FOR ENUMS!!! Need to know where the discriminant + // lives in order for this to be sound. + Data::Enum(enm) => { + let reprs = try_or_print!(ENUM_FROM_BYTES_CFG.validate_reprs(&ast)); + let repr = if let [repr] = reprs.as_slice() { repr } else { unreachable!() }; + + impl_block(&ast, enm, "TryFromBytes", true, None, Some(opts), repr.type_ident()) } Data::Union(_) => { Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error() @@ -153,8 +160,8 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ // A struct is `FromZeroes` if: // - all fields are `FromZeroes` -fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { - impl_block(ast, strct, "FromZeroes", true, None, None) +fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { + impl_block(ast, strct, "FromZeroes", true, None, None, None) } // An enum is `FromZeroes` if: @@ -188,21 +195,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { .to_compile_error(); } - impl_block(ast, enm, "FromZeroes", true, None, None) + impl_block(ast, enm, "FromZeroes", true, None, None, None) } // Like structs, unions are `FromZeroes` if // - all fields are `FromZeroes` -fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { - impl_block(ast, unn, "FromZeroes", true, None, None) +fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { + impl_block(ast, unn, "FromZeroes", true, None, None, None) } // A struct is `FromBytes` if: // - all fields are `FromBytes` -fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { - impl_block(ast, strct, "FromBytes", true, None, None) +fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { + impl_block(ast, strct, "FromBytes", true, None, None, None) } // An enum is `FromBytes` if: @@ -245,7 +252,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, None, None) + impl_block(ast, enm, "FromBytes", true, None, None, None) } #[rustfmt::skip] @@ -275,8 +282,8 @@ const ENUM_FROM_BYTES_CFG: Config = { // Like structs, unions are `FromBytes` if // - all fields are `FromBytes` -fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { - impl_block(ast, unn, "FromBytes", true, None, None) +fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { + impl_block(ast, unn, "FromBytes", true, None, None, None) } // A struct is `AsBytes` if: @@ -310,7 +317,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream // any padding bytes would need to come from the fields, all of which // we require to be `AsBytes` (meaning they don't have any padding). let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; - impl_block(ast, strct, "AsBytes", true, padding_check, None) + impl_block(ast, strct, "AsBytes", true, padding_check, None, None) } const STRUCT_UNION_AS_BYTES_CFG: Config = Config { @@ -333,7 +340,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, None, None) + impl_block(ast, enm, "AsBytes", false, None, None, None) } #[rustfmt::skip] @@ -375,7 +382,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), None) + impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), None, None) } // A struct is `Unaligned` if: @@ -388,7 +395,7 @@ fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, None, None) + impl_block(ast, strct, "Unaligned", require_trait_bound, None, None, None) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -419,7 +426,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, None, None) + impl_block(ast, enm, "Unaligned", true, None, None, None) } #[rustfmt::skip] @@ -457,7 +464,7 @@ fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, None, None) + impl_block(ast, unn, "Unaligned", require_trait_bound, None, None, None) } // This enum describes what kind of padding check needs to be generated for the @@ -482,6 +489,8 @@ impl PaddingCheck { } } +// TODO: Passing `repr_type_ident` as an `Option` is a hack. Figure out a way to +// do it infallibly. fn impl_block( input: &DeriveInput, data: &D, @@ -489,6 +498,7 @@ fn impl_block( require_trait_bound: bool, padding_check: Option, is_bit_valid_opts: Option, + repr_type_ident: Option, ) -> TokenStream { // In this documentation, we will refer to this hypothetical struct: // @@ -599,6 +609,101 @@ fn impl_block( }); let is_bit_valid = is_bit_valid_opts.map(|opts| { + use core::borrow::Borrow; + fn foo<'a>(name_and_type: impl Borrow<(TokenStream, &'a Type)>) -> TokenStream { + let (name, ty) = name_and_type.borrow(); + quote!(&& { + let field_candidate = ::core::ptr::addr_of_mut!((*_c).#name); + // TODO: Update this safety comment. It was just copied verbatim + // from when we only supported struct fields with this logic. + // + // SAFETY: Caller has promised that `candidate` points to a + // single allocation, and allocations cannot wrap around the + // address space. `field_candidate` is at some offset within + // this allocation, so it also cannot wrap around. Since + // `candidate` is a non-null pointer, `field_candidate`'s + // smallest possible value is non-null, and its largest possible + // value doesn't wrap around, and is thus also non-null. + let f = unsafe { ::core::ptr::NonNull::new_unchecked(field_candidate) }; + // TODO: Update this safety comment. It was just copied verbatim + // from when we only supported struct fields with this logic. + // + // SAFETY: + // - `f` is properly aligned for `#field_tys` because + // `candidate` is properly aligned for `Self`. + // - `f` is valid for reads because `candidate` is. + // - Total length encoded by `f` doesn't overflow `isize` + // because it's no greater than the size encoded by + // `candidate`, whose size doesn't overflow `isize`. + // - `f` addresses a range which falls inside a single + // allocation because that range is a subset of the range + // addressed by `candidate`, and that latter range falls + // inside a single allocation. + // - The bit validity property of `is_bit_valid` is trivially + // compositional for structs. In particular, in a struct, + // there is no data dependency between bit validity in any two + // byte offsets (this is notably not true of enums). Since we + // know that the bit validity property holds for all of + // `candidate`, we also know that it holds for `f` regardless + // of the contents of any other region of `candidate`. + // + // Note that it's possible that this call will panic - + // `is_bit_valid` does not promise that it doesn't panic, and in + // practice, we support user-defined validators, which could + // panic. This is sound because we haven't violated any safety + // invariants which we would need to fix before returning. + <#ty as zerocopy::TryFromBytes>::is_bit_valid(f) + }) + } + + let (preamble, default, cases): (_, _, Vec<_>) = match &input.data { + Data::Struct(_) | Data::Union(_) => (None, true, fields.iter().map(foo).collect()), + Data::Enum(enm) => { + if !input.generics.params.is_empty() { + return Error::new_spanned( + &input.generics.params, + "TryFromBytes is not supported on enums with type parameters", + ) + .to_compile_error(); + } + + let preamble = quote!( + let _discriminant_base = candidate.cast::<#repr_type_ident>().as_ptr(); + // SAFETY: TODO + let _discriminant = unsafe { ptr::read(_discriminant_base) }; + + // Note that `.add`'s argument is a count of + // `size_of::()` bytes, so this results in a pointer + // which points one byte past the end of the discriminant - + // in other words, to the first byte of the variants. + // + // SAFETY: TODO + let _variant_base = unsafe { _discriminant_base.add(1) }; + ); + let cases = enm + .variants + .iter() + .scan(Discriminant::default(), |disc, var| { + let disc_expr = disc.update_and_generate_expr(&var.discriminant); + + let variant_fields = + var.fields.fields().into_iter().map(|(name, ty)| quote!(#name: #ty)); + + let cases = var.fields.fields().into_iter().map(foo); + Some(quote!(|| { + struct Variant { + #(#variant_fields,)* + } + + let _c = unsafe { NonNull::new_unchecked(_variant_base).cast::() }; + _discriminant == (#disc_expr) #(#cases)* + })) + }) + .collect(); + (Some(preamble), false, cases) + } + }; + let validator = Option::from(opts) .map(|ZcValidatorAttr(validator)| { quote::quote_spanned!(validator.span()=> && { @@ -629,57 +734,19 @@ fn impl_block( }) .into_iter(); - let field_names = fields.iter().map(|(name, _ty)| name); - let field_tys = fields.iter().map(|(_name, ty)| ty); quote!( + // TODO: Update this safety comment. It was just copied verbatim + // from when we only supported struct fields with this logic. + // // SAFETY: We use `is_bit_valid` to validate that each field is // bit-valid, and only return `true` if all of them are. The bit // validity of a struct is just the composition of the bit // validities of its fields, so this is a sound implementation of // `is_bit_valid`. unsafe fn is_bit_valid(candidate: ::core::ptr::NonNull) -> bool { + #preamble let _c = candidate.as_ptr(); - true #(&& { - let field_candidate = ::core::ptr::addr_of_mut!((*_c).#field_names); - // SAFETY: Caller has promised that `candidate` points to a - // single allocation, and allocations cannot wrap around the - // address space. `field_candidate` is at some offset within - // this allocation, so it also cannot wrap around. Since - // `candidate` is a non-null pointer, `field_candidate`'s - // smallest possible value is non-null, and its largest - // possible value doesn't wrap around, and is thus also - // non-null. - let f = unsafe { ::core::ptr::NonNull::new_unchecked(field_candidate) }; - // SAFETY: - // - `f` is properly aligned for `#field_tys` because - // `candidate` is properly aligned for `Self`. - // - `f` is valid for reads because `candidate` is. - // - Total length encoded by `f` doesn't overflow `isize` - // because it's no greater than the size encoded by - // `candidate`, whose size doesn't overflow `isize`. - // - `f` addresses a range which falls inside a single - // allocation because that range is a subset of the range - // addressed by `candidate`, and that latter range falls - // inside a single allocation. - // - The bit validity property of `is_bit_valid` is - // trivially compositional for structs. In particular, in - // a struct, there is no data dependency between bit - // validity in any two byte offsets (this is notably not - // true of enums). Since we know that the bit validity - // property holds for all of `candidate`, we also know - // that it holds for `f` regardless of the contents of any - // other region of `candidate`. - // - // Note that it's possible that this call will panic - - // `is_bit_valid` does not promise that it doesn't panic, - // and in practice, we support user-defined validators, - // which could panic. This is sound because we haven't - // violated any safety invariants which we would need to fix - // before returning. - <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(f) - })* - - #(#validator)* + #default #(#cases)* #(#validator)* } ) }); @@ -697,6 +764,61 @@ fn impl_block( } } +// Enum variant discriminants can be manually set not only as literal values, +// but as arbitrary const expressions. In order to handle this, we keep track of +// the most-recently-seen expression and a count of how many variants have been +// encountered since then. +// +// #[repr(u8)] +// enum Foo { +// A, // 0 +// B = 5, // 5 +// C, // 6 +// D = 1 + 1, // 2 +// E, // 3 +// } +// +// Note: Default::default does the right thing (initializes to { None, 0 }). +#[derive(Default, Copy, Clone)] +struct Discriminant<'a> { + // The most-recently-set explicit discriminant. + previous: Option<&'a Expr>, + // When the next variant is encountered, what offset should be used compared + // to `previous` to determine the variant's discriminant? + next_offset: usize, +} + +impl<'a> Discriminant<'a> { + /// Called when encountering a variant with discriminant set to `ast`. + /// Updates `self` in preparation for the next variant and generates an + /// expression which will evaluate to the numeric value this variant's + /// discriminant. + fn update_and_generate_expr( + &mut self, + ast: &'a Option<(syn::token::Eq, Expr)>, + ) -> proc_macro2::TokenStream { + match ast.as_ref().map(|(_eq, expr)| expr) { + Some(expr) => { + self.previous = Some(expr); + self.next_offset = 1; + quote!(#expr) + } + None => { + let previous = self.previous.iter(); + // Use `Index` instead of `usize` so that the number is + // formatted just as `0` rather than as `0usize`; the latter + // syntax is only valid if the repr is `usize`; otherwise, + // comparison will result in a type mismatch. + let offset = Index::from(self.next_offset); + let tokens = quote!(#(#previous +)* #offset); + + self.next_offset += 1; + tokens + } + } + } +} + fn print_all_errors(errors: Vec) -> TokenStream { errors.iter().map(Error::to_compile_error).collect() } diff --git a/zerocopy-derive/src/repr.rs b/zerocopy-derive/src/repr.rs index 5997ad2f2f..01525e1b17 100644 --- a/zerocopy-derive/src/repr.rs +++ b/zerocopy-derive/src/repr.rs @@ -9,7 +9,7 @@ use { syn::punctuated::Punctuated, syn::spanned::Spanned, syn::token::Comma, - syn::{Attribute, DeriveInput, Error, LitInt, Meta}, + syn::{Attribute, DeriveInput, Error, Ident, LitInt, Meta}, }; pub struct Config { @@ -149,10 +149,17 @@ macro_rules! define_kind_specific_repr { impl core::fmt::Display for $repr_name { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { + let r: Repr = (*self).into(); + r.fmt(f) + } + } + + impl From<$repr_name> for Repr { + fn from(repr: $repr_name) -> Repr { + match repr { $($repr_name::$repr_variant => Repr::$repr_variant,)* - $repr_name::Align(u) => Repr::Align(*u), - }.fmt(f) + $repr_name::Align(u) => Repr::Align(u), + } } } } @@ -163,6 +170,17 @@ define_kind_specific_repr!( "an enum", EnumRepr, C, U8, U16, U32, U64, Usize, I8, I16, I32, I64, Isize ); +impl EnumRepr { + // TODO: Having this exist and return `Option` is a hack. We should + // restructure the types so this is infallible. + pub fn type_ident(&self) -> Option { + use EnumRepr::*; + let r: Repr = (*self).into(); + matches!(self, U8 | U16 | U32 | U64 | Usize | I8 | I16 | I32 | I64 | Isize) + .then(|| Ident::new(r.foo(), Span::call_site())) + } +} + // All representations known to Rust. #[derive(Copy, Clone, Eq, PartialEq)] pub enum Repr { @@ -214,6 +232,25 @@ impl Repr { Err(Error::new_spanned(meta, "unrecognized representation hint")) } + + fn foo(&self) -> &str { + match self { + Repr::U8 => "u8", + Repr::U16 => "u16", + Repr::U32 => "u32", + Repr::U64 => "u64", + Repr::Usize => "usize", + Repr::I8 => "i8", + Repr::I16 => "i16", + Repr::I32 => "i32", + Repr::I64 => "i64", + Repr::Isize => "isize", + Repr::C => "C", + Repr::Transparent => "transparent", + Repr::Packed => "packed", + _ => unreachable!(), + } + } } impl Display for Repr { @@ -221,26 +258,7 @@ impl Display for Repr { if let Repr::Align(n) = self { return write!(f, "repr(align({}))", n); } - write!( - f, - "repr({})", - match self { - Repr::U8 => "u8", - Repr::U16 => "u16", - Repr::U32 => "u32", - Repr::U64 => "u64", - Repr::Usize => "usize", - Repr::I8 => "i8", - Repr::I16 => "i16", - Repr::I32 => "i32", - Repr::I64 => "i64", - Repr::Isize => "isize", - Repr::C => "C", - Repr::Transparent => "transparent", - Repr::Packed => "packed", - _ => unreachable!(), - } - ) + write!(f, "repr({})", self.foo()) } }