From c2db8b10bc86653eaf71272a494f3590a0957980 Mon Sep 17 00:00:00 2001 From: Joshua Liebow-Feeser Date: Sat, 9 Sep 2023 23:10:30 -0700 Subject: [PATCH] [WIP][derive] Support custom TryFromBytes validator TODO: - Cleaner way to pass name of validator to `impl_block`? - Cleaner way to parse validator attribute? - Safety comment in emitted call to `validate` - Tests (especially including tests for the error message resulting from passing a validator with the wrong type signature - test for both argument types and return types) - Also test for invalid zerocopy attributes - Other misc TODO comments in code --- src/derive_util.rs | 20 +++-- zerocopy-derive/Cargo.toml | 2 +- zerocopy-derive/src/lib.rs | 167 ++++++++++++++++++++++++++----------- 3 files changed, 132 insertions(+), 57 deletions(-) diff --git a/src/derive_util.rs b/src/derive_util.rs index 4fd76b15ca..edd999e5b0 100644 --- a/src/derive_util.rs +++ b/src/derive_util.rs @@ -128,12 +128,22 @@ mod tests { #[test] fn foo() { - #[derive(TryFromBytes)] - struct Foo { - f: u8, - b: bool, - } + #[derive(TryFromBytes, Eq, PartialEq, Debug)] + #[zerocopy(validator = "validate")] + #[repr(C)] + struct Foo(u8, bool); impl_known_layout!(Foo); + + impl Foo { + fn validate(&self) -> bool { + self.0 < 128 + } + } + + assert_eq!(try_transmute!([0u8, 0]), Some(Foo(0, false))); + assert_eq!(try_transmute!([1u8, 1]), Some(Foo(1, true))); + assert_eq!(try_transmute!([2u8, 2]), None::); + assert_eq!(try_transmute!([128u8, 1]), None::); } } diff --git a/zerocopy-derive/Cargo.toml b/zerocopy-derive/Cargo.toml index 4eae69f393..10dfb5f3d9 100644 --- a/zerocopy-derive/Cargo.toml +++ b/zerocopy-derive/Cargo.toml @@ -20,7 +20,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.1" quote = "1.0.10" -syn = "2.0.31" +syn = { version = "2.0.31", features = ["full", "parsing"] } [dev-dependencies] rustversion = "1.0" diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index f29a125a12..01ce07e62c 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -28,7 +28,7 @@ mod ext; mod repr; use { - proc_macro2::Span, + proc_macro2::{Span, TokenStream}, quote::quote, syn::{ parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit, @@ -56,15 +56,43 @@ use {crate::ext::*, crate::repr::*}; // (https://doc.rust-lang.org/nightly/proc_macro/struct.Span.html#method.error), // which is currently unstable. Revisit this once it's stable. -#[proc_macro_derive(TryFromBytes)] +// Unwraps a `Result<_, Vec>`, converting any `Err` value into a +// `TokenStream` and returning it. +macro_rules! try_or_print { + ($e:expr) => { + match $e { + Ok(x) => x, + Err(errors) => return print_all_errors(errors).into(), + } + }; +} + +enum IsBitValidOpts { + WithValidator(ZcValidatorAttr), + NoValidator, +} + +impl From for Option { + fn from(opts: IsBitValidOpts) -> Option { + match opts { + IsBitValidOpts::WithValidator(validator) => Some(validator), + IsBitValidOpts::NoValidator => None, + } + } +} + +#[proc_macro_derive(TryFromBytes, attributes(zerocopy))] pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(ts as DeriveInput); + + let validator = try_or_print!(parse_zerocopy_attrs(&ast.attrs).map_err(|e| vec![e])); + let opts = validator.map(IsBitValidOpts::WithValidator).unwrap_or(IsBitValidOpts::NoValidator); match &ast.data { - Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct), + Data::Struct(strct) => impl_block(&ast, strct, "TryFromBytes", true, None, Some(opts)), Data::Enum(_) => { Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() } - Data::Union(unn) => derive_try_from_bytes_union(&ast, unn), + Data::Union(unn) => impl_block(&ast, unn, "TryFromBytes", true, None, Some(opts)), } .into() } @@ -113,25 +141,6 @@ pub fn derive_unaligned(ts: proc_macro::TokenStream) -> proc_macro::TokenStream .into() } -// Unwraps a `Result<_, Vec>`, converting any `Err` value into a -// `TokenStream` and returning it. -macro_rules! try_or_print { - ($e:expr) => { - match $e { - Ok(x) => x, - Err(errors) => return print_all_errors(errors), - } - }; -} - -fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "TryFromBytes", true, None, true) -} - -fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "TryFromBytes", true, None, true) -} - const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C], &[StructRepr::Transparent], @@ -142,15 +151,15 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ // A struct is `FromZeroes` if: // - all fields are `FromZeroes` -fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromZeroes", true, None, false) +fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { + impl_block(ast, strct, "FromZeroes", true, None, None) } // An enum is `FromZeroes` if: // - all of its variants are fieldless // - one of the variants has a discriminant of `0` -fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement FromZeroes") .to_compile_error(); @@ -177,21 +186,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To .to_compile_error(); } - impl_block(ast, enm, "FromZeroes", true, None, false) + impl_block(ast, enm, "FromZeroes", true, None, None) } // Like structs, unions are `FromZeroes` if // - all fields are `FromZeroes` -fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromZeroes", true, None, false) +fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { + impl_block(ast, unn, "FromZeroes", true, None, None) } // A struct is `FromBytes` if: // - all fields are `FromBytes` -fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromBytes", true, None, false) +fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { + impl_block(ast, strct, "FromBytes", true, None, None) } // An enum is `FromBytes` if: @@ -208,7 +217,7 @@ fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro // platform-specific and, b) even on Rust's smallest bit width platform (32), // this would require ~4 billion enum variants, which obviously isn't a thing. -fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement FromBytes") .to_compile_error(); @@ -234,7 +243,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, None, false) + impl_block(ast, enm, "FromBytes", true, None, None) } #[rustfmt::skip] @@ -264,8 +273,8 @@ const ENUM_FROM_BYTES_CFG: Config = { // Like structs, unions are `FromBytes` if // - all fields are `FromBytes` -fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromBytes", true, None, false) +fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { + impl_block(ast, unn, "FromBytes", true, None, None) } // A struct is `AsBytes` if: @@ -274,7 +283,7 @@ fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::T // - no padding (size of struct equals sum of size of field types) // - `repr(packed)` -fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { +fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); let is_transparent = reprs.contains(&StructRepr::Transparent); let is_packed = reprs.contains(&StructRepr::Packed); @@ -299,7 +308,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2: // any padding bytes would need to come from the fields, all of which // we require to be `AsBytes` (meaning they don't have any padding). let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; - impl_block(ast, strct, "AsBytes", true, padding_check, false) + impl_block(ast, strct, "AsBytes", true, padding_check, None) } const STRUCT_UNION_AS_BYTES_CFG: Config = Config { @@ -313,7 +322,7 @@ const STRUCT_UNION_AS_BYTES_CFG: Config = Config { // An enum is `AsBytes` if it is C-like and has a defined repr. -fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement AsBytes") .to_compile_error(); @@ -322,7 +331,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, None, false) + impl_block(ast, enm, "AsBytes", false, None, None) } #[rustfmt::skip] @@ -355,7 +364,7 @@ const ENUM_AS_BYTES_CFG: Config = { // - `repr(C)`, `repr(transparent)`, or `repr(packed)` // - no padding (size of union equals size of each field type) -fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { +fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { // TODO(#10): Support type parameters. if !ast.generics.params.is_empty() { return Error::new(Span::call_site(), "unsupported on types with type parameters") @@ -364,7 +373,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), false) + impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), None) } // A struct is `Unaligned` if: @@ -373,11 +382,11 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok // - all fields `Unaligned` // - `repr(packed)` -fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { +fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, None, false) + impl_block(ast, strct, "Unaligned", require_trait_bound, None, None) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -393,7 +402,7 @@ const STRUCT_UNION_UNALIGNED_CFG: Config = Config { // - No `repr(align(N > 1))` // - `repr(u8)` or `repr(i8)` -fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement Unaligned") .to_compile_error(); @@ -408,7 +417,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, None, false) + impl_block(ast, enm, "Unaligned", true, None, None) } #[rustfmt::skip] @@ -442,11 +451,11 @@ const ENUM_UNALIGNED_CFG: Config = { // - all fields `Unaligned` // - `repr(packed)` -fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { +fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, None, false) + impl_block(ast, unn, "Unaligned", require_trait_bound, None, None) } // This enum describes what kind of padding check needs to be generated for the @@ -477,8 +486,8 @@ fn impl_block( trait_name: &str, require_trait_bound: bool, padding_check: Option, - emit_is_bit_valid: bool, -) -> proc_macro2::TokenStream { + is_bit_valid_opts: Option, +) -> TokenStream { // In this documentation, we will refer to this hypothetical struct: // // #[derive(FromBytes)] @@ -587,7 +596,33 @@ fn impl_block( GenericParam::Const(cnst) => quote!(#cnst), }); - let is_bit_valid = emit_is_bit_valid.then(|| { + let is_bit_valid = is_bit_valid_opts.map(|opts| { + let validator = Option::from(opts) + .map(|ZcValidatorAttr(validator)| { + quote::quote_spanned!(validator.span()=> && { + // This assignment helps make the compiler error more precise if + // `#validator` doesn't return a `bool`. We use `quote_spanned!` + // so that the error is attributed to the annotation rather than + // just to the `TryFromBytes` token in + // `#[derive(TryFromBytes)]`: + // + // #[zerocopy(validator = "validate")] + // ^^^^^^^^^^ + // | + // | + // (type errors attributed to this span) + // + // TODO: Validate that assigning to this function pointer has no + // performance impact. Presumably it's trivially + // constant-foldable, but since it's a function pointer, maybe + // the compiler is worse at optimizing it? + let validate: fn(&Self) -> bool = Self::#validator; + // SAFETY: TODO + validate(unsafe { candidate.as_ref() }) + }) + }) + .into_iter(); + let field_names = fields.iter().map(|(name, _ty)| name); let field_tys = fields.iter().map(|(_name, ty)| ty); quote!( @@ -600,6 +635,8 @@ fn impl_block( // SAFETY: TODO <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(f) })* + + #(#validator)* } ) }); @@ -616,10 +653,38 @@ fn impl_block( } } -fn print_all_errors(errors: Vec) -> proc_macro2::TokenStream { +fn print_all_errors(errors: Vec) -> TokenStream { errors.iter().map(Error::to_compile_error).collect() } +struct ZcValidatorAttr(Ident); + +fn parse_zerocopy_attrs(attrs: &[syn::Attribute]) -> Result, Error> { + let mut attrs = attrs.iter().filter(|attr| attr.path().is_ident("zerocopy")); + // TODO: Any way to do this unwrap-or-return-Ok(None) using a combinator? + let attr = if let Some(attr) = attrs.next() { + attr + } else { + return Ok(None); + }; + if let Some(attr) = attrs.next() { + return Err(Error::new_spanned(attr, "multiple attributes are unsupported")); + } + + // Attempt to parse as an assignment expression (ie, `validator = + // "validator"`). + let assign: syn::ExprAssign = attr.parse_args().map_err(|_| { + Error::new_spanned(attr, "expected syntax: #[zerocopy(validator = \"validate\")]") + })?; + let validator = if let Expr::Lit(ExprLit { lit: Lit::Str(validator), .. }) = *assign.right { + validator + } else { + return Err(Error::new_spanned(attr, "expected syntax: validator = \"validate\"")); + }; + + Ok(Some(ZcValidatorAttr(validator.parse()?))) +} + // A polyfill for `Option::then_some`, which was added after our MSRV. // // TODO(#67): Remove this once our MSRV is >= 1.62.