Skip to content

Commit

Permalink
aes_gcm: Use overlapping::PartialBlock for opening operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
briansmith committed Jan 2, 2025
1 parent a2a8e9b commit 44467df
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 39 deletions.
1 change: 1 addition & 0 deletions src/aead/aes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub(super) mod hw;
pub(super) mod vp;

pub type Overlapping<'o> = overlapping::Overlapping<'o, u8>;
pub type OverlappingPartialBlock<'o> = overlapping::PartialBlock<'o, u8, BLOCK_LEN>;

cfg_if! {
if #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] {
Expand Down
45 changes: 25 additions & 20 deletions src/aead/aes_gcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::{
aes::{self, Counter, Overlapping, BLOCK_LEN, ZERO_BLOCK},
aes::{self, Counter, Overlapping, OverlappingPartialBlock, BLOCK_LEN, ZERO_BLOCK},
gcm,
overlapping::SrcIndexError,
shift, Aad, Nonce, Tag,
Aad, Nonce, Tag,
};
use crate::{
cpu, error,
cpu,
error::{self, InputTooLongError},
polyfill::{slice, sliceutil::overwrite_at_start, usize_from_u64_saturated},
};
use core::ops::RangeFrom;
Expand Down Expand Up @@ -343,7 +344,11 @@ pub(super) fn open(
Some(partial) => partial,
None => unreachable!(),
};
open_finish(aes_key, auth, in_out_slice, src, ctr, tag_iv)
let in_out = Overlapping::new(in_out_slice, src)
.unwrap_or_else(|SrcIndexError { .. }| unreachable!());
let in_out = OverlappingPartialBlock::new(in_out)
.unwrap_or_else(|InputTooLongError { .. }| unreachable!());
open_finish(aes_key, auth, in_out, ctr, tag_iv)
}

#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -387,7 +392,11 @@ pub(super) fn open(
}
}
let remainder = &mut in_out_slice[whole_len..];
open_finish(aes_key, auth, remainder, src, ctr, tag_iv)
let remainder = Overlapping::new(remainder, src)
.unwrap_or_else(|SrcIndexError { .. }| unreachable!());
let remainder = OverlappingPartialBlock::new(remainder)
.unwrap_or_else(|InputTooLongError { .. }| unreachable!());
open_finish(aes_key, auth, remainder, ctr, tag_iv)
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
Expand Down Expand Up @@ -466,31 +475,27 @@ fn open_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
}
}

open_finish(
aes_key,
auth,
&mut in_out_slice[whole_len..],
src,
ctr,
tag_iv,
)
let in_out = Overlapping::new(&mut in_out_slice[whole_len..], src)
.unwrap_or_else(|SrcIndexError { .. }| unreachable!());
let in_out = OverlappingPartialBlock::new(in_out)
.unwrap_or_else(|InputTooLongError { .. }| unreachable!());

open_finish(aes_key, auth, in_out, ctr, tag_iv)
}

fn open_finish<A: aes::EncryptBlock, G: gcm::Gmult>(
aes_key: &A,
mut auth: gcm::Context<G>,
remainder: &mut [u8],
src: RangeFrom<usize>,
remainder: OverlappingPartialBlock<'_>,
ctr: Counter,
tag_iv: aes::Iv,
) -> Result<Tag, error::Unspecified> {
shift::shift_partial((src.start, remainder), |remainder| {
if remainder.len() > 0 {
let mut input = ZERO_BLOCK;
overwrite_at_start(&mut input, remainder);
overwrite_at_start(&mut input, remainder.input());
auth.update_block(input);
aes_key.encrypt_iv_xor_block(ctr.into(), input)
});

remainder.overwrite_at_start(aes_key.encrypt_iv_xor_block(ctr.into(), input));
}
Ok(finish(aes_key, auth, tag_iv))
}

Expand Down
8 changes: 8 additions & 0 deletions src/aead/overlapping/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ impl<'o, T> Overlapping<'o, T> {
pub fn into_slice_src_mut(self) -> (&'o mut [T], RangeFrom<usize>) {
(self.in_out, self.src)
}

pub(super) fn into_unwritten_output(self) -> &'o mut [T] {
let len = self.len();
self.in_out.get_mut(..len).unwrap_or_else(|| {
// The invariant ensures this succeeds.
unreachable!()
})
}
}

impl<T> Overlapping<'_, T> {
Expand Down
6 changes: 5 additions & 1 deletion src/aead/overlapping/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

pub use self::base::{Overlapping, SrcIndexError};
pub use self::{
base::{Overlapping, SrcIndexError},
partial_block::PartialBlock,
};

mod base;
mod partial_block;
59 changes: 59 additions & 0 deletions src/aead/overlapping/partial_block.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2024 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::Overlapping;
use crate::error::InputTooLongError;

pub struct PartialBlock<'i, T, const BLOCK_LEN: usize> {
// invariant: `self.in_out.len() < BLOCK_LEN`.
in_out: Overlapping<'i, T>,
}

impl<'i, T, const BLOCK_LEN: usize> PartialBlock<'i, T, BLOCK_LEN> {
pub fn new(in_out: Overlapping<'i, T>) -> Result<Self, InputTooLongError> {
let len = in_out.len();
if len >= BLOCK_LEN {
return Err(InputTooLongError::new(len));
}
Ok(Self { in_out })
}

pub fn overwrite_at_start(self, padded: [T; BLOCK_LEN])
where
T: Copy,
{
let len = self.len();
let output = self.in_out.into_unwritten_output();
assert!(output.len() <= padded.len());
output.copy_from_slice(&padded[..len]);
}
}

impl<T, const BLOCK_LEN: usize> PartialBlock<'_, T, BLOCK_LEN> {
#[inline(always)]
pub fn input(&self) -> &[T] {
let r = self.in_out.input();
// Help the optimizer optimize the caller using the invariant.
// TODO: Does this actually help?
if r.len() >= BLOCK_LEN {
unreachable!()
}
r
}

#[inline(always)]
pub fn len(&self) -> usize {
self.input().len()
}
}
18 changes: 0 additions & 18 deletions src/aead/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use crate::polyfill::sliceutil::overwrite_at_start;

#[cfg(target_arch = "x86")]
pub fn shift_full_blocks<const BLOCK_LEN: usize>(
in_out: super::overlapping::Overlapping<'_, u8>,
Expand All @@ -32,19 +30,3 @@ pub fn shift_full_blocks<const BLOCK_LEN: usize>(
*output = block;
}
}

pub fn shift_partial<const BLOCK_LEN: usize>(
(in_prefix_len, in_out): (usize, &mut [u8]),
transform: impl FnOnce(&[u8]) -> [u8; BLOCK_LEN],
) {
let (block, in_out_len) = {
let input = &in_out[in_prefix_len..];
let in_out_len = input.len();
if in_out_len == 0 {
return;
}
debug_assert!(in_out_len < BLOCK_LEN);
(transform(input), in_out_len)
};
overwrite_at_start(&mut in_out[..in_out_len], &block);
}

0 comments on commit 44467df

Please sign in to comment.