Skip to content

Commit

Permalink
[WIP][derive] Support custom TryFromBytes validator
Browse files Browse the repository at this point in the history
TODO:
- Cleaner way to pass name of validator to `impl_block`?
- Cleaner way to parse validator attribute?
- Safety comment in emitted call to `validate`
- Tests (especially including tests for the error message resulting from
  passing a validator with the wrong type signature - test for both
  argument types and return types)
  - Also test for invalid zerocopy attributes
  - Also test for hygiene - that the validator can't access variables
    from the `is_bit_valid` function body scope
- Other misc TODO comments in code
  • Loading branch information
joshlf committed Sep 19, 2023
1 parent f262f98 commit 5c45ea8
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 54 deletions.
14 changes: 9 additions & 5 deletions src/derive_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ mod tests {

#[test]
fn foo() {
#[derive(TryFromBytes)]
struct Foo {
f: u8,
b: bool,
}
#[derive(TryFromBytes, Eq, PartialEq, Debug)]
#[zerocopy(validator = |f| f.0 < 128)]
#[repr(C)]
struct Foo(u8, bool);

impl_known_layout!(Foo);

assert_eq!(try_transmute!([0u8, 0]), Some(Foo(0, false)));
assert_eq!(try_transmute!([1u8, 1]), Some(Foo(1, true)));
assert_eq!(try_transmute!([2u8, 2]), None::<Foo>);
assert_eq!(try_transmute!([128u8, 1]), None::<Foo>);
}
}
2 changes: 1 addition & 1 deletion zerocopy-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0.1"
quote = "1.0.10"
syn = "2.0.31"
syn = { version = "2.0.31", features = ["full", "parsing"] }

[dev-dependencies]
rustversion = "1.0"
Expand Down
163 changes: 115 additions & 48 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ mod ext;
mod repr;

use {
proc_macro2::Span,
proc_macro2::{Span, TokenStream},
quote::quote,
syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
GenericParam, Ident, Lit,
parse_quote, spanned::Spanned as _, Data, DataEnum, DataStruct, DataUnion, DeriveInput,
Error, Expr, ExprLit, GenericParam, Ident, Lit,
},
};

Expand All @@ -56,11 +56,39 @@ use {crate::ext::*, crate::repr::*};
// (https://doc.rust-lang.org/nightly/proc_macro/struct.Span.html#method.error),
// which is currently unstable. Revisit this once it's stable.

#[proc_macro_derive(TryFromBytes)]
// Unwraps a `Result<_, Vec<Error>>`, converting any `Err` value into a
// `TokenStream` and returning it.
macro_rules! try_or_print {
($e:expr) => {
match $e {
Ok(x) => x,
Err(errors) => return print_all_errors(errors).into(),
}
};
}

enum IsBitValidOpts {
WithValidator(ZcValidatorAttr),
NoValidator,
}

impl From<IsBitValidOpts> for Option<ZcValidatorAttr> {
fn from(opts: IsBitValidOpts) -> Option<ZcValidatorAttr> {
match opts {
IsBitValidOpts::WithValidator(validator) => Some(validator),
IsBitValidOpts::NoValidator => None,
}
}
}

#[proc_macro_derive(TryFromBytes, attributes(zerocopy))]
pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(ts as DeriveInput);

let validator = try_or_print!(parse_zerocopy_attrs(&ast.attrs).map_err(|e| vec![e]));
let opts = validator.map(IsBitValidOpts::WithValidator).unwrap_or(IsBitValidOpts::NoValidator);
match &ast.data {
Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
Data::Struct(strct) => impl_block(&ast, strct, "TryFromBytes", true, None, Some(opts)),
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Expand Down Expand Up @@ -115,21 +143,6 @@ pub fn derive_unaligned(ts: proc_macro::TokenStream) -> proc_macro::TokenStream
.into()
}

// Unwraps a `Result<_, Vec<Error>>`, converting any `Err` value into a
// `TokenStream` and returning it.
macro_rules! try_or_print {
($e:expr) => {
match $e {
Ok(x) => x,
Err(errors) => return print_all_errors(errors),
}
};
}

fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, "TryFromBytes", true, None, true)
}

const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C],
&[StructRepr::Transparent],
Expand All @@ -140,15 +153,15 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
// A struct is `FromZeroes` if:
// - all fields are `FromZeroes`

fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, "FromZeroes", true, None, false)
fn derive_from_zeroes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream {
impl_block(ast, strct, "FromZeroes", true, None, None)
}

// An enum is `FromZeroes` if:
// - all of its variants are fieldless
// - one of the variants has a discriminant of `0`

fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(ast, "only C-like enums can implement FromZeroes")
.to_compile_error();
Expand All @@ -175,21 +188,21 @@ fn derive_from_zeroes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::To
.to_compile_error();
}

impl_block(ast, enm, "FromZeroes", true, None, false)
impl_block(ast, enm, "FromZeroes", true, None, None)
}

// Like structs, unions are `FromZeroes` if
// - all fields are `FromZeroes`

fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, "FromZeroes", true, None, false)
fn derive_from_zeroes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream {
impl_block(ast, unn, "FromZeroes", true, None, None)
}

// A struct is `FromBytes` if:
// - all fields are `FromBytes`

fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
impl_block(ast, strct, "FromBytes", true, None, false)
fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream {
impl_block(ast, strct, "FromBytes", true, None, None)
}

// An enum is `FromBytes` if:
Expand All @@ -206,7 +219,7 @@ fn derive_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro
// platform-specific and, b) even on Rust's smallest bit width platform (32),
// this would require ~4 billion enum variants, which obviously isn't a thing.

fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(ast, "only C-like enums can implement FromBytes")
.to_compile_error();
Expand All @@ -232,7 +245,7 @@ fn derive_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Tok
.to_compile_error();
}

impl_block(ast, enm, "FromBytes", true, None, false)
impl_block(ast, enm, "FromBytes", true, None, None)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -262,8 +275,8 @@ const ENUM_FROM_BYTES_CFG: Config<EnumRepr> = {
// Like structs, unions are `FromBytes` if
// - all fields are `FromBytes`

fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
impl_block(ast, unn, "FromBytes", true, None, false)
fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream {
impl_block(ast, unn, "FromBytes", true, None, None)
}

// A struct is `AsBytes` if:
Expand All @@ -272,7 +285,7 @@ fn derive_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::T
// - no padding (size of struct equals sum of size of field types)
// - `repr(packed)`

fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream {
let reprs = try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast));
let is_transparent = reprs.contains(&StructRepr::Transparent);
let is_packed = reprs.contains(&StructRepr::Packed);
Expand All @@ -297,7 +310,7 @@ fn derive_as_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2:
// any padding bytes would need to come from the fields, all of which
// we require to be `AsBytes` (meaning they don't have any padding).
let padding_check = if is_transparent || is_packed { None } else { Some(PaddingCheck::Struct) };
impl_block(ast, strct, "AsBytes", true, padding_check, false)
impl_block(ast, strct, "AsBytes", true, padding_check, None)
}

const STRUCT_UNION_AS_BYTES_CFG: Config<StructRepr> = Config {
Expand All @@ -311,7 +324,7 @@ const STRUCT_UNION_AS_BYTES_CFG: Config<StructRepr> = Config {

// An enum is `AsBytes` if it is C-like and has a defined repr.

fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(ast, "only C-like enums can implement AsBytes")
.to_compile_error();
Expand All @@ -320,7 +333,7 @@ fn derive_as_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Token
// We don't care what the repr is; we only care that it is one of the
// allowed ones.
let _: Vec<repr::EnumRepr> = try_or_print!(ENUM_AS_BYTES_CFG.validate_reprs(ast));
impl_block(ast, enm, "AsBytes", false, None, false)
impl_block(ast, enm, "AsBytes", false, None, None)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -353,7 +366,7 @@ const ENUM_AS_BYTES_CFG: Config<EnumRepr> = {
// - `repr(C)`, `repr(transparent)`, or `repr(packed)`
// - no padding (size of union equals size of each field type)

fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream {
// TODO(#10): Support type parameters.
if !ast.generics.params.is_empty() {
return Error::new(Span::call_site(), "unsupported on types with type parameters")
Expand All @@ -362,7 +375,7 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok

try_or_print!(STRUCT_UNION_AS_BYTES_CFG.validate_reprs(ast));

impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), false)
impl_block(ast, unn, "AsBytes", true, Some(PaddingCheck::Union), None)
}

// A struct is `Unaligned` if:
Expand All @@ -371,11 +384,11 @@ fn derive_as_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::Tok
// - all fields `Unaligned`
// - `repr(packed)`

fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_macro2::TokenStream {
fn derive_unaligned_struct(ast: &DeriveInput, strct: &DataStruct) -> TokenStream {
let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast));
let require_trait_bound = !reprs.contains(&StructRepr::Packed);

impl_block(ast, strct, "Unaligned", require_trait_bound, None, false)
impl_block(ast, strct, "Unaligned", require_trait_bound, None, None)
}

const STRUCT_UNION_UNALIGNED_CFG: Config<StructRepr> = Config {
Expand All @@ -391,7 +404,7 @@ const STRUCT_UNION_UNALIGNED_CFG: Config<StructRepr> = Config {
// - No `repr(align(N > 1))`
// - `repr(u8)` or `repr(i8)`

fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> TokenStream {
if !enm.is_c_like() {
return Error::new_spanned(ast, "only C-like enums can implement Unaligned")
.to_compile_error();
Expand All @@ -406,7 +419,7 @@ fn derive_unaligned_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::Toke
// for `require_trait_bounds` doesn't really do anything. But it's
// marginally more future-proof in case that restriction is lifted in the
// future.
impl_block(ast, enm, "Unaligned", true, None, false)
impl_block(ast, enm, "Unaligned", true, None, None)
}

#[rustfmt::skip]
Expand Down Expand Up @@ -440,11 +453,11 @@ const ENUM_UNALIGNED_CFG: Config<EnumRepr> = {
// - all fields `Unaligned`
// - `repr(packed)`

fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
fn derive_unaligned_union(ast: &DeriveInput, unn: &DataUnion) -> TokenStream {
let reprs = try_or_print!(STRUCT_UNION_UNALIGNED_CFG.validate_reprs(ast));
let require_trait_bound = !reprs.contains(&StructRepr::Packed);

impl_block(ast, unn, "Unaligned", require_trait_bound, None, false)
impl_block(ast, unn, "Unaligned", require_trait_bound, None, None)
}

// This enum describes what kind of padding check needs to be generated for the
Expand Down Expand Up @@ -475,8 +488,8 @@ fn impl_block<D: DataExt>(
trait_name: &str,
require_trait_bound: bool,
padding_check: Option<PaddingCheck>,
emit_is_bit_valid: bool,
) -> proc_macro2::TokenStream {
is_bit_valid_opts: Option<IsBitValidOpts>,
) -> TokenStream {
// In this documentation, we will refer to this hypothetical struct:
//
// #[derive(FromBytes)]
Expand Down Expand Up @@ -585,7 +598,37 @@ fn impl_block<D: DataExt>(
GenericParam::Const(cnst) => quote!(#cnst),
});

let is_bit_valid = emit_is_bit_valid.then(|| {
let is_bit_valid = is_bit_valid_opts.map(|opts| {
let validator = Option::from(opts)
.map(|ZcValidatorAttr(validator)| {
quote::quote_spanned!(validator.span()=> && {
// This assignment helps make the compiler error more
// precise if `#validator` has the wrong signature. We use
// `quote_spanned!` so that the error is attributed to the
// annotation rather than just to the `TryFromBytes` token
// in `#[derive(TryFromBytes)]`:
//
// #[zerocopy(validator = Self::validate)]
// ^^^^^^^^^^^^^^
// |
// |
// (type errors attributed to this span)
//
// TODO: Validate that assigning to this function pointer
// has no performance impact. Presumably it's trivially
// constant-foldable, but since it's a function pointer,
// maybe the compiler is worse at optimizing it?
let validate: fn(&Self) -> bool = #validator;
// SAFETY: TODO; remember to mention that we need to be
// sound in the face of `validate` panicking. Maybe also add
// a safety comment where this token stream is included
// below to make it clear that its location is critical to
// safety.
validate(unsafe { candidate.as_ref() })
})
})
.into_iter();

let field_names = fields.iter().map(|(name, _ty)| name);
let field_tys = fields.iter().map(|(_name, ty)| ty);
quote!(
Expand Down Expand Up @@ -635,6 +678,8 @@ fn impl_block<D: DataExt>(
// before returning.
<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(f)
})*

#(#validator)*
}
)
});
Expand All @@ -652,10 +697,32 @@ fn impl_block<D: DataExt>(
}
}

fn print_all_errors(errors: Vec<Error>) -> proc_macro2::TokenStream {
fn print_all_errors(errors: Vec<Error>) -> TokenStream {
errors.iter().map(Error::to_compile_error).collect()
}

struct ZcValidatorAttr(Expr);

fn parse_zerocopy_attrs(attrs: &[syn::Attribute]) -> Result<Option<ZcValidatorAttr>, Error> {
let mut attrs = attrs.iter().filter(|attr| attr.path().is_ident("zerocopy"));
// TODO: Any way to do this unwrap-or-return-Ok(None) using a combinator?
let attr = if let Some(attr) = attrs.next() {
attr
} else {
return Ok(None);
};
if let Some(attr) = attrs.next() {
return Err(Error::new_spanned(attr, "multiple zerocopy attributes are unsupported"));
}

// Attempt to parse as an assignment expression (ie, `validator =
// "validator"`).
let assign: syn::ExprAssign = attr.parse_args().map_err(|_| {
Error::new_spanned(attr, "expected syntax: #[zerocopy(validator = \"validate\")]")
})?;
Ok(Some(ZcValidatorAttr(*assign.right)))
}

// A polyfill for `Option::then_some`, which was added after our MSRV.
//
// TODO(#67): Remove this once our MSRV is >= 1.62.
Expand Down

0 comments on commit 5c45ea8

Please sign in to comment.