Skip to content

Commit

Permalink
Add KnownLayout::validate_cast
Browse files Browse the repository at this point in the history
TODO: Tests

Co-authored-by: Jack Wrenn <[email protected]>
  • Loading branch information
joshlf and jswrenn committed Sep 8, 2023
1 parent 12e7fac commit 7adbc94
Show file tree
Hide file tree
Showing 2 changed files with 330 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ optional = true
zerocopy-derive = { version = "=0.7.3", path = "zerocopy-derive" }

[dev-dependencies]
assert_matches = "1.5"
itertools = "0.11"
rand = { version = "0.8.5", features = ["small_rng"] }
rustversion = "1.0"
static_assertions = "1.1"
Expand Down
328 changes: 328 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,21 @@ pub struct DstLayout {
/// `size_of::<T>()`. For DSTs, the size represents the size of the type
/// when the trailing slice field contains 0 elements.
/// - For all types, the alignment represents the alignment of the type.
// TODO: If we end up replacing this with separate size and alignment to
// make Kani happy, file an issue to eventually adopt the stdlib's
// `Alignment` type trick.
_base_layout: Layout,
/// For sized types, `None`. For DSTs, the size of the element type of the
/// trailing slice.
_trailing_slice_elem_size: Option<usize>,
}

#[cfg_attr(test, derive(Copy, Clone, Debug))]
enum _CastType {
_Prefix,
_Suffix,
}

impl DstLayout {
/// Constructs a `DstLayout` which describes `T`.
///
Expand All @@ -251,6 +260,193 @@ impl DstLayout {
_trailing_slice_elem_size: Some(mem::size_of::<T>()),
}
}

