From d8d49d7ea1736ad9b342094719dac4980cdec943 Mon Sep 17 00:00:00 2001 From: TrAyZeN Date: Thu, 5 Dec 2024 13:27:58 +0100 Subject: [PATCH] Allow to pass different trace and plaintext types --- src/distinguishers/cpa.rs | 10 ++++++---- src/distinguishers/cpa_normal.rs | 10 +++++----- src/distinguishers/dpa.rs | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/distinguishers/cpa.rs b/src/distinguishers/cpa.rs index e92bbf4..3e87393 100644 --- a/src/distinguishers/cpa.rs +++ b/src/distinguishers/cpa.rs @@ -15,7 +15,7 @@ use std::{iter::zip, ops::Add}; /// use ndarray::array; /// /// let traces = array![ -/// [77, 137, 51, 91], +/// [77u8, 137, 51, 91], /// [72, 61, 91, 83], /// [39, 49, 52, 23], /// [26, 114, 63, 45], @@ -44,9 +44,9 @@ use std::{iter::zip, ops::Add}; /// # Panics /// - Panic if `traces.shape()[0] != plaintexts.shape()[0]` /// - Panic if `batch_size` is 0. -pub fn cpa( +pub fn cpa( traces: ArrayView2, - plaintexts: ArrayView2, + plaintexts: ArrayView2

, guess_range: usize, target_byte: usize, leakage_func: F, @@ -54,6 +54,7 @@ pub fn cpa( ) -> Cpa where T: Into + Copy + Sync, + P: Into + Copy + Sync, F: Fn(usize, usize) -> usize + Send + Sync + Copy, { assert_eq!(traces.shape()[0], plaintexts.shape()[0]); @@ -171,9 +172,10 @@ where /// # Panics /// Panic in debug if `trace.shape()[0] != self.num_samples`. - pub fn update(&mut self, trace: ArrayView1, plaintext: ArrayView1) + pub fn update(&mut self, trace: ArrayView1, plaintext: ArrayView1

) where T: Into + Copy, + P: Into + Copy, { debug_assert_eq!(trace.shape()[0], self.num_samples); diff --git a/src/distinguishers/cpa_normal.rs b/src/distinguishers/cpa_normal.rs index aa9e37f..efb7b1f 100644 --- a/src/distinguishers/cpa_normal.rs +++ b/src/distinguishers/cpa_normal.rs @@ -42,16 +42,16 @@ use crate::distinguishers::cpa::Cpa; /// # Panics /// - Panic if `traces.shape()[0] != plaintexts.shape()[0]` /// - Panic if `batch_size` is 0. -pub fn cpa( +pub fn cpa( traces: ArrayView2, - plaintexts: ArrayView2, + plaintexts: ArrayView2

, guess_range: usize, leakage_func: F, batch_size: usize, ) -> Cpa where T: Into + Copy + Sync, - U: Into + Copy + Sync, + P: Into + Copy + Sync, F: Fn(ArrayView1, usize) -> usize + Send + Sync + Copy, { assert_eq!(traces.shape()[0], plaintexts.shape()[0]); @@ -127,10 +127,10 @@ where /// # Panics /// - Panic in debug if `trace_batch.shape()[0] != plaintext_batch.shape()[0]`. /// - Panic in debug if `trace_batch.shape()[1] != self.num_samples`. - pub fn update(&mut self, trace_batch: ArrayView2, plaintext_batch: ArrayView2) + pub fn update(&mut self, trace_batch: ArrayView2, plaintext_batch: ArrayView2

) where T: Into + Copy, - U: Into + Copy, + P: Into + Copy, { debug_assert_eq!(trace_batch.shape()[0], plaintext_batch.shape()[0]); debug_assert_eq!(trace_batch.shape()[1], self.num_samples); diff --git a/src/distinguishers/dpa.rs b/src/distinguishers/dpa.rs index 64644ef..98fd422 100644 --- a/src/distinguishers/dpa.rs +++ b/src/distinguishers/dpa.rs @@ -52,7 +52,7 @@ use crate::util::{argmax_by, argsort_by, max_per_row}; /// /// # Panics /// Panic if `batch_size` is not strictly positive. -pub fn dpa( +pub fn dpa( traces: ArrayView2, metadata: ArrayView1, guess_range: usize,