Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-KEM: AVX2 target feature edition #636

Merged
merged 18 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions libcrux-intrinsics/src/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,83 @@ pub type Vec256 = __m256i;
pub type Vec128 = __m128i;
pub type Vec256Float = __m256;

#[inline(always)]
pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) {
debug_assert_eq!(output.len(), 32);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}
#[inline(always)]
pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}
#[inline(always)]
pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) {
debug_assert_eq!(output.len(), 8);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut Vec256, vector);
}
}

#[inline(always)]
pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) {
debug_assert!(output.len() >= 8);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}
#[inline(always)]
pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) {
debug_assert_eq!(output.len(), 4);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}

#[inline(always)]
pub fn mm_storeu_bytes_si128(output: &mut [u8], vector: Vec128) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut Vec128, vector);
}
}

#[inline(always)]
pub fn mm_loadu_si128(input: &[u8]) -> Vec128 {
debug_assert_eq!(input.len(), 16);
unsafe { _mm_loadu_si128(input.as_ptr() as *const Vec128) }
}

#[inline(always)]
pub fn mm256_loadu_si256_u8(input: &[u8]) -> Vec256 {
debug_assert_eq!(input.len(), 32);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}
#[inline(always)]
pub fn mm256_loadu_si256_i16(input: &[i16]) -> Vec256 {
debug_assert_eq!(input.len(), 16);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}
#[inline(always)]
pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 {
debug_assert_eq!(input.len(), 8);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const Vec256) }
}

#[inline(always)]
pub fn mm256_setzero_si256() -> Vec256 {
unsafe { _mm256_setzero_si256() }
}
#[inline(always)]
pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 {
unsafe { _mm256_set_m128i(hi, lo) }
}

#[inline(always)]
pub fn mm_set_epi8(
byte15: u8,
byte14: u8,
Expand Down Expand Up @@ -111,6 +124,7 @@ pub fn mm_set_epi8(
}
}

#[inline(always)]
pub fn mm256_set_epi8(
byte31: i8,
byte30: i8,
Expand Down Expand Up @@ -154,9 +168,11 @@ pub fn mm256_set_epi8(
}
}

#[inline(always)]
pub fn mm256_set1_epi16(constant: i16) -> Vec256 {
unsafe { _mm256_set1_epi16(constant) }
}
#[inline(always)]
pub fn mm256_set_epi16(
input15: i16,
input14: i16,
Expand Down Expand Up @@ -242,21 +258,26 @@ pub fn mm256_abs_epi32(a: Vec256) -> Vec256 {
unsafe { _mm256_abs_epi32(a) }
}

#[inline(always)]
pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi16(lhs, rhs) }
}
#[inline(always)]
pub fn mm256_sub_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_sub_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm_sub_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_sub_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mullo_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mullo_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm_mullo_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_mullo_epi16(lhs, rhs) }
}
Expand Down Expand Up @@ -289,18 +310,22 @@ pub fn mm256_movemask_ps(a: Vec256Float) -> i32 {
unsafe { _mm256_movemask_ps(a) }
}

#[inline(always)]
pub fn mm_mulhi_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_mulhi_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mullo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mullo_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mulhi_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mulhi_epi16(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_mul_epu32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_mul_epu32(lhs, rhs) }
}
Expand All @@ -320,102 +345,126 @@ pub fn mm256_or_si256(a: Vec256, b: Vec256) -> Vec256 {
unsafe { _mm256_or_si256(a, b) }
}