/// Validates that a cast is sound from a layout perspective.
///
/// Validates that the size and alignment requirements of a type with the
/// layout described in `self` would not be violated by performing a
/// `cast_type` cast from a pointer with address `addr` which refers to a
/// memory region of size `bytes_len`.
///
/// If the cast is valid, `validate_cast` returns `(elems, split_at)`. If
/// `self` describes a dynamically-sized type, then `elems` is the maximum
/// number of trailing slice elements for which a cast would be valid (for
/// sized types, `elem` is meaningless and should be ignored). `split_at` is
/// the index at which to split the memory region in order for the prefix
/// (suffix) to contain the result of the cast, and in order for the
/// remaining suffix (prefix) to contain the leftover bytes.
///
/// There are three conditions under which a cast can fail:
/// - The smallest possible value for the type is larger than the provided
/// memory region
/// - A prefix cast is requested, and `addr` does not satisfy `self`'s
/// alignment requirement
/// - A suffix cast is requested, and `addr + bytes_len` does not satisfy
/// `self`'s alignment requirement (TODO: Is this the precise condition
/// under which a suffix cast can fail?)
///
/// # Safety
///
/// The caller may assume that this implementation is correct, and may rely
/// on that assumption for the soundness of their code. In particular, the
/// caller may assume that:
/// - A pointer to the type (for dynamically sized types, this includes
/// `elems` as its pointer metadata) describes an object of size `size <=
/// bytes_len`
/// - If this is a prefix cast, `addr` satisfies `self`'s alignment
/// - If this is a suffix cast, `addr + bytes_len - size` satisfies `self`'s
/// alignment
///
/// # Panics
///
/// If `addr + bytes_len` overflows `usize`, `validate_cast` may panic, or
/// it may return incorrect results. No guarantees are made about when
/// `validate_cast` will panic. The caller should not rely on
/// `validate_cast` panicking in any particular condition, even if
/// `debug_assertions` are enabled.
const fn _validate_cast(
&self,
addr: usize,
bytes_len: usize,
cast_type: _CastType,
) -> Option<(usize, usize)> {
// `debug_assert!`, but with `#[allow(clippy::arithmetic_side_effects)]`.
macro_rules! __debug_assert {
($e:expr $(, $msg:expr)?) => {
debug_assert!({
#[allow(clippy::arithmetic_side_effects)]
let e = $e;
e
} $(, $msg)?);
};
}

let base_size = self._base_layout.size();

// Precondition
__debug_assert!(addr.checked_add(bytes_len).is_some(), "`addr` + `bytes_len` > usize::MAX");

// LEMMA 0: max_slice_bytes + base_size == bytes_len
//
// LEMMA 1: base_size <= bytes_len:
// - If `base_size > bytes_len`, `bytes_len.checked_sub(base_size)`
// returns `None`, and we return.
//
// TODO(#67): Once our MSRV is 1.65, use let-else:
// https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html#let-else-statements
let max_slice_bytes = if let Some(max_byte_slice) = bytes_len.checked_sub(base_size) {
max_byte_slice
} else {
return None;
};

// Lemma 0
__debug_assert!(max_slice_bytes + base_size == bytes_len);

// Lemma 1
__debug_assert!(base_size <= bytes_len);

let (elems, self_bytes) = if let Some(elem_size) = self._trailing_slice_elem_size {
// TODO(#67): Once our MSRV is 1.65, use let-else:
// https://blog.rust-lang.org/2022/11/03/Rust-1.65.0.html#let-else-statements
let elem_size = if let Some(elem_size) = NonZeroUsize::new(elem_size) {
elem_size
} else {
panic!("attempted to cast to slice type with zero-sized element");
};

// Guaranteed not to divide by 0 because `elem_size` is a
// `NonZeroUsize`.
#[allow(clippy::arithmetic_side_effects)]
let elems = max_slice_bytes / elem_size.get();

// NOTE: Another option for this step in the algorithm is to set
// `slice_bytes = elems * elem_size`. However, using multiplication
// causes Kani to choke. In practice, the compiler is likely to
// generate identical machine code in both cases. Note that this
// divide-then-mod approach is trivially optimizable into a single
// operation that computes both the quotient and the remainder.

// First line is guaranteed not to mod by 0 because `elem_size` is a
// `NonZeroUsize`. Second line is guaranteed not to underflow
// because `rem <= max_slice_bytes` thanks to the mod operation.
//
// LEMMA 2: slice_bytes <= max_slice_bytes
#[allow(clippy::arithmetic_side_effects)]
let rem = max_slice_bytes % elem_size.get();
#[allow(clippy::arithmetic_side_effects)]
let slice_bytes = max_slice_bytes - rem;

// Lemma 2
__debug_assert!(slice_bytes <= max_slice_bytes);

// Guaranteed not to overflow:
// - max_slice_bytes + base_size == bytes_len (lemma 0)
// - slice_bytes <= max_slice_bytes (lemma 2)
// - slice_bytes + base_size <= bytes_len (substitution)
// - bytes_len <= usize::MAX (bytes_len: usize)
// - slice_bytes + base_size <= usize::MAX (substitution)
//
// LEMMA 3: self_bytes <= bytes_len: TODO
#[allow(clippy::arithmetic_side_effects)]
let self_bytes = base_size + slice_bytes;

// Lemma 3
__debug_assert!(self_bytes <= bytes_len);

(elems, self_bytes)
} else {
(0, base_size)
};

// LEMMA 4: self_bytes <= bytes_len:
// - `if` branch returns `self_bytes`; lemma 3 guarantees `self_bytes <=
// bytes_len`
// - `else` branch returns `base_size`; lemma 1 guarantees `base_size <=
// bytes_len`

// Lemma 5
__debug_assert!(self_bytes <= bytes_len);

// `self_addr` indicates where in the given byte range the `Self` will
// start. If we're doing a prefix cast, it starts at the beginning. If
// we're doing a suffix cast, it starts after whatever bytes are
// remaining.
let (self_addr, split_at) = match cast_type {
_CastType::_Prefix => (addr, self_bytes),
_CastType::_Suffix => {
// Guaranteed not to underflow because `self_bytes <= bytes_len`
// (lemma 4).
//
// LEMMA 5: split_at == bytes_len - self_bytes
#[allow(clippy::arithmetic_side_effects)]
let split_at = bytes_len - self_bytes;

// Lemma 5
__debug_assert!(split_at == bytes_len - self_bytes);

// Guaranteed not to overflow:
// - addr + bytes_len <= usize::MAX (method precondition)
// - split_at == bytes_len - self_bytes (lemma 5)
// - addr + split_at == addr + bytes_len - self_bytes (substitution)
// - addr + split_at <= addr + bytes_len
// - addr + split_at <= usize::MAX (substitution)
#[allow(clippy::arithmetic_side_effects)]
let self_addr = addr + split_at;

(self_addr, split_at)
}
};

// Guaranteed not to divide by 0 because `.align()` guarantees that it
// returns a non-zero value.
#[allow(clippy::arithmetic_side_effects)]
if self_addr % self._base_layout.align() != 0 {
return None;
}

Some((elems, split_at))
}
}

/// A trait which carries information about a type's layout that is used by the
Expand Down Expand Up @@ -2738,6 +2934,138 @@ mod tests {
}
}

// This test takes a long time when running under Miri, so we skip it in
// that case. This is acceptable because this is a logic test that doesn't
// attempt to expose UB.
#[test]
#[cfg_attr(miri, ignore)]
fn test_validate_cast() {
fn layout(
base_size: usize,
align: usize,
_trailing_slice_elem_size: Option<usize>,
) -> DstLayout {
DstLayout {
_base_layout: Layout::from_size_align(base_size, align).unwrap(),
_trailing_slice_elem_size,
}
}

/// This macro accepts arguments in the form of:
///
/// layout(_, _, _).validate_cast(_, _, _), Ok(Some((_, _)))
/// | | | | | | | |
/// base_size ----+ | | | | | | |
/// align -----------+ | | | | | |
/// trailing_size ------+ | | | | |
/// addr --------------------------------+ | | | |
/// bytes_len ------------------------------+ | | |
/// cast_type ---------------------------------+ | |
/// elems --------------------------------------------------+ |
/// split_at --------------------------------------------------+
///
/// Each argument can either be an iterator or a wildcard. Each
/// wildcarded variable is implicitly replaced by an iterator over a
/// representative sample of values for that variable. Each `test!`
/// invocation iterates over every combination of values provided by
/// each variable's iterator (ie, the cartesian product) and validates
/// that the results are expected.
///
/// The final argument uses the same syntax, but it has a different
/// meaning:
/// - If it is `Ok(pat)`, then the pattern `pat` is supplied to
/// `assert_matches!` to validate the computed result for each
/// combination of input values.
/// - If it is `Err(mst)`, then `test!` validates that the call to
/// `validate_cast` panics with the given panic message.
///
/// Note that the meta-variables that match these variables have the
/// `tt` type, and some valid expressions are not valid `tt`s (such as
/// `a..b`). In this case, wrap the expression in parentheses, and it
/// will become valid `tt`.
macro_rules! test {
(
layout($base_size:tt, $align:tt, $trailing_size:tt)
.validate_cast($addr:tt, $bytes_len:tt, $cast_type:tt), $expect:pat $(,)?
) => {
itertools::iproduct!(
test!(@generate_usize $base_size),
test!(@generate_align $align),
test!(@generate_opt_usize $trailing_size),
test!(@generate_usize $addr),
test!(@generate_usize $bytes_len),
test!(@generate_cast_type $cast_type)
).for_each(|(base_size, align, trailing_size, addr, bytes_len, cast_type)| {
let actual = std::panic::catch_unwind(|| {
layout(base_size, align, trailing_size)._validate_cast(addr, bytes_len, cast_type)
}).map_err(|d| {
*d.downcast::<&'static str>().expect("expected string panic message").as_ref()
});
assert_matches::assert_matches!(
actual, $expect,
"layout({base_size}, {align}, {trailing_size:?}).validate_cast({addr}, {bytes_len}, {cast_type:?})",
);
});
};
(@generate_usize _) => { 0..8 };
(@generate_align _) => { [1, 2, 4, 8, 16] };
(@generate_opt_usize _) => { [None].into_iter().chain((0..8).map(Some).into_iter()) };
(@generate_cast_type _) => { [_CastType::_Prefix, _CastType::_Suffix] };
(@generate_cast_type $variant:ident) => { [_CastType::$variant] };
// Some expressions need to be wrapped in parentheses in order to be
// valid `tt`s (required by the top match pattern). See the comment
// below for more details. This arm removes these parentheses to avoid
// generating an `unused_parens` warning.
(@$_:ident ($vals:expr)) => { $vals };
(@$_:ident $vals:expr) => { $vals };
}

// base_size is too big for the memory region.
test!(layout((1..8), _, _).validate_cast(_, [0], _), Ok(None));
test!(layout((2..8), _, _).validate_cast(_, [1], _), Ok(None));

// addr is unaligned
test!(layout(_, [2], [None]).validate_cast([1, 3, 5, 7, 9], _, _Prefix), Ok(None));
test!(
layout(_, [2], ((1..8).map(Some))).validate_cast([1, 3, 5, 7, 9], _, _Prefix),
Ok(None)
);

// TODO: Test Suffix cast failure cases, especially regarding alignment.

// TDOO: Success cases

// Unfortunately, these constants cannot easily be used in the
// implementation of `validate_cast`, since `panic!` consumes a string
// literal, not an expression.
mod messages {
pub(super) const TRAILING: &str =
"attempted to cast to slice type with zero-sized element";
pub(super) const OVERFLOW: &str = "`addr` + `bytes_len` > usize::MAX";
}

// casts with ZST trailing element types are unsupported
test!(layout([1], [1], [Some(0)]).validate_cast([1], [1], _), Err(messages::TRAILING),);

// addr + bytes_len must not overflow usize
test!(
layout([1], [1], _).validate_cast([usize::MAX], (1..100), _),
Err(messages::OVERFLOW)
);
test!(
layout([1], [1], [None]).validate_cast((1..100), [usize::MAX], _),
Err(messages::OVERFLOW)
);
test!(
layout([1], [1], [None]).validate_cast(
[usize::MAX / 2 + 1, usize::MAX],
[usize::MAX / 2 + 1, usize::MAX],
_
),
Err(messages::OVERFLOW)
);
}

#[test]
fn test_known_layout() {
// Test that `$ty` and `ManuallyDrop<$ty>` have the expected layout.
Expand Down

0 comments on commit 7adbc94

Please sign in to comment.