Skip to content

Commit

Permalink
Encode infallible alignment errors in types (#1718)
Browse files Browse the repository at this point in the history
Permit callers to prove at compile time that alignment errors are
unreachable for unaligned destination types. This permits them to
infallibly ignore this error condition.
  • Loading branch information
joshlf authored Sep 21, 2024
1 parent 4426bb2 commit c6b9554
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 21 deletions.
185 changes: 176 additions & 9 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ use core::error::Error;
#[cfg(all(not(zerocopy_core_error), any(feature = "std", test)))]
use std::error::Error;

use crate::{util::SendSyncPhantomData, KnownLayout, TryFromBytes};
use crate::{util::SendSyncPhantomData, KnownLayout, TryFromBytes, Unaligned};
#[cfg(doc)]
use crate::{FromBytes, Ref};

Expand Down Expand Up @@ -135,6 +135,51 @@ pub enum ConvertError<A, S, V> {
Validity(V),
}

impl<Src, Dst: ?Sized + Unaligned, S, V> From<ConvertError<AlignmentError<Src, Dst>, S, V>>
for ConvertError<Infallible, S, V>
{
/// Infallibly discards the alignment error from this `ConvertError` since
/// `Dst` is unaligned.
///
/// Since [`Dst: Unaligned`], it is impossible to encounter an alignment
/// error. This method permits discarding that alignment error infallibly
/// and replacing it with [`Infallible`].
///
/// [`Dst: Unaligned`]: crate::Unaligned
///
/// # Examples
///
/// ```
/// use core::convert::Infallible;
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout, Unaligned, Immutable)]
/// #[repr(C, packed)]
/// struct Bools {
/// one: bool,
/// two: bool,
/// many: [bool],
/// }
///
/// impl Bools {
/// fn parse(bytes: &[u8]) -> Result<&Bools, UnalignedTryCastError<&[u8], Bools>> {
/// // Since `Bools: Unaligned`, we can infallibly discard
/// // the alignment error.
/// Bools::try_ref_from_bytes(bytes).map_err(Into::into)
/// }
/// }
/// ```
#[inline]
fn from(err: ConvertError<AlignmentError<Src, Dst>, S, V>) -> ConvertError<Infallible, S, V> {
match err {
ConvertError::Alignment(e) => ConvertError::Alignment(Infallible::from(e)),
ConvertError::Size(e) => ConvertError::Size(e),
ConvertError::Validity(e) => ConvertError::Validity(e),
}
}
}

impl<A: fmt::Debug, S: fmt::Debug, V: fmt::Debug> fmt::Debug for ConvertError<A, S, V> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -177,11 +222,20 @@ pub struct AlignmentError<Src, Dst: ?Sized> {
/// The source value involved in the conversion.
src: Src,
/// The inner destination type inolved in the conversion.
///
/// INVARIANT: An `AlignmentError` may only be constructed if `Dst`'s
/// alignment requirement is greater than one.
dst: SendSyncPhantomData<Dst>,
}

impl<Src, Dst: ?Sized> AlignmentError<Src, Dst> {
pub(crate) fn new(src: Src) -> Self {
/// # Safety
///
/// The caller must ensure that `Dst`'s alignment requirement is greater
/// than one.
pub(crate) unsafe fn new_unchecked(src: Src) -> Self {
// INVARIANT: The caller guarantees that `Dst`'s alignment requirement
// is greater than one.
Self { src, dst: SendSyncPhantomData::default() }
}

Expand All @@ -192,6 +246,9 @@ impl<Src, Dst: ?Sized> AlignmentError<Src, Dst> {
}

pub(crate) fn with_src<NewSrc>(self, new_src: NewSrc) -> AlignmentError<NewSrc, Dst> {
// INVARIANT: `with_src` doesn't change the type of `Dst`, so the
// invariant that `Dst`'s alignment requirement is greater than one is
// preserved.
AlignmentError { src: new_src, dst: SendSyncPhantomData::default() }
}

Expand Down Expand Up @@ -255,6 +312,29 @@ impl<Src, Dst: ?Sized> AlignmentError<Src, Dst> {
}
}

impl<Src, Dst: ?Sized + Unaligned> From<AlignmentError<Src, Dst>> for Infallible {
#[inline(always)]
fn from(_: AlignmentError<Src, Dst>) -> Infallible {
// SAFETY: `AlignmentError`s can only be constructed when `Dst`'s
// alignment requirement is greater than one. In this block, `Dst:
// Unaligned`, which means that its alignment requirement is equal to
// one. Thus, it's not possible to reach here at runtime.
unsafe { core::hint::unreachable_unchecked() }
}
}

#[cfg(test)]
impl<Src, Dst> AlignmentError<Src, Dst> {
// A convenience constructor so that test code doesn't need to write
// `unsafe`.
fn new_checked(src: Src) -> AlignmentError<Src, Dst> {
assert_ne!(core::mem::align_of::<Dst>(), 1);
// SAFETY: The preceding assertion guarantees that `Dst`'s alignment
// requirement is greater than one.
unsafe { AlignmentError::new_unchecked(src) }
}
}

impl<Src, Dst: ?Sized> fmt::Debug for AlignmentError<Src, Dst> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -295,7 +375,7 @@ where
impl<Src, Dst: ?Sized, S, V> From<AlignmentError<Src, Dst>>
for ConvertError<AlignmentError<Src, Dst>, S, V>
{
#[inline]
#[inline(always)]
fn from(err: AlignmentError<Src, Dst>) -> Self {
Self::Alignment(err)
}
Expand Down Expand Up @@ -438,7 +518,7 @@ where
}

impl<Src, Dst: ?Sized, A, V> From<SizeError<Src, Dst>> for ConvertError<A, SizeError<Src, Dst>, V> {
#[inline]
#[inline(always)]
fn from(err: SizeError<Src, Dst>) -> Self {
Self::Size(err)
}
Expand Down Expand Up @@ -547,7 +627,7 @@ where
impl<Src, Dst: ?Sized + TryFromBytes, A, S> From<ValidityError<Src, Dst>>
for ConvertError<A, S, ValidityError<Src, Dst>>
{
#[inline]
#[inline(always)]
fn from(err: ValidityError<Src, Dst>) -> Self {
Self::Validity(err)
}
Expand Down Expand Up @@ -626,6 +706,57 @@ impl<Src, Dst: ?Sized> CastError<Src, Dst> {
}
}

impl<Src, Dst: ?Sized + Unaligned> From<CastError<Src, Dst>> for SizeError<Src, Dst> {
/// Infallibly extracts the [`SizeError`] from this `CastError` since `Dst`
/// is unaligned.
///
/// Since [`Dst: Unaligned`], it is impossible to encounter an alignment
/// error, and so the only error that can be encountered at runtime is a
/// [`SizeError`]. This method permits extracting that `SizeError`
/// infallibly.
///
/// [`Dst: Unaligned`]: crate::Unaligned
///
/// # Examples
///
/// ```rust
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
/// #[repr(C)]
/// struct UdpHeader {
/// src_port: [u8; 2],
/// dst_port: [u8; 2],
/// length: [u8; 2],
/// checksum: [u8; 2],
/// }
///
/// #[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
/// #[repr(C, packed)]
/// struct UdpPacket {
/// header: UdpHeader,
/// body: [u8],
/// }
///
/// impl UdpPacket {
/// pub fn parse(bytes: &[u8]) -> Result<&UdpPacket, SizeError<&[u8], UdpPacket>> {
/// // Since `UdpPacket: Unaligned`, we can map the `CastError` to a `SizeError`.
/// UdpPacket::ref_from_bytes(bytes).map_err(Into::into)
/// }
/// }
/// ```
#[inline(always)]
fn from(err: CastError<Src, Dst>) -> SizeError<Src, Dst> {
match err {
#[allow(unreachable_code)]
CastError::Alignment(e) => match Infallible::from(e) {},
CastError::Size(e) => e,
CastError::Validity(i) => match i {},
}
}
}

/// The error type of fallible reference conversions.
///
/// Fallible reference conversions, like [`TryFromBytes::try_ref_from_bytes`]
Expand Down Expand Up @@ -749,6 +880,42 @@ impl<Src, Dst: ?Sized + TryFromBytes> TryReadError<Src, Dst> {
}
}

/// The error type of fallible casts to unaligned types.
///
/// This is like [`TryCastError`], but for unaligned types. It is identical to
/// `TryCastError`, except that its alignment error is [`Infallible`].
///
/// As of this writing, none of zerocopy's API produces this error directly.
/// However, it is useful since it permits users to infallibly discard alignment
/// errors when they can prove statically that alignment errors are impossible.
///
/// # Examples
///
/// ```
/// use core::convert::Infallible;
/// use zerocopy::*;
/// # use zerocopy_derive::*;
///
/// #[derive(TryFromBytes, KnownLayout, Unaligned, Immutable)]
/// #[repr(C, packed)]
/// struct Bools {
/// one: bool,
/// two: bool,
/// many: [bool],
/// }
///
/// impl Bools {
/// fn parse(bytes: &[u8]) -> Result<&Bools, UnalignedTryCastError<&[u8], Bools>> {
/// // Since `Bools: Unaligned`, we can infallibly discard
/// // the alignment error.
/// Bools::try_ref_from_bytes(bytes).map_err(Into::into)
/// }
/// }
/// ```
#[allow(type_alias_bounds)]
pub type UnalignedTryCastError<Src, Dst: ?Sized + TryFromBytes> =
ConvertError<Infallible, SizeError<Src, Dst>, ValidityError<Src, Dst>>;

/// The error type of a failed allocation.
///
/// This type is intended to be deprecated in favor of the standard library's
Expand Down Expand Up @@ -818,7 +985,7 @@ mod tests {
let bytes = &aligned.bytes[1..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 1)\
Expand All @@ -829,7 +996,7 @@ mod tests {
let bytes = &aligned.bytes[2..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 2)\
Expand All @@ -840,7 +1007,7 @@ mod tests {
let bytes = &aligned.bytes[3..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 1)\
Expand All @@ -851,7 +1018,7 @@ mod tests {
let bytes = &aligned.bytes[4..];
let addr = crate::util::AsAddress::addr(bytes);
assert_eq!(
AlignmentError::<_, elain::Align::<8>>::new(bytes).to_string(),
AlignmentError::<_, elain::Align::<8>>::new_checked(bytes).to_string(),
format!("The conversion failed because the address of the source is not a multiple of the alignment of the destination type.\n\
\nSource type: &[u8]\
\nSource address: 0x{:x} (a multiple of 4)\
Expand Down
10 changes: 7 additions & 3 deletions src/pointer/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,8 +792,8 @@ mod _transitions {
where
T: Sized,
{
if !crate::util::aligned_to::<_, T>(self.as_non_null()) {
return Err(AlignmentError::new(self));
if let Err(err) = crate::util::validate_aligned_to::<_, T>(self.as_non_null()) {
return Err(err.with_src(self));
}

// SAFETY: We just checked the alignment.
Expand Down Expand Up @@ -1204,7 +1204,11 @@ mod _casts {
let (elems, split_at) = match maybe_metadata {
Ok((elems, split_at)) => (elems, split_at),
Err(MetadataCastError::Alignment) => {
return Err(CastError::Alignment(AlignmentError::new(self)))
// SAFETY: Since `validate_cast_and_convert_metadata`
// returned an alignment error, `U` must have an alignment
// requirement greater than one.
let err = unsafe { AlignmentError::<_, U>::new_unchecked(self) };
return Err(CastError::Alignment(err));
}
Err(MetadataCastError::Size) => return Err(CastError::Size(SizeError::new(self))),
};
Expand Down
12 changes: 6 additions & 6 deletions src/ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ where
if bytes.len() != mem::size_of::<T>() {
return Err(SizeError::new(bytes).into());
}
if !util::aligned_to::<_, T>(bytes.deref()) {
return Err(AlignmentError::new(bytes).into());
if let Err(err) = util::validate_aligned_to::<_, T>(bytes.deref()) {
return Err(err.with_src(bytes).into());
}

// SAFETY: We just validated size and alignment.
Expand All @@ -220,8 +220,8 @@ where
if bytes.len() < mem::size_of::<T>() {
return Err(SizeError::new(bytes).into());
}
if !util::aligned_to::<_, T>(bytes.deref()) {
return Err(AlignmentError::new(bytes).into());
if let Err(err) = util::validate_aligned_to::<_, T>(bytes.deref()) {
return Err(err.with_src(bytes).into());
}
let (bytes, suffix) =
bytes.split_at(mem::size_of::<T>()).map_err(|b| SizeError::new(b).into())?;
Expand All @@ -243,8 +243,8 @@ where
return Err(SizeError::new(bytes).into());
};
let (prefix, bytes) = bytes.split_at(split_at).map_err(|b| SizeError::new(b).into())?;
if !util::aligned_to::<_, T>(bytes.deref()) {
return Err(AlignmentError::new(bytes).into());
if let Err(err) = util::validate_aligned_to::<_, T>(bytes.deref()) {
return Err(err.with_src(bytes).into());
}
// SAFETY: Since `split_at` is defined as `bytes_len - size_of::<T>()`,
// the `bytes` which results from `let (prefix, bytes) =
Expand Down
13 changes: 10 additions & 3 deletions src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use core::{
};

use crate::{
error::AlignmentError,
pointer::invariant::{self, Invariants},
Unalign,
};
Expand Down Expand Up @@ -547,14 +548,20 @@ impl<T: ?Sized> AsAddress for *mut T {
}
}

/// Is `t` aligned to `align_of::<U>()`?
/// Validates that `t` is aligned to `align_of::<U>()`.
#[inline(always)]
pub(crate) fn aligned_to<T: AsAddress, U>(t: T) -> bool {
pub(crate) fn validate_aligned_to<T: AsAddress, U>(t: T) -> Result<(), AlignmentError<(), U>> {
// `mem::align_of::<U>()` is guaranteed to return a non-zero value, which in
// turn guarantees that this mod operation will not panic.
#[allow(clippy::arithmetic_side_effects)]
let remainder = t.addr() % mem::align_of::<U>();
remainder == 0
if remainder == 0 {
Ok(())
} else {
// SAFETY: We just confirmed that `t.addr() % align_of::<U>() != 0`.
// That's only possible if `align_of::<U>() > 1`.
Err(unsafe { AlignmentError::new_unchecked(()) })
}
}

/// Returns the bytes needed to pad `len` to the next multiple of `align`.
Expand Down

0 comments on commit c6b9554

Please sign in to comment.