diff --git a/Cargo.toml b/Cargo.toml index 117ebed0d7..cfc32c85d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,8 @@ optional = true zerocopy-derive = { version = "=0.7.3", path = "zerocopy-derive" } [dev-dependencies] +assert_matches = "1.5" +itertools = "0.11" rand = { version = "0.8.5", features = ["small_rng"] } rustversion = "1.0" static_assertions = "1.1" diff --git a/src/lib.rs b/src/lib.rs index 4da4e3db76..7c9f98cadf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -219,12 +219,21 @@ pub struct DstLayout { /// `size_of::()`. For DSTs, the size represents the size of the type /// when the trailing slice field contains 0 elements. /// - For all types, the alignment represents the alignment of the type. + // TODO: If we end up replacing this with separate size and alignment to + // make Kani happy, file an issue to eventually adopt the stdlib's + // `Alignment` type trick. _base_layout: Layout, /// For sized types, `None`. For DSTs, the size of the element type of the /// trailing slice. _trailing_slice_elem_size: Option, } +#[cfg_attr(test, derive(Copy, Clone, Debug))] +enum _CastType { + _Prefix, + _Suffix, +} + impl DstLayout { /// Constructs a `DstLayout` which describes `T`. /// @@ -251,6 +260,162 @@ impl DstLayout { _trailing_slice_elem_size: Some(mem::size_of::()), } } + + /// TODO + /// + /// The caller is responsible for ensuring that `addr + bytes_len` does not + /// overflow `usize`. + /// + /// # Panics + /// + /// If `addr + bytes_len` overflows `usize`, `validate_cast` may panic, or + /// it may return incorrect results. No guarantees are made about when + /// `validate_cast` will panic. The caller should not rely on + /// `validate_cast` panicking in any particular condition, even if + /// `debug_assertions` are enabled. + const fn _validate_cast( + &self, + addr: usize, + bytes_len: usize, + cast_type: _CastType, + ) -> Option<(usize, usize)> { + // `debug_assert!`, but with `#[allow(clippy::arithmetic_side_effects)]`. + macro_rules! __debug_assert { + ($e:expr) => { + debug_assert!({ + #[allow(clippy::arithmetic_side_effects)] + let e = $e; + e + }) + }; + } + + let base_size = self._base_layout.size(); + + // Precondition + __debug_assert!(addr.checked_add(bytes_len).is_some()); + + // LEMMA 0: max_slice_bytes + base_size == bytes_len + // + // LEMMA 1: base_size <= bytes_len: + // - If `base_size > bytes_len`, `bytes_len.checked_sub(base_size)` + // returns `None`, and we return. + // + // TODO(#67): Once our MSRV is 1.65, use let-else: + // https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html#let-else-statements + let max_slice_bytes = if let Some(max_byte_slice) = bytes_len.checked_sub(base_size) { + max_byte_slice + } else { + return None; + }; + + // Lemma 0 + __debug_assert!(max_slice_bytes + base_size == bytes_len); + + // Lemma 1 + __debug_assert!(base_size <= bytes_len); + + let (elems, self_bytes) = if let Some(elem_size) = self._trailing_slice_elem_size { + // TODO(#67): Once our MSRV is 1.65, use let-else: + // https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html#let-else-statements + let elem_size = if let Some(elem_size) = NonZeroUsize::new(elem_size) { + elem_size + } else { + panic!("attempted to cast to slice type with zero-sized element"); + }; + + // Guaranteed not to divide by 0 because `elem_size` is a + // `NonZeroUsize`. + #[allow(clippy::arithmetic_side_effects)] + let elems = max_slice_bytes / elem_size.get(); + + // NOTE: Another option for this step in the algorithm is to set + // `slice_bytes = elems * elem_size`. However, using multiplication + // causes Kani to choke. In practice, the compiler is likely to + // generate identical machine code in both cases. Note that this + // divide-then-mod approach is trivially optimizable into a single + // operation that computes both the quotient and the remainder. + + // First line is guaranteed not to mod by 0 because `elem_size` is a + // `NonZeroUsize`. Second line is guaranteed not to underflow + // because `rem <= max_slice_bytes` thanks to the mod operation. + // + // LEMMA 2: slice_bytes <= max_slice_bytes + #[allow(clippy::arithmetic_side_effects)] + let rem = max_slice_bytes % elem_size.get(); + #[allow(clippy::arithmetic_side_effects)] + let slice_bytes = max_slice_bytes - rem; + + // Lemma 2 + __debug_assert!(slice_bytes <= max_slice_bytes); + + // Guaranteed not to overflow: + // - max_slice_bytes + base_size == bytes_len (lemma 0) + // - slice_bytes <= max_slice_bytes (lemma 2) + // - slice_bytes + base_size <= bytes_len (substitution) + // - bytes_len <= usize::MAX (bytes_len: usize) + // - slice_bytes + base_size <= usize::MAX (substitution) + // + // LEMMA 3: self_bytes <= bytes_len: TODO + #[allow(clippy::arithmetic_side_effects)] + let self_bytes = base_size + slice_bytes; + + // Lemma 3 + __debug_assert!(self_bytes <= bytes_len); + + (elems, self_bytes) + } else { + (0, base_size) + }; + + // LEMMA 4: self_bytes <= bytes_len: + // - `if` branch returns `self_bytes`; lemma 3 guarantees `self_bytes <= + // bytes_len` + // - `else` branch returns `base_size`; lemma 1 guarantees `base_size <= + // bytes_len` + + // Lemma 5 + __debug_assert!(self_bytes <= bytes_len); + + // `self_addr` indicates where in the given byte range the `Self` will + // start. If we're doing a prefix cast, it starts at the beginning. If + // we're doing a suffix cast, it starts after whatever bytes are + // remaining. + let (self_addr, split_at) = match cast_type { + _CastType::_Prefix => (addr, self_bytes), + _CastType::_Suffix => { + // Guaranteed not to underflow because `self_bytes <= bytes_len` + // (lemma 4). + // + // LEMMA 5: split_at == bytes_len - self_bytes + #[allow(clippy::arithmetic_side_effects)] + let split_at = bytes_len - self_bytes; + + // Lemma 5 + __debug_assert!(split_at == bytes_len - self_bytes); + + // Guaranteed not to overflow: + // - addr + bytes_len <= usize::MAX (method precondition) + // - split_at == bytes_len - self_bytes (lemma 5) + // - addr + split_at == addr + bytes_len - self_bytes (substitution) + // - addr + split_at <= addr + bytes_len + // - addr + split_at <= usize::MAX (substitution) + #[allow(clippy::arithmetic_side_effects)] + let self_addr = addr + split_at; + + (self_addr, split_at) + } + }; + + // Guaranteed not to divide by 0 because `.align()` guarantees that it + // returns a non-zero value. + #[allow(clippy::arithmetic_side_effects)] + if self_addr % self._base_layout.align() != 0 { + return None; + } + + Some((elems, split_at)) + } } /// A trait which carries information about a type's layout that is used by the @@ -2738,6 +2903,99 @@ mod tests { } } + // This test takes a long time when running under Miri, so we skip it in + // that case. This is acceptable because this is a logic test that doesn't + // attempt to expose UB. + #[test] + #[cfg_attr(miri, ignore)] + fn test_validate_cast() { + let layout = |base_size, align, _trailing_slice_elem_size| DstLayout { + _base_layout: Layout::from_size_align(base_size, align).unwrap(), + _trailing_slice_elem_size, + }; + + macro_rules! test { + ( + layout($base_size:tt, $align:tt, $trailing_size:tt) + .validate_cast($addr:tt, $bytes_len:tt, $cast_type:tt), $expect:pat $(,)? + ) => { + itertools::iproduct!( + test!(@generate_usize $base_size), + test!(@generate_align $align), + test!(@generate_opt_usize $trailing_size), + test!(@generate_usize $addr), + test!(@generate_usize $bytes_len), + test!(@generate_cast_type $cast_type) + ).for_each(|(base_size, align, trailing_size, addr, bytes_len, cast_type)| { + assert_matches::assert_matches!( + layout(base_size, align, trailing_size)._validate_cast(addr, bytes_len, cast_type), $expect, + "layout({base_size}, {align}, {trailing_size:?}).validate_cast({addr}, {bytes_len}, {cast_type:?})", + ); + }); + }; + (@generate_usize _) => { 0..8 }; + (@generate_align _) => { [1, 2, 4, 8, 16] }; + (@generate_opt_usize _) => { [None].into_iter().chain((0..8).map(Some).into_iter()) }; + (@generate_cast_type _) => { [_CastType::_Prefix, _CastType::_Suffix] }; + (@generate_cast_type $variant:ident) => { [_CastType::$variant] }; + // Some expressions need to be wrapped in parentheses in order to be + // valid `tt`s (required by the top match pattern). See the comment + // below for more details. This arm removes these parentheses to + // avoid generating an `unused_parens` warning. + (@$_:ident ($vals:expr)) => { $vals }; + (@$_:ident $vals:expr) => { $vals }; + } + + // The format of all of these test cases is: + // + // layout(_, _, _).validate_cast(_, _, _), Some((_, _)) + // | | | | | | | | + // base_size ----+ | | | | | | | + // align -----------+ | | | | | | + // trailing_size ------+ | | | | | + // addr --------------------------------+ | | | | + // bytes_len ------------------------------+ | | | + // cast_type ---------------------------------+ | | + // elems -----------------------------------------------+ | + // split_at -----------------------------------------------+ + // + // Each argument can either be an iterator or a wildcard. Each + // wildcarded variable is implicitly replaced by an iterator over a + // representative sample of values for that variable. Each `test!` + // invocation iterates over every combination of values provided by each + // variable's iterator (ie, the cartesian product) and validates that + // the results are expected. + // + // The final argument uses the same syntax, but it has a different + // meaning. The final argument is a pattern that will be supplied to + // `assert_matches!` to validate the computed result for each + // combination of input values. + // + // Note that the meta-variables that match these variables have the `tt` + // type, and some valid expressions are not valid `tt`s (such as + // `a..b`). In this case, wrap the expression in parentheses, and it + // will become valid `tt`. + + // base_size is too big for the memory region. + test!(layout((1..8), _, _).validate_cast(_, [0], _), None); + test!(layout((2..8), _, _).validate_cast(_, [1], _), None); + + // addr is unaligned + test!(layout(_, [2], [None]).validate_cast([1, 3, 5, 7, 9], _, _Prefix), None); + test!(layout(_, [2], ((1..8).map(Some))).validate_cast([1, 3, 5, 7, 9], _, _Prefix), None); + + // TODO: Test Suffix cast failure cases, especially regarding alignment. + + // TDOO: Success cases + } + + #[test] + fn test_validate_cast_panics() { + // TODO: Test for these cases: + // - addr + bytes overflows usize + // - zero-sized trailing element type + } + #[test] fn test_known_layout() { // Test that `$ty` and `ManuallyDrop<$ty>` have the expected layout.