diff --git a/Cargo.toml b/Cargo.toml index 4e78d72..b052b80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,6 @@ path = "src/utils.rs" default = [] # default = [ "grinding" ] grinding = [] -avx256 = ["arith/avx256"] [workspace] members = ["arith", "bi-kzg"] diff --git a/arith/Cargo.toml b/arith/Cargo.toml index d05aaed..f31a0bb 100644 --- a/arith/Cargo.toml +++ b/arith/Cargo.toml @@ -27,5 +27,3 @@ harness = false name = "ext_field" harness = false -[features] -avx256 = [] diff --git a/arith/src/extension_field.rs b/arith/src/extension_field.rs index e0f3547..36a9c30 100644 --- a/arith/src/extension_field.rs +++ b/arith/src/extension_field.rs @@ -1,16 +1,12 @@ mod fr_ext; -// mod gf2_127; mod gf2_128; mod gf2_128x8; mod m31_ext; mod m31_ext3x16; use crate::{Field, FieldSerde}; -// pub use gf2_127::*; pub use gf2_128::*; pub use gf2_128x8::GF2_128x8; -#[cfg(target_arch = "x86_64")] -pub use gf2_128x8::GF2_128x8_256; pub use m31_ext::M31Ext3; pub use m31_ext3x16::M31Ext3x16; diff --git a/arith/src/extension_field/gf2_127.rs b/arith/src/extension_field/gf2_127.rs deleted file mode 100644 index 306fa59..0000000 --- a/arith/src/extension_field/gf2_127.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[cfg(target_arch = "aarch64")] -pub(crate) mod neon; -#[cfg(target_arch = "aarch64")] -pub type GF2_127 = neon::NeonGF2_127; - -#[cfg(target_arch = "x86_64")] -mod avx; -#[cfg(target_arch = "x86_64")] -pub type GF2_127 = avx::AVX512GF2_127; diff --git a/arith/src/extension_field/gf2_127/avx.rs b/arith/src/extension_field/gf2_127/avx.rs deleted file mode 100644 index 7548045..0000000 --- a/arith/src/extension_field/gf2_127/avx.rs +++ /dev/null @@ -1,367 +0,0 @@ -use std::iter::{Product, Sum}; -use std::{ - arch::x86_64::*, - mem::transmute, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -use crate::{field_common, ExtensionField, Field, FieldSerde, FieldSerdeResult, GF2}; - -#[derive(Debug, Clone, Copy)] -pub struct AVX512GF2_127 { - pub v: __m128i, -} - -field_common!(AVX512GF2_127); - -impl FieldSerde for AVX512GF2_127 { - const SERIALIZED_SIZE: usize = 16; - - #[inline(always)] - fn serialize_into(&self, mut writer: W) -> FieldSerdeResult<()> { - unsafe { writer.write_all(transmute::<__m128i, [u8; 16]>(self.v).as_ref())? }; - Ok(()) - } - - #[inline(always)] - fn deserialize_from(mut reader: R) -> FieldSerdeResult { - let mut u = [0u8; Self::SERIALIZED_SIZE]; - reader.read_exact(&mut u)?; - u[Self::SERIALIZED_SIZE - 1] &= 0x7F; // Should we do a modular operation here? - - unsafe { - Ok(AVX512GF2_127 { - v: transmute::<[u8; Self::SERIALIZED_SIZE], __m128i>(u), - }) - } - } - - #[inline(always)] - fn try_deserialize_from_ecc_format(mut reader: R) -> FieldSerdeResult { - let mut u = [0u8; 32]; - reader.read_exact(&mut u)?; - assert!(u[15] <= 0x7F); // and ignoring 16 - 31 - Ok(unsafe { - AVX512GF2_127 { - v: transmute::<[u8; 16], __m128i>(u[..16].try_into().unwrap()), - } - }) - } -} - -// mod x^127 + x + 1 -impl Field for AVX512GF2_127 { - const NAME: &'static str = "Galios Field 2^127"; - - const SIZE: usize = 128 / 8; - - const FIELD_SIZE: usize = 127; // in bits - - const ZERO: Self = AVX512GF2_127 { - v: unsafe { std::mem::zeroed() }, - }; - - const ONE: Self = AVX512GF2_127 { - v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) }, - }; - - const INV_2: Self = AVX512GF2_127 { - v: unsafe { std::mem::zeroed() }, - }; // should not be used - - #[inline(always)] - fn zero() -> Self { - AVX512GF2_127 { - v: unsafe { std::mem::zeroed() }, - } - } - - #[inline(always)] - fn one() -> Self { - AVX512GF2_127 { - v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) }, - } - } - - #[inline(always)] - fn random_unsafe(mut rng: impl rand::RngCore) -> Self { - let mut u = [0u8; 16]; - rng.fill_bytes(&mut u); - u[15] &= 0x7F; - unsafe { - AVX512GF2_127 { - v: *(u.as_ptr() as *const __m128i), - } - } - } - - #[inline(always)] - fn random_bool(mut rng: impl rand::RngCore) -> Self { - AVX512GF2_127 { - v: unsafe { std::mem::transmute::<[u32; 4], __m128i>([rng.next_u32() % 2, 0, 0, 0]) }, - } - } - - #[inline(always)] - fn is_zero(&self) -> bool { - unsafe { std::mem::transmute::<__m128i, [u8; 16]>(self.v) == [0; 16] } - } - - #[inline(always)] - fn exp(&self, exponent: u128) -> Self { - let mut e = exponent; - let mut res = Self::one(); - let mut t = *self; - while e > 0 { - if e & 1 == 1 { - res *= t; - } - t = t * t; - e >>= 1; - } - res - } - - #[inline(always)] - fn inv(&self) -> Option { - if self.is_zero() { - return None; - } - let p_m2 = (1u128 << 127) - 2; - Some(Self::exp(self, p_m2)) - } - - #[inline(always)] - fn square(&self) -> Self { - self * self - } - - #[inline(always)] - fn as_u32_unchecked(&self) -> u32 { - unimplemented!("u32 for GF127 doesn't make sense") - } - - #[inline(always)] - fn from_uniform_bytes(bytes: &[u8; 32]) -> Self { - let mut bytes = bytes.clone(); - bytes[15] &= 0x7F; - - unsafe { - AVX512GF2_127 { - v: transmute::<[u8; 16], __m128i>(bytes[..16].try_into().unwrap()), - } - } - } -} - -impl ExtensionField for AVX512GF2_127 { - const DEGREE: usize = 127; - - const W: u32 = 0x87; - - const X: Self = AVX512GF2_127 { - v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([2, 0, 0, 0]) }, - }; - - type BaseField = GF2; - - #[inline(always)] - fn mul_by_base_field(&self, base: &Self::BaseField) -> Self { - if base.v == 0 { - Self::zero() - } else { - *self - } - } - - #[inline(always)] - fn add_by_base_field(&self, base: &Self::BaseField) -> Self { - let mut res = *self; - res.v = unsafe { _mm_xor_si128(res.v, _mm_set_epi64x(0, base.v as i64)) }; - res - } - - /// - #[inline] - fn mul_by_x(&self) -> Self { - unsafe { - // Shift left by 1 bit - let shifted = _mm_slli_epi64(self.v, 1); - - // Get the most significant bit and move it - let msb = _mm_srli_epi64(self.v, 63); - let msb_moved = _mm_slli_si128(msb, 8); - - // Combine the shifted value with the moved msb - let shifted_consolidated = _mm_or_si128(shifted, msb_moved); - - // Create the reduction value (0b11) and the comparison value (1) - let reduction = { - let multiplier = _mm_set_epi64x(0, 0b11); - let one = _mm_set_epi64x(0, 1); - - // Check if the MSB was 1 and create a mask - let mask = _mm_cmpeq_epi64( - _mm_srli_si128(_mm_srli_epi64(shifted, 63), 8), - one); - - _mm_and_si128(mask, multiplier) - }; - - // Apply the reduction conditionally - let res = _mm_xor_si128(shifted_consolidated, reduction); - - Self { v: res } - } - } -} - -impl From for AVX512GF2_127 { - #[inline(always)] - fn from(v: GF2) -> Self { - AVX512GF2_127 { - v: unsafe { _mm_set_epi64x(0, v.v as i64) }, - } - } -} - -const X0TO126_MASK: __m128i = unsafe { transmute::<[u8; 16], __m128i>( - [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F])}; -const X127_MASK: __m128i = unsafe { transmute::<[u8; 16], __m128i>( - [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80])}; -const X127_REMINDER: __m128i = unsafe { transmute::<[u8; 16], __m128i>( - [0b11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80])}; - - -#[inline(always)] -unsafe fn mm_bitshift_left(x: __m128i) -> __m128i -{ - let mut carry = _mm_bslli_si128(x, 8); - carry = _mm_srli_epi64(carry, 64 - count); - let x = _mm_slli_epi64(x, count); - _mm_or_si128(x, carry) -} - - -#[inline] -unsafe fn gfmul(a: __m128i, b: __m128i) -> __m128i { - let xmm_mask = _mm_setr_epi32((0xFFffffff_u32) as i32, 0x0, 0x0, 0x0); - - // a = a0|a1, b = b0|b1 - - let mut tmp3 = _mm_clmulepi64_si128(a, b, 0x00); // tmp3 = a0 * b0 - let mut tmp6 = _mm_clmulepi64_si128(a, b, 0x11); // tmp6 = a1 * b1 - - // 78 = 0b0100_1110 - let mut tmp4 = _mm_shuffle_epi32(a, 78); // tmp4 = a1|a0 - let mut tmp5 = _mm_shuffle_epi32(b, 78); // tmp5 = b1|b0 - tmp4 = _mm_xor_si128(tmp4, a); // tmp4 = (a0 + a1) | (a0 + a1) - tmp5 = _mm_xor_si128(tmp5, b); // tmp5 = (b0 + b1) | (b0 + b1) - - tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00); // tmp4 = (a0 + a1) * (b0 + b1) - tmp4 = _mm_xor_si128(tmp4, tmp3); // tmp4 = (a0 + a1) * (b0 + b1) - a0 * b0 - tmp4 = _mm_xor_si128(tmp4, tmp6); // tmp4 = (a0 + a1) * (b0 + b1) - a0 * b0 - a1 * b1 = a0 * b1 + a1 * b0 - - // tmp4 = e1 | e0 - tmp5 = _mm_slli_si128(tmp4, 8); // tmp5 = e0 | 00 - tmp4 = _mm_srli_si128(tmp4, 8); // tmp4 = 00 | e1 - tmp3 = _mm_xor_si128(tmp3, tmp5); // the lower 128 bits, deg 0 - 127 - tmp6 = _mm_xor_si128(tmp6, tmp4); // the higher 128 bits, deg 128 - 252, the 124 least signicicant bits are non-zero - - // x^0 - x^126 - let x0to126 = _mm_and_si128(tmp3, X0TO126_MASK); - - // x^127 - tmp4 = _mm_and_si128(tmp3, X127_MASK); - tmp4 = _mm_cmpeq_epi64(tmp4, X127_MASK); - tmp4 = _mm_srli_si128(tmp4, 15); - let x127 = _mm_and_si128(tmp4, X127_REMINDER); - - // x^128 - x^252 - let x128to252 = - _mm_and_si128( - mm_bitshift_left::<2>(tmp6), - mm_bitshift_left::<1>(tmp6), - ); - - _mm_and_si128(_mm_and_si128(x0to126, x127), x128to252) - - // let mut tmp7 = _mm_srli_epi32(tmp6, 31); - // let mut tmp8 = _mm_srli_epi32(tmp6, 30); - // let tmp9 = _mm_srli_epi32(tmp6, 25); - - // tmp7 = _mm_xor_si128(tmp7, tmp8); - // tmp7 = _mm_xor_si128(tmp7, tmp9); - - // tmp8 = _mm_shuffle_epi32(tmp7, 147); - // tmp7 = _mm_and_si128(xmm_mask, tmp8); - // tmp8 = _mm_andnot_si128(xmm_mask, tmp8); - - // tmp3 = _mm_xor_si128(tmp3, tmp8); - // tmp6 = _mm_xor_si128(tmp6, tmp7); - - // let tmp10 = _mm_slli_epi32(tmp6, 1); - // tmp3 = _mm_xor_si128(tmp3, tmp10); - - // let tmp11 = _mm_slli_epi32(tmp6, 2); - // tmp3 = _mm_xor_si128(tmp3, tmp11); - - // let tmp12 = _mm_slli_epi32(tmp6, 7); - // tmp3 = _mm_xor_si128(tmp3, tmp12); - - // _mm_xor_si128(tmp3, tmp6) - -} - -impl Default for AVX512GF2_127 { - #[inline(always)] - fn default() -> Self { - Self::zero() - } -} - -impl PartialEq for AVX512GF2_127 { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - unsafe { _mm_test_all_ones(_mm_cmpeq_epi8(self.v, other.v)) == 1 } - } -} - -impl Neg for AVX512GF2_127 { - type Output = Self; - - #[inline(always)] - fn neg(self) -> Self { - self - } -} - -impl From for AVX512GF2_127 { - #[inline(always)] - fn from(v: u32) -> Self { - AVX512GF2_127 { - v: unsafe { std::mem::transmute::<[u32; 4], __m128i>([v, 0, 0, 0]) }, - } - } -} - -#[inline(always)] -fn add_internal(a: &AVX512GF2_127, b: &AVX512GF2_127) -> AVX512GF2_127 { - AVX512GF2_127 { - v: unsafe { _mm_xor_si128(a.v, b.v) }, - } -} - -#[inline(always)] -fn sub_internal(a: &AVX512GF2_127, b: &AVX512GF2_127) -> AVX512GF2_127 { - AVX512GF2_127 { - v: unsafe { _mm_xor_si128(a.v, b.v) }, - } -} - -#[inline(always)] -fn mul_internal(a: &AVX512GF2_127, b: &AVX512GF2_127) -> AVX512GF2_127 { - AVX512GF2_127 { - v: unsafe { gfmul(a.v, b.v) }, - } -} diff --git a/arith/src/extension_field/gf2_127/neon.rs b/arith/src/extension_field/gf2_127/neon.rs deleted file mode 100644 index e69de29..0000000 diff --git a/arith/src/extension_field/gf2_128.rs b/arith/src/extension_field/gf2_128.rs index e5cb17b..8ba6fc0 100644 --- a/arith/src/extension_field/gf2_128.rs +++ b/arith/src/extension_field/gf2_128.rs @@ -3,7 +3,12 @@ pub(crate) mod neon; #[cfg(target_arch = "aarch64")] pub type GF2_128 = neon::NeonGF2_128; -#[cfg(target_arch = "x86_64")] -mod avx; -#[cfg(target_arch = "x86_64")] -pub type GF2_128 = avx::AVX512GF2_128; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +mod avx512; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub type GF2_128 = avx512::AVX512GF2_128; + +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +mod avx256; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub type GF2_128 = avx256::AVX512GF2_128; diff --git a/arith/src/extension_field/gf2_128/avx.rs b/arith/src/extension_field/gf2_128/avx512.rs similarity index 100% rename from arith/src/extension_field/gf2_128/avx.rs rename to arith/src/extension_field/gf2_128/avx512.rs diff --git a/arith/src/extension_field/gf2_128x4/avx256.rs b/arith/src/extension_field/gf2_128x4/avx256.rs deleted file mode 100644 index 1783f70..0000000 --- a/arith/src/extension_field/gf2_128x4/avx256.rs +++ /dev/null @@ -1,466 +0,0 @@ -use crate::field_common; - -use crate::{Field, FieldSerde, FieldSerdeResult, SimdField, GF2_128}; -use std::fmt::Debug; -use std::{ - arch::x86_64::*, - iter::{Product, Sum}, - mem::transmute, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -#[derive(Clone, Copy)] -pub struct AVX256GF2_128x4 { - data: [__m256i; 2]; -} - -field_common!(AVX256GF2_128x4); - -impl AVX256GF2_128x4 { - #[inline(always)] - pub(crate) fn pack_full(data: __m128i) -> [__m256i; 2] { - unsafe { [_mm256_broadcast_i32x4(data), _mm256_broadcast_i32x4(data)] } - } -} - -impl FieldSerde for AVX256GF2_128x4 { - const SERIALIZED_SIZE: usize = 512 / 8; - - #[inline(always)] - fn serialize_into(&self, mut writer: W) -> FieldSerdeResult<()> { - unsafe { - let mut data = [0u8; 64]; - _mm256_storeu_si256(data.as_mut_ptr() as *mut i32, self.data[0]); - _mm256_storeu_si256(data.as_mut_ptr().add(32) as *mut i32, self.data[1]); - writer.write_all(&data)?; - } - Ok(()) - } - #[inline(always)] - fn deserialize_from(mut reader: R) -> FieldSerdeResult { - let mut data = [0u8; Self::SERIALIZED_SIZE]; - reader.read_exact(&mut data)?; - unsafe { - Ok(Self { - data: [_mm256_loadu_si256(data.as_ptr() as *const i32), _mm256_loadu_si256(data.as_ptr().add(8) as *const i32)], - }) - } - } - - #[inline(always)] - fn try_deserialize_from_ecc_format(mut reader: R) -> FieldSerdeResult { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf)?; - let data: __m128i = unsafe { _mm_loadu_si128(buf.as_ptr() as *const __m128i) }; - Ok(Self { - data: Self::pack_full(data), - }) - } -} - -const PACKED_0: __m256i = unsafe { transmute([0; 4]) }; - -const PACKED_INV_2: __m256i = unsafe { - transmute([ - 67_u64, - (1_u64) << 63, - 67_u64, - (1_u64) << 63, - ]) -}; - -// p(x) = x^128 + x^7 + x^2 + x + 1 -impl Field for AVX256GF2_128x4 { - const NAME: &'static str = "AVX256 Galios Field 2^128"; - - // size in bytes - const SIZE: usize = 512 / 8; - - const ZERO: Self = Self { data: PACKED_0 }; - - const INV_2: Self = Self { data: PACKED_INV_2 }; - - const FIELD_SIZE: usize = 128; - - #[inline(always)] - fn zero() -> Self { - unsafe { - let zero = _mm256_setzero_si256(); - Self { data: [zero, zero] } - } - } - - #[inline(always)] - fn is_zero(&self) -> bool { - unsafe { - let zero = _mm256_setzero_si256(); - let cmp = _mm256_cmpeq_epi64_mask(self.data[0], zero) & _mm256_cmpeq_epi64_mask(self.data[1], zero); - cmp == 0xFF // All 8 64-bit integers are equal (zero) - } - } - - #[inline(always)] - fn one() -> Self { - unsafe { - let one = _mm256_set_epi64(0, 1, 0, 1); - Self { data: [one, one] } - } - } - - #[inline(always)] - fn random_unsafe(mut rng: impl rand::RngCore) -> Self { - let data = unsafe { - _mm256_set_epi64( - rng.next_u64() as i64, - rng.next_u64() as i64, - rng.next_u64() as i64, - rng.next_u64() as i64, - ) - }; - Self { data } - } - - #[inline(always)] - fn random_bool(mut rng: impl rand::RngCore) -> Self { - let data = unsafe { - _mm256_set_epi64( - 0, - (rng.next_u64() % 2) as i64, - 0, - (rng.next_u64() % 2) as i64, - ) - }; - Self { data } - } - - #[inline(always)] - fn exp(&self, exponent: u128) -> Self { - let mut e = exponent; - let mut res = Self::one(); - let mut t = *self; - while e != 0 { - let b = e & 1; - if b == 1 { - res *= t; - } - t = t * t; - e >>= 1; - } - res - } - - #[inline(always)] - fn inv(&self) -> Option { - if self.is_zero() { - return None; - } - let p_m2 = !(0u128) - 1; - Some(Self::exp(self, p_m2)) - } - - #[inline(always)] - fn as_u32_unchecked(&self) -> u32 { - unimplemented!("self is a vector, cannot convert to u32") - } - - #[inline(always)] - fn from_uniform_bytes(_bytes: &[u8; 32]) -> Self { - todo!() - } - - #[inline(always)] - fn square(&self) -> Self { - *self * *self - } - - #[inline(always)] - fn double(&self) -> Self { - Self::ZERO - } - - #[inline(always)] - fn mul_by_2(&self) -> Self { - Self::ZERO - } - - #[inline(always)] - fn mul_by_3(&self) -> Self { - *self - } - - #[inline(always)] - fn mul_by_5(&self) -> Self { - *self - } - - #[inline(always)] - fn mul_by_6(&self) -> Self { - Self::ZERO - } -} -/* -credit to intel for the original implementation -void gfmul(__m128i a, __m128i b, __m128i *res) { - __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6; - __m128i tmp7, tmp8, tmp9, tmp10, tmp11, tmp12; - __m128i XMMMASK = _mm_setr_epi32(0xffffffff, 0x0, 0x0, 0x0); - - // a = a0|a1, b = b0|b1 - - tmp3 = _mm_clmulepi64_si128(a, b, 0x00); // tmp3 = a0 * b0 - tmp6 = _mm_clmulepi64_si128(a, b, 0x11); // tmp6 = a1 * b1 - - tmp4 = _mm_shuffle_epi32(a, 78); // tmp4 = a1|a0 - tmp5 = _mm_shuffle_epi32(b, 78); // tmp5 = b1|b0 - tmp4 = _mm_xor_si128(tmp4, a); // tmp4 = (a0 + a1) | (a0 + a1) - tmp5 = _mm_xor_si128(tmp5, b); // tmp5 = (b0 + b1) | (b0 + b1) - - tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00); // tmp4 = (a0 + a1) * (b0 + b1) - tmp4 = _mm_xor_si128(tmp4, tmp3); // tmp4 = (a0 + a1) * (b0 + b1) - a0 * b0 - tmp4 = _mm_xor_si128(tmp4, tmp6); // tmp4 = (a0 + a1) * (b0 + b1) - a0 * b0 - a1 * b1 = a0 * b1 + a1 * b0 - - tmp5 = _mm_slli_si128(tmp4, 8); - tmp4 = _mm_srli_si128(tmp4, 8); - tmp3 = _mm_xor_si128(tmp3, tmp5); - tmp6 = _mm_xor_si128(tmp6, tmp4); - - tmp7 = _mm_srli_epi32(tmp6, 31); - tmp8 = _mm_srli_epi32(tmp6, 30); - tmp9 = _mm_srli_epi32(tmp6, 25); - - tmp7 = _mm_xor_si128(tmp7, tmp8); - tmp7 = _mm_xor_si128(tmp7, tmp9); - - tmp8 = _mm_shuffle_epi32(tmp7, 147); - tmp7 = _mm_and_si128(XMMMASK, tmp8); - tmp8 = _mm_andnot_si128(XMMMASK, tmp8); - - tmp3 = _mm_xor_si128(tmp3, tmp8); - tmp6 = _mm_xor_si128(tmp6, tmp7); - - tmp10 = _mm_slli_epi32(tmp6, 1); - tmp3 = _mm_xor_si128(tmp3, tmp10); - - tmp11 = _mm_slli_epi32(tmp6, 2); - tmp3 = _mm_xor_si128(tmp3, tmp11); - - tmp12 = _mm_slli_epi32(tmp6, 7); - tmp3 = _mm_xor_si128(tmp3, tmp12); - - *res = _mm_xor_si128(tmp3, tmp6); -} - -*/ - -/* -AVX 512 version -void gfmul_avx512(__m512i a, __m512i b, __m512i *res) { - __m512i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6; - __m512i tmp7, tmp8, tmp9, tmp10, tmp11, tmp12; - __m512i XMMMASK = _mm512_set_epi32( - 0, 0, 0, 0xffffffff, - 0, 0, 0, 0xffffffff, - 0, 0, 0, 0xffffffff, - 0, 0, 0, 0xffffffff - ); - - tmp3 = _mm512_clmulepi64_epi128(a, b, 0x00); - tmp6 = _mm512_clmulepi64_epi128(a, b, 0x11); - - tmp4 = _mm512_shuffle_epi32(a, _MM_PERM_BADC); - tmp5 = _mm512_shuffle_epi32(b, _MM_PERM_BADC); - tmp4 = _mm512_xor_si512(tmp4, a); - tmp5 = _mm512_xor_si512(tmp5, b); - - tmp4 = _mm512_clmulepi64_epi128(tmp4, tmp5, 0x00); - tmp4 = _mm512_xor_si512(tmp4, tmp3); - tmp4 = _mm512_xor_si512(tmp4, tmp6); - - tmp5 = _mm512_bslli_epi128(tmp4, 8); - tmp4 = _mm512_bsrli_epi128(tmp4, 8); - tmp3 = _mm512_xor_si512(tmp3, tmp5); - tmp6 = _mm512_xor_si512(tmp6, tmp4); - - tmp7 = _mm512_srli_epi32(tmp6, 31); - tmp8 = _mm512_srli_epi32(tmp6, 30); - tmp9 = _mm512_srli_epi32(tmp6, 25); - - tmp7 = _mm512_xor_si512(tmp7, tmp8); - tmp7 = _mm512_xor_si512(tmp7, tmp9); - - tmp8 = _mm512_shuffle_epi32(tmp7, _MM_PERM_ABCD); - tmp7 = _mm512_and_si512(XMMMASK, tmp8); - tmp8 = _mm512_andnot_si512(XMMMASK, tmp8); - - tmp3 = _mm512_xor_si512(tmp3, tmp8); - tmp6 = _mm512_xor_si512(tmp6, tmp7); - - tmp10 = _mm512_slli_epi32(tmp6, 1); - tmp3 = _mm512_xor_si512(tmp3, tmp10); - - tmp11 = _mm512_slli_epi32(tmp6, 2); - tmp3 = _mm512_xor_si512(tmp3, tmp11); - - tmp12 = _mm512_slli_epi32(tmp6, 7); - tmp3 = _mm512_xor_si512(tmp3, tmp12); - - *res = _mm512_xor_si512(tmp3, tmp6); -} - */ - -impl From for AVX256GF2_128x4 { - #[inline(always)] - fn from(v: u32) -> AVX256GF2_128x4 { - assert!(v < 2); // only 0 and 1 are allowed - let data = unsafe { [_mm256_set_epi64(0, v as i64, 0, v as i64, 0), _mm256_set_epi64(0, v as i64, 0, v as i64, 0)] }; - AVX256GF2_128x4 { data } - } -} - -impl Neg for AVX256GF2_128x4 { - type Output = AVX256GF2_128x4; - - #[inline(always)] - fn neg(self) -> AVX256GF2_128x4 { - self - } -} - -impl Debug for AVX256GF2_128x4 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut data = [0u8; 64]; - unsafe { - _mm256_storeu_si256(data.as_mut_ptr() as *mut __m256i, self.data[0]); - _mm256_storeu_si256(data.as_mut_ptr().add(8) as *mut __m256i, self.data[1]); - } - f.debug_struct("AVX256GF2_128x4") - .field("data", &data) - .finish() - } -} - -impl PartialEq for AVX256GF2_128x4 { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - unsafe { - let cmp = _mm256_cmpeq_epi64_mask(self.data[0], other.data[0]) & _mm256_cmpeq_epi64_mask(self.data[1], other.data[1]); - cmp == 0xFF // All 8 64-bit integers are equal - } - } -} - -impl Default for AVX256GF2_128x4 { - #[inline(always)] - fn default() -> Self { - Self::zero() - } -} - -impl From for AVX256GF2_128x4 { - #[inline(always)] - fn from(v: GF2_128) -> AVX256GF2_128x4 { - unsafe { - let mut result = [_mm256_setzero_si256(), _mm256_setzero_si256()]; // Initialize a zeroed _m512i - result[0] = _mm256_inserti32x4(result[0], v.v, 0); // Insert `a` at position 0 - result[0] = _mm256_inserti32x4(result[0], v.v, 1); // Insert `b` at position 1 - result[1] = _mm256_inserti32x4(result[1], v.v, 2); // Insert `c` at position 2 - result[1] = _mm256_inserti32x4(result[1], v.v, 3); // Insert `d` at position 3 - AVX256GF2_128x4 { data: result } - } - } -} - -impl SimdField for AVX256GF2_128x4 { - #[inline(always)] - fn scale(&self, challenge: &Self::Scalar) -> Self { - let simd_challenge = AVX256GF2_128x4::from(*challenge); - *self * simd_challenge - } - type Scalar = GF2_128; - - #[inline(always)] - fn pack_size() -> usize { - 4 - } -} - -#[inline(always)] -fn add_internal(a: &AVX256GF2_128x4, b: &AVX256GF2_128x4) -> AVX256GF2_128x4 { - unsafe { - AVX256GF2_128x4 { - data: [_mm256_xor_si256(a.data[0], b.data[0]), _mm256_xor_si256(a.data[1], b.data[1])], - } - } -} - -#[inline(always)] -fn sub_internal(a: &AVX256GF2_128x4, b: &AVX256GF2_128x4) -> AVX256GF2_128x4 { - unsafe { - AVX256GF2_128x4 { - data: [_mm256_xor_si256(a.data[0], b.data[0]), _mm256_xor_si256(a.data[1], b.data[1])], - } - } -} - -#[inline] -fn mul_internal(a: &AVX256GF2_128x4, b: &AVX256GF2_128x4) -> AVX256GF2_128x4 { - unsafe { - let xmmmask = _mm256_set_epi32( - 0, - 0, - 0, - 0xffffffffu32 as i32, - 0, - 0, - 0, - 0xffffffffu32 as i32, - ); - let mut result = [_mm256_setzero_si256(), _mm256_setzero_si256()]; - for i in 0..2 { - - let mut tmp3 = _mm256_clmulepi64_epi128(a.data[i], b.data[i], 0x00); - let mut tmp6 = _mm256_clmulepi64_epi128(a.data[i], b.data[i], 0x11); - - let mut tmp4 = _mm256_shuffle_epi32(a.data[i], _MM_PERM_BADC); - let mut tmp5 = _mm256_shuffle_epi32(b.data[i], _MM_PERM_BADC); - tmp4 = _mm256_xor_si256(tmp4, a.data[i]); - tmp5 = _mm256_xor_si256(tmp5, b.data[i]); - - tmp4 = _mm256_clmulepi64_epi128(tmp4, tmp5, 0x00); - tmp4 = _mm256_xor_si256(tmp4, tmp3); - tmp4 = _mm256_xor_si256(tmp4, tmp6); - - tmp5 = _mm256_bslli_epi128(tmp4, 8); - tmp4 = _mm256_bsrli_epi128(tmp4, 8); - tmp3 = _mm256_xor_si256(tmp3, tmp5); - tmp6 = _mm256_xor_si256(tmp6, tmp4); - - let tmp7 = _mm256_srli_epi32(tmp6, 31); - let tmp8 = _mm256_srli_epi32(tmp6, 30); - let tmp9 = _mm256_srli_epi32(tmp6, 25); - - let mut tmp7 = _mm256_xor_si256(tmp7, tmp8); - tmp7 = _mm256_xor_si256(tmp7, tmp9); - - let mut tmp8 = _mm256_shuffle_epi32(tmp7, _MM_PERM_CBAD); - tmp7 = _mm256_and_si256(xmmmask, tmp8); - tmp8 = _mm256_andnot_si256(xmmmask, tmp8); - - tmp3 = _mm256_xor_si256(tmp3, tmp8); - tmp6 = _mm256_xor_si256(tmp6, tmp7); - - let tmp10 = _mm256_slli_epi32(tmp6, 1); - tmp3 = _mm256_xor_si256(tmp3, tmp10); - - let tmp11 = _mm256_slli_epi32(tmp6, 2); - tmp3 = _mm256_xor_si256(tmp3, tmp11); - - let tmp12 = _mm256_slli_epi32(tmp6, 7); - tmp3 = _mm256_xor_si256(tmp3, tmp12); - - result[i] = _mm256_xor_si256(tmp3, tmp6); - - } - AVX256GF2_128x4 { data: result } - } -} diff --git a/arith/src/extension_field/gf2_128x8.rs b/arith/src/extension_field/gf2_128x8.rs index ba39310..e250159 100644 --- a/arith/src/extension_field/gf2_128x8.rs +++ b/arith/src/extension_field/gf2_128x8.rs @@ -3,13 +3,12 @@ pub(crate) mod neon; #[cfg(target_arch = "aarch64")] pub type GF2_128x8 = neon::NeonGF2_128x8; -#[cfg(target_arch = "x86_64")] -mod avx; -#[cfg(target_arch = "x86_64")] +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +mod avx512; +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +pub type GF2_128x8 = avx512::AVX512GF2_128x8; + +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] mod avx256; -#[cfg(target_arch = "x86_64")] -pub type GF2_128x8_256 = avx256::AVX256GF2_128x8; -#[cfg(all(target_arch = "x86_64", feature = "avx256"))] +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] pub type GF2_128x8 = avx256::AVX256GF2_128x8; -#[cfg(all(target_arch = "x86_64", not(feature = "avx256")))] -pub type GF2_128x8 = avx::AVX512GF2_128x8; diff --git a/arith/src/extension_field/gf2_128x8/avx.rs b/arith/src/extension_field/gf2_128x8/avx512.rs similarity index 100% rename from arith/src/extension_field/gf2_128x8/avx.rs rename to arith/src/extension_field/gf2_128x8/avx512.rs diff --git a/arith/src/field/m31.rs b/arith/src/field/m31.rs index 628f3ad..6c12ff2 100644 --- a/arith/src/field/m31.rs +++ b/arith/src/field/m31.rs @@ -1,11 +1,12 @@ mod m31x16; pub use m31x16::M31x16; -#[cfg(target_arch = "x86_64")] -pub(crate) mod m31_avx; -#[cfg(target_arch = "x86_64")] +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] pub(crate) mod m31_avx256; +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub(crate) mod m31_avx512; + #[cfg(target_arch = "x86_64")] pub type M31x16_256 = m31_avx256::AVXM31; diff --git a/arith/src/field/m31/m31_avx.rs b/arith/src/field/m31/m31_avx512.rs similarity index 100% rename from arith/src/field/m31/m31_avx.rs rename to arith/src/field/m31/m31_avx512.rs diff --git a/arith/src/field/m31/m31x16.rs b/arith/src/field/m31/m31x16.rs index c00456d..b111b6e 100644 --- a/arith/src/field/m31/m31x16.rs +++ b/arith/src/field/m31/m31x16.rs @@ -1,14 +1,12 @@ // A M31x16 stores 512 bits of data. // With AVX it stores a single __m512i element. // With NEON it stores four uint32x4_t elements. -#[cfg(target_arch = "x86_64")] -cfg_if::cfg_if! { - if #[cfg(feature = "avx256")] { - pub type M31x16 = super::m31_avx256::AVXM31; - } else { - pub type M31x16 = super::m31_avx::AVXM31; - } -} #[cfg(target_arch = "aarch64")] pub type M31x16 = super::m31_neon::NeonM31; + +#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] +pub type M31x16 = super::m31_avx512::AVXM31; + +#[cfg(all(target_arch = "x86_64", not(target_feature = "avx512f")))] +pub type M31x16 = super::m31_avx256::AVXM31;