diff --git a/src/aead/aes.rs b/src/aead/aes.rs index 5f0585cb7..826c21a59 100644 --- a/src/aead/aes.rs +++ b/src/aead/aes.rs @@ -33,6 +33,7 @@ pub(super) mod hw; pub(super) mod vp; pub type Overlapping<'o> = overlapping::Overlapping<'o, u8>; +pub type OverlappingPartialBlock<'o> = overlapping::PartialBlock<'o, u8, BLOCK_LEN>; cfg_if! { if #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] { diff --git a/src/aead/aes_gcm.rs b/src/aead/aes_gcm.rs index fabaef95f..02c4a34fa 100644 --- a/src/aead/aes_gcm.rs +++ b/src/aead/aes_gcm.rs @@ -13,13 +13,14 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use super::{ - aes::{self, Counter, Overlapping, BLOCK_LEN, ZERO_BLOCK}, + aes::{self, Counter, Overlapping, OverlappingPartialBlock, BLOCK_LEN, ZERO_BLOCK}, gcm, overlapping::SrcIndexError, - shift, Aad, Nonce, Tag, + Aad, Nonce, Tag, }; use crate::{ - cpu, error, + cpu, + error::{self, InputTooLongError}, polyfill::{slice, sliceutil::overwrite_at_start, usize_from_u64_saturated}, }; use core::ops::RangeFrom; @@ -343,7 +344,11 @@ pub(super) fn open( Some(partial) => partial, None => unreachable!(), }; - open_finish(aes_key, auth, in_out_slice, src, ctr, tag_iv) + let in_out = Overlapping::new(in_out_slice, src) + .unwrap_or_else(|SrcIndexError { .. }| unreachable!()); + let in_out = OverlappingPartialBlock::new(in_out) + .unwrap_or_else(|InputTooLongError { .. }| unreachable!()); + open_finish(aes_key, auth, in_out, ctr, tag_iv) } #[cfg(target_arch = "aarch64")] @@ -466,31 +471,27 @@ fn open_strided( aes_key: &A, mut auth: gcm::Context, - remainder: &mut [u8], - src: RangeFrom, + remainder: OverlappingPartialBlock<'_>, ctr: Counter, tag_iv: aes::Iv, ) -> Result { - shift::shift_partial((src.start, remainder), |remainder| { + if remainder.len() > 0 { let mut input = ZERO_BLOCK; - overwrite_at_start(&mut input, remainder); + overwrite_at_start(&mut input, remainder.input()); auth.update_block(input); - aes_key.encrypt_iv_xor_block(ctr.into(), input) - }); - + remainder.overwrite_at_start(aes_key.encrypt_iv_xor_block(ctr.into(), input)); + } Ok(finish(aes_key, auth, tag_iv)) } diff --git a/src/aead/overlapping/base.rs b/src/aead/overlapping/base.rs index f4e217f41..eeb381d25 100644 --- a/src/aead/overlapping/base.rs +++ b/src/aead/overlapping/base.rs @@ -50,6 +50,14 @@ impl<'o, T> Overlapping<'o, T> { pub fn into_slice_src_mut(self) -> (&'o mut [T], RangeFrom) { (self.in_out, self.src) } + + pub(super) fn into_unwritten_output(self) -> &'o mut [T] { + let len = self.len(); + match self.in_out.get_mut(..len) { + Some(unwritten_output) => unwritten_output, + None => unreachable!(), + } + } } impl Overlapping<'_, T> { diff --git a/src/aead/overlapping/mod.rs b/src/aead/overlapping/mod.rs index fb9bfb3d2..17ce533c6 100644 --- a/src/aead/overlapping/mod.rs +++ b/src/aead/overlapping/mod.rs @@ -12,6 +12,10 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -pub use self::base::{Overlapping, SrcIndexError}; +pub use self::{ + base::{Overlapping, SrcIndexError}, + partial_block::PartialBlock, +}; mod base; +mod partial_block; diff --git a/src/aead/overlapping/partial_block.rs b/src/aead/overlapping/partial_block.rs new file mode 100644 index 000000000..63baa9283 --- /dev/null +++ b/src/aead/overlapping/partial_block.rs @@ -0,0 +1,60 @@ +// Copyright 2024 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +use super::Overlapping; +use crate::error::InputTooLongError; + +pub struct PartialBlock<'i, T, const BLOCK_LEN: usize> { + // invariant: `self.in_out.len() < BLOCK_LEN`. + in_out: Overlapping<'i, T>, +} + +impl<'i, T, const BLOCK_LEN: usize> PartialBlock<'i, T, BLOCK_LEN> { + #[inline(always)] + pub fn new(in_out: Overlapping<'i, T>) -> Result { + let len = in_out.len(); + if len >= BLOCK_LEN { + return Err(InputTooLongError::new(len)); + } + Ok(Self { in_out }) + } + + pub fn overwrite_at_start(self, padded: [T; BLOCK_LEN]) + where + T: Copy, + { + let len = self.len(); + let output = self.in_out.into_unwritten_output(); + assert!(output.len() <= padded.len()); + output.copy_from_slice(&padded[..len]); + } +} + +impl PartialBlock<'_, T, BLOCK_LEN> { + #[inline(always)] + pub fn input(&self) -> &[T] { + let r = self.in_out.input(); + // Help the optimizer optimize the caller using the invariant. + // TODO: Does this actually help? + if r.len() >= BLOCK_LEN { + unreachable!() + } + r + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.input().len() + } +} diff --git a/src/aead/shift.rs b/src/aead/shift.rs index d0cc74de4..0e253f43d 100644 --- a/src/aead/shift.rs +++ b/src/aead/shift.rs @@ -12,8 +12,6 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -use crate::polyfill::sliceutil::overwrite_at_start; - #[cfg(target_arch = "x86")] pub fn shift_full_blocks( in_out: super::overlapping::Overlapping<'_, u8>, @@ -32,19 +30,3 @@ pub fn shift_full_blocks( *output = block; } } - -pub fn shift_partial( - (in_prefix_len, in_out): (usize, &mut [u8]), - transform: impl FnOnce(&[u8]) -> [u8; BLOCK_LEN], -) { - let (block, in_out_len) = { - let input = &in_out[in_prefix_len..]; - let in_out_len = input.len(); - if in_out_len == 0 { - return; - } - debug_assert!(in_out_len < BLOCK_LEN); - (transform(input), in_out_len) - }; - overwrite_at_start(&mut in_out[..in_out_len], &block); -}