#[inline(always)]
pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 {
unsafe { _mm256_testz_si256(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_xor_si256(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_srai_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srai_epi16(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srai_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srai_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_srli_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srli_epi16(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srli_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm_srli_epi64<const SHIFT_BY: i32>(vector: Vec128) -> Vec128 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm_srli_epi64(vector, SHIFT_BY) }
}
#[inline(always)]
pub fn mm256_srli_epi64<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm256_srli_epi64(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_slli_epi16<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_slli_epi16(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm256_slli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_slli_epi32(vector, SHIFT_BY) }
}

#[inline(always)]
pub fn mm_shuffle_epi8(vector: Vec128, control: Vec128) -> Vec128 {
unsafe { _mm_shuffle_epi8(vector, control) }
}
#[inline(always)]
pub fn mm256_shuffle_epi8(vector: Vec256, control: Vec256) -> Vec256 {
unsafe { _mm256_shuffle_epi8(vector, control) }
}
#[inline(always)]
pub fn mm256_shuffle_epi32<const CONTROL: i32>(vector: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_shuffle_epi32(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_permute4x64_epi64<const CONTROL: i32>(vector: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_permute4x64_epi64(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_unpackhi_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpackhi_epi64(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_unpacklo_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpacklo_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_unpackhi_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_unpackhi_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_castsi256_si128(vector: Vec256) -> Vec128 {
unsafe { _mm256_castsi256_si128(vector) }
}
#[inline(always)]
pub fn mm256_castsi128_si256(vector: Vec128) -> Vec256 {
unsafe { _mm256_castsi128_si256(vector) }
}

#[inline(always)]
pub fn mm256_cvtepi16_epi32(vector: Vec128) -> Vec256 {
unsafe { _mm256_cvtepi16_epi32(vector) }
}

#[inline(always)]
pub fn mm_packs_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unsafe { _mm_packs_epi16(lhs, rhs) }
}
#[inline(always)]
pub fn mm256_packs_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unsafe { _mm256_packs_epi32(lhs, rhs) }
}

#[inline(always)]
pub fn mm256_extracti128_si256<const CONTROL: i32>(vector: Vec256) -> Vec128 {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_extracti128_si256(vector, CONTROL) }
}

#[inline(always)]
pub fn mm256_inserti128_si256<const CONTROL: i32>(vector: Vec256, vector_i128: Vec128) -> Vec256 {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) }
Expand Down Expand Up @@ -465,9 +514,11 @@ pub fn mm256_srlv_epi64(vector: Vec256, counts: Vec256) -> Vec256 {
unsafe { _mm256_srlv_epi64(vector, counts) }
}

#[inline(always)]
pub fn mm_sllv_epi32(vector: Vec128, counts: Vec128) -> Vec128 {
unsafe { _mm_sllv_epi32(vector, counts) }
}
#[inline(always)]
pub fn mm256_sllv_epi32(vector: Vec256, counts: Vec256) -> Vec256 {
unsafe { _mm256_sllv_epi32(vector, counts) }
}
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/code_gen.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
This code was generated with the following revisions:
Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_core_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_mlkem_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_mlkem_avx2_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_mlkem_portable.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_mlkem_portable_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_sha3_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_sha3_avx2_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/internal/libcrux_sha3_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#ifndef __internal_libcrux_sha3_internal_H
Expand Down
8 changes: 4 additions & 4 deletions libcrux-ml-kem/c/libcrux_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* SPDX-License-Identifier: MIT or Apache-2.0
*
* This code was generated with the following revisions:
* Charon: 45f5a34f336e35c6cc2253bc90cbdb8d812cefa9
* Charon: 3a133fe0eee9bd3928d5bb16c24ddd2dd0f3ee7f
* Eurydice: 1fff1c51ae6e6c87eafd28ec9d5594f54bc91c0c
* Karamel: 8c3612018c25889288da6857771be3ad03b75bcd
* F*: 5643e656b989aca7629723653a2570c7df6252b9-dirty
* Libcrux: 2e8f138dbcbfbfabf4bbd994c8587ec00d197102
* Karamel: c31a22c1e07d2118c07ee5cebb640d863e31a198
* F*: 2c32d6e230851bbceadac7a21fc418fa2bb7e4bc
* Libcrux: 99b4e0ae6147eb731652e0ee355fc77d2c160664
*/

#include "internal/libcrux_core.h"
Expand Down
Loading
Loading