diff --git a/src/aead/aes_gcm.rs b/src/aead/aes_gcm.rs index 7c3173ffa..3ec133b15 100644 --- a/src/aead/aes_gcm.rs +++ b/src/aead/aes_gcm.rs @@ -316,34 +316,34 @@ pub(super) fn open( xi, ) }; - let in_out = match in_out_slice.get_mut(processed..) { - Some(remaining) => remaining, - None => { + let in_out_slice = in_out_slice.get_mut(processed..).unwrap_or_else(|| { + // This can't happen. If it did, then the assembly already + // caused a buffer overflow. + unreachable!() + }); + // Authenticate any remaining whole blocks. + let in_out = Overlapping::new(in_out_slice, src.clone()).unwrap_or_else( + |SrcIndexError { .. }| { // This can't happen. If it did, then the assembly already - // caused a buffer overflow. + // overwrote part of the remaining input. unreachable!() - } - }; - // Authenticate any remaining whole blocks. - let input = match in_out.get(src.clone()) { - Some(remaining_input) => remaining_input, - None => unreachable!(), - }; - let (whole, _) = slice::as_chunks(input); + }, + ); + let (whole, _) = slice::as_chunks(in_out.input()); auth.update_blocks(whole); let whole_len = slice::flatten(whole).len(); // Decrypt any remaining whole blocks. - let whole = Overlapping::new(&mut in_out[..(src.start + whole_len)], src.clone()) + let whole = Overlapping::new(&mut in_out_slice[..(src.start + whole_len)], src.clone()) .map_err(error::erase::)?; aes_key.ctr32_encrypt_within(whole, &mut ctr); - let in_out = match in_out.get_mut(whole_len..) { + let in_out_slice = match in_out_slice.get_mut(whole_len..) { Some(partial) => partial, None => unreachable!(), }; - open_finish(aes_key, auth, in_out, src, ctr, tag_iv) + open_finish(aes_key, auth, in_out_slice, src, ctr, tag_iv) } #[cfg(target_arch = "aarch64")] diff --git a/src/aead/overlapping/base.rs b/src/aead/overlapping/base.rs index cd5a3f5dd..17f501f83 100644 --- a/src/aead/overlapping/base.rs +++ b/src/aead/overlapping/base.rs @@ -15,6 +15,7 @@ use core::ops::RangeFrom; pub struct Overlapping<'o, T> { + // Invariant: `assert!(self.in_out.get_mut(self.src.clone()).is_some())`. in_out: &'o mut [T], src: RangeFrom, } @@ -39,8 +40,15 @@ impl<'o, T> Overlapping<'o, T> { impl Overlapping<'_, T> { pub fn len(&self) -> usize { - self.in_out[self.src.clone()].len() + self.input().len() } + + pub fn input(&self) -> &[T] { + self.in_out.get(self.src.clone()).unwrap_or_else(|| { + unreachable!() // Invariant + }) + } + pub fn into_input_output_len(self) -> (*const T, *mut T, usize) { let len = self.len(); let output = self.in_out.as_mut_ptr();