From 5c45ea83a3296c659a80e78ced77b7f6e4d82fad Mon Sep 17 00:00:00 2001 From: Joshua Liebow-Feeser Date: Sat, 9 Sep 2023 23:10:30 -0700 Subject: [PATCH] [WIP][derive] Support custom TryFromBytes validator TODO: - Cleaner way to pass name of validator to `impl_block`? - Cleaner way to parse validator attribute? - Safety comment in emitted call to `validate` - Tests (especially including tests for the error message resulting from passing a validator with the wrong type signature - test for both argument types and return types) - Also test for invalid zerocopy attributes - Also test for hygiene - that the validator can't access variables from the `is_bit_valid` function body scope - Other misc TODO comments in code --- src/derive_util.rs | 14 ++-- zerocopy-derive/Cargo.toml | 2 +- zerocopy-derive/src/lib.rs | 163 ++++++++++++++++++++++++++----------- 3 files changed, 125 insertions(+), 54 deletions(-) diff --git a/src/derive_util.rs b/src/derive_util.rs index 4fd76b15ca..84f8278786 100644 --- a/src/derive_util.rs +++ b/src/derive_util.rs @@ -128,12 +128,16 @@ mod tests { #[test] fn foo() { - #[derive(TryFromBytes)] - struct Foo { - f: u8, - b: bool, - } + #[derive(TryFromBytes, Eq, PartialEq, Debug)] + #[zerocopy(validator = |f| f.0 < 128)] + #[repr(C)] + struct Foo(u8, bool); impl_known_layout!(Foo); + + assert_eq!(try_transmute!([0u8, 0]), Some(Foo(0, false))); + assert_eq!(try_transmute!([1u8, 1]), Some(Foo(1, true))); + assert_eq!(try_transmute!([2u8, 2]), None::); + assert_eq!(try_transmute!([128u8, 1]), None::); } } diff --git a/zerocopy-derive/Cargo.toml b/zerocopy-derive/Cargo.toml index 4cb0f08f70..550e51bed6 100644 --- a/zerocopy-derive/Cargo.toml +++ b/zerocopy-derive/Cargo.toml @@ -20,7 +20,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.1" quote = "1.0.10" -syn = "2.0.31" +syn = { version = "2.0.31", features = ["full", "parsing"] } [dev-dependencies] rustversion = "1.0" diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 4c9466f805..daa0580a1e 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -28,11 +28,11 @@ mod ext; mod repr; use { - proc_macro2::Span, + proc_macro2::{Span, TokenStream}, quote::quote, syn::{ - parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit, - GenericParam, Ident, Lit, + parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput, + Error, Expr, ExprLit, GenericParam, Ident, Lit, }, }; @@ -56,11 +56,39 @@ use {crate::ext::*, crate::repr::*}; // (https://doc.rust-lang.org/nightly/proc_macro/struct.Span.html#method.error), // which is currently unstable. Revisit this once it's stable. -#[proc_macro_derive(TryFromBytes)] +// Unwraps a `Result<_, Vec>`, converting any `Err` value into a +// `TokenStream` and returning it. +macro_rules! try_or_print { + ($e:expr) => { + match $e { + Ok(x) => x, + Err(errors) => return print_all_errors(errors).into(), + } + }; +} + +enum IsBitValidOpts { + WithValidator(ZcValidatorAttr), + NoValidator, +} + +impl From for Option { + fn from(opts: IsBitValidOpts) -> Option { + match opts { + IsBitValidOpts::WithValidator(validator) => Some(validator), + IsBitValidOpts::NoValidator => None, + } + } +} + +#[proc_macro_derive(TryFromBytes, attributes(zerocopy))] pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream { let ast = syn::parse_macro_input!(ts as DeriveInput); + + let validator = try_or_print!(parse_zerocopy_attrs(&ast.attrs).map_err(|e| vec![e])); + let opts = validator.map(IsBitValidOpts::WithValidator).unwrap_or(IsBitValidOpts::NoValidator); match &ast.data { - Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct), + Data::Struct(strct) => impl_block(&ast, strct, "TryFromBytes", true, None, Some(opts)), Data::Enum(_) => { Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() } @@ -115,21 +143,6 @@ pub fn derive_unaligned(ts: proc_macro::TokenStream) -> proc_macro::TokenStream .into() } -// Unwraps a `Result<_, Vec>`, converting any `Err` value into a -// `TokenStream` and returning it. -macro_rules! try_or_print { - ($e:expr) => { - match $e { - Ok(x) => x, - Err(errors) => return print_all_errors(errors), - } - }; -} - -fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "TryFromBytes", true, None, true) -} - const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C], &[StructRepr::Transparent], @@ -140,15 +153,15 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ // A struct is `FromZeroes` if: // - all fields are `FromZeroes` -fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromZeroes", true, None, false) +fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { + impl_block(ast, strct, "FromZeroes", true, None, None) } // An enum is `FromZeroes` if: // - all of its variants are fieldless // - one of the variants has a discriminant of `0` -fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement FromZeroes") .to_compile_error(); @@ -175,21 +188,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To .to_compile_error(); } - impl_block(ast, enm, "FromZeroes", true, None, false) + impl_block(ast, enm, "FromZeroes", true, None, None) } // Like structs, unions are `FromZeroes` if // - all fields are `FromZeroes` -fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromZeroes", true, None, false) +fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { + impl_block(ast, unn, "FromZeroes", true, None, None) } // A struct is `FromBytes` if: // - all fields are `FromBytes` -fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { - impl_block(ast, strct, "FromBytes", true, None, false) +fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { + impl_block(ast, strct, "FromBytes", true, None, None) } // An enum is `FromBytes` if: @@ -206,7 +219,7 @@ fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro // platform-specific and, b) even on Rust's smallest bit width platform (32), // this would require ~4 billion enum variants, which obviously isn't a thing. -fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement FromBytes") .to_compile_error(); @@ -232,7 +245,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok .to_compile_error(); } - impl_block(ast, enm, "FromBytes", true, None, false) + impl_block(ast, enm, "FromBytes", true, None, None) } #[rustfmt::skip] @@ -262,8 +275,8 @@ const ENUM_FROM_BYTES_CFG: Config = { // Like structs, unions are `FromBytes` if // - all fields are `FromBytes` -fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { - impl_block(ast, unn, "FromBytes", true, None, false) +fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { + impl_block(ast, unn, "FromBytes", true, None, None) } // A struct is `AsBytes` if: @@ -272,7 +285,7 @@ fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::T // - no padding (size of struct equals sum of size of field types) // - `repr(packed)` -fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { +fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); let is_transparent = reprs.contains(&StructRepr::Transparent); let is_packed = reprs.contains(&StructRepr::Packed); @@ -297,7 +310,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2: // any padding bytes would need to come from the fields, all of which // we require to be `AsBytes` (meaning they don't have any padding). let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) }; - impl_block(ast, strct, "AsBytes", true, padding_check, false) + impl_block(ast, strct, "AsBytes", true, padding_check, None) } const STRUCT_UNION_AS_BYTES_CFG: Config = Config { @@ -311,7 +324,7 @@ const STRUCT_UNION_AS_BYTES_CFG: Config = Config { // An enum is `AsBytes` if it is C-like and has a defined repr. -fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement AsBytes") .to_compile_error(); @@ -320,7 +333,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token // We don't care what the repr is; we only care that it is one of the // allowed ones. let _: Vec = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, enm, "AsBytes", false, None, false) + impl_block(ast, enm, "AsBytes", false, None, None) } #[rustfmt::skip] @@ -353,7 +366,7 @@ const ENUM_AS_BYTES_CFG: Config = { // - `repr(C)`, `repr(transparent)`, or `repr(packed)` // - no padding (size of union equals size of each field type) -fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { +fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { // TODO(#10): Support type parameters. if !ast.generics.params.is_empty() { return Error::new(Span::call_site(), "unsupported on types with type parameters") @@ -362,7 +375,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast)); - impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), false) + impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), None) } // A struct is `Unaligned` if: @@ -371,11 +384,11 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok // - all fields `Unaligned` // - `repr(packed)` -fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream { +fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, strct, "Unaligned", require_trait_bound, None, false) + impl_block(ast, strct, "Unaligned", require_trait_bound, None, None) } const STRUCT_UNION_UNALIGNED_CFG: Config = Config { @@ -391,7 +404,7 @@ const STRUCT_UNION_UNALIGNED_CFG: Config = Config { // - No `repr(align(N > 1))` // - `repr(u8)` or `repr(i8)` -fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream { +fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream { if !enm.is_c_like() { return Error::new_spanned(ast, "only C-like enums can implement Unaligned") .to_compile_error(); @@ -406,7 +419,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke // for `require_trait_bounds` doesn't really do anything. But it's // marginally more future-proof in case that restriction is lifted in the // future. - impl_block(ast, enm, "Unaligned", true, None, false) + impl_block(ast, enm, "Unaligned", true, None, None) } #[rustfmt::skip] @@ -440,11 +453,11 @@ const ENUM_UNALIGNED_CFG: Config = { // - all fields `Unaligned` // - `repr(packed)` -fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { +fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream { let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast)); let require_trait_bound = !reprs.contains(&StructRepr::Packed); - impl_block(ast, unn, "Unaligned", require_trait_bound, None, false) + impl_block(ast, unn, "Unaligned", require_trait_bound, None, None) } // This enum describes what kind of padding check needs to be generated for the @@ -475,8 +488,8 @@ fn impl_block( trait_name: &str, require_trait_bound: bool, padding_check: Option, - emit_is_bit_valid: bool, -) -> proc_macro2::TokenStream { + is_bit_valid_opts: Option, +) -> TokenStream { // In this documentation, we will refer to this hypothetical struct: // // #[derive(FromBytes)] @@ -585,7 +598,37 @@ fn impl_block( GenericParam::Const(cnst) => quote!(#cnst), }); - let is_bit_valid = emit_is_bit_valid.then(|| { + let is_bit_valid = is_bit_valid_opts.map(|opts| { + let validator = Option::from(opts) + .map(|ZcValidatorAttr(validator)| { + quote::quote_spanned!(validator.span()=> && { + // This assignment helps make the compiler error more + // precise if `#validator` has the wrong signature. We use + // `quote_spanned!` so that the error is attributed to the + // annotation rather than just to the `TryFromBytes` token + // in `#[derive(TryFromBytes)]`: + // + // #[zerocopy(validator = Self::validate)] + // ^^^^^^^^^^^^^^ + // | + // | + // (type errors attributed to this span) + // + // TODO: Validate that assigning to this function pointer + // has no performance impact. Presumably it's trivially + // constant-foldable, but since it's a function pointer, + // maybe the compiler is worse at optimizing it? + let validate: fn(&Self) -> bool = #validator; + // SAFETY: TODO; remember to mention that we need to be + // sound in the face of `validate` panicking. Maybe also add + // a safety comment where this token stream is included + // below to make it clear that its location is critical to + // safety. + validate(unsafe { candidate.as_ref() }) + }) + }) + .into_iter(); + let field_names = fields.iter().map(|(name, _ty)| name); let field_tys = fields.iter().map(|(_name, ty)| ty); quote!( @@ -635,6 +678,8 @@ fn impl_block( // before returning. <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(f) })* + + #(#validator)* } ) }); @@ -652,10 +697,32 @@ fn impl_block( } } -fn print_all_errors(errors: Vec) -> proc_macro2::TokenStream { +fn print_all_errors(errors: Vec) -> TokenStream { errors.iter().map(Error::to_compile_error).collect() } +struct ZcValidatorAttr(Expr); + +fn parse_zerocopy_attrs(attrs: &[syn::Attribute]) -> Result, Error> { + let mut attrs = attrs.iter().filter(|attr| attr.path().is_ident("zerocopy")); + // TODO: Any way to do this unwrap-or-return-Ok(None) using a combinator? + let attr = if let Some(attr) = attrs.next() { + attr + } else { + return Ok(None); + }; + if let Some(attr) = attrs.next() { + return Err(Error::new_spanned(attr, "multiple zerocopy attributes are unsupported")); + } + + // Attempt to parse as an assignment expression (ie, `validator = + // "validator"`). + let assign: syn::ExprAssign = attr.parse_args().map_err(|_| { + Error::new_spanned(attr, "expected syntax: #[zerocopy(validator = \"validate\")]") + })?; + Ok(Some(ZcValidatorAttr(*assign.right))) +} + // A polyfill for `Option::then_some`, which was added after our MSRV. // // TODO(#67): Remove this once our MSRV is >= 1.62.