diff --git a/crates/core_arch/Cargo.toml b/crates/core_arch/Cargo.toml index a1bb168ee7..cd83058663 100644 --- a/crates/core_arch/Cargo.toml +++ b/crates/core_arch/Cargo.toml @@ -23,3 +23,6 @@ maintenance = { status = "experimental" } [dev-dependencies] stdarch-test = { version = "0.*", path = "../stdarch-test" } std_detect = { version = "0.*", path = "../std_detect" } + +[target.'cfg(all(target_arch = "x86_64", target_os = "linux"))'.dev-dependencies] +syscalls = { version = "0.6.18", default-features = false } diff --git a/crates/core_arch/missing-x86.md b/crates/core_arch/missing-x86.md index e8f16f7e69..2ea66f3e6e 100644 --- a/crates/core_arch/missing-x86.md +++ b/crates/core_arch/missing-x86.md @@ -2,7 +2,6 @@
["AMX-BF16"]

* [ ] [`__tile_dpbf16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbf16ps) - * [ ] [`_tile_dpbf16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_dpbf16ps)

@@ -10,15 +9,12 @@ * [ ] [`__tile_cmmimfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmimfp16ps) * [ ] [`__tile_cmmrlfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmrlfp16ps) - * [ ] [`_tile_cmmimfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_cmmimfp16ps) - * [ ] [`_tile_cmmrlfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_cmmrlfp16ps)

["AMX-FP16"]

* [ ] [`__tile_dpfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpfp16ps) - * [ ] [`_tile_dpfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_dpfp16ps)

@@ -28,10 +24,6 @@ * [ ] [`__tile_dpbsud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbsud) * [ ] [`__tile_dpbusd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbusd) * [ ] [`__tile_dpbuud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbuud) - * [ ] [`_tile_dpbssd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_dpbssd) - * [ ] [`_tile_dpbsud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_dpbsud) - * [ ] [`_tile_dpbusd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_dpbusd) - * [ ] [`_tile_dpbuud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_dpbuud)

@@ -41,13 +33,6 @@ * [ ] [`__tile_stored`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stored) * [ ] [`__tile_stream_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stream_loadd) * [ ] [`__tile_zero`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_zero) - * [ ] [`_tile_loadconfig`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_loadconfig) - * [ ] [`_tile_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_loadd) - * [ ] [`_tile_release`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_release) - * [ ] [`_tile_storeconfig`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_storeconfig) - * [ ] [`_tile_stored`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_stored) - * [ ] [`_tile_stream_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_stream_loadd) - * [ ] [`_tile_zero`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_tile_zero)

diff --git a/crates/core_arch/src/lib.rs b/crates/core_arch/src/lib.rs index a7a02783e0..d6f7de619e 100644 --- a/crates/core_arch/src/lib.rs +++ b/crates/core_arch/src/lib.rs @@ -35,6 +35,7 @@ generic_arg_infer, asm_experimental_arch, sha512_sm_x86, + x86_amx_intrinsics, f16 )] #![cfg_attr(test, feature(test, abi_vectorcall, stdarch_internal))] diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs new file mode 100644 index 0000000000..547dc2a67e --- /dev/null +++ b/crates/core_arch/src/x86_64/amx.rs @@ -0,0 +1,604 @@ +/// Load tile configuration from a 64-byte memory location specified by mem_addr. +/// The tile configuration format is specified below, and includes the tile type pallette, +/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, +/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. +/// Any invalid configurations will result in #GP fault. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) +#[inline] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { + ldtilecfg(mem_addr); +} + +/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. +/// The tile configuration format is specified below, and includes the tile type pallette, +/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) +#[inline] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { + sttilecfg(mem_addr); +} + +/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) +#[inline] +#[rustc_legacy_const_generics(0)] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { + static_assert_uimm_bits!(DST, 3); + tileloadd64(DST as i8, base, stride); +} + +/// Release the tile configuration to return to the init state, which releases all storage it currently holds. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878) +#[inline] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_release() { + tilerelease(); +} + +/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) +#[inline] +#[rustc_legacy_const_generics(0)] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { + static_assert_uimm_bits!(DST, 3); + tilestored64(DST as i8, base, stride); +} + +/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration +/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will +/// likely not be reused in the near future and the data caching can be optimized accordingly. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) +#[inline] +#[rustc_legacy_const_generics(0)] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) { + static_assert_uimm_bits!(DST, 3); + tileloaddt164(DST as i8, base, stride); +} + +/// Zero the tile specified by tdest. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) +#[inline] +#[rustc_legacy_const_generics(0)] +#[target_feature(enable = "amx-tile")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_zero() { + static_assert_uimm_bits!(DST, 3); + tilezero(DST as i8); +} + +/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-bf16")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_dpbf16ps() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tdpbf16ps(DST as i8, A as i8, B as i8); +} + +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-int8")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_dpbssd() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tdpbssd(DST as i8, A as i8, B as i8); +} + +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-int8")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_dpbsud() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tdpbsud(DST as i8, A as i8, B as i8); +} + +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-int8")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_dpbusd() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tdpbusd(DST as i8, A as i8, B as i8); +} + +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-int8")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_dpbuud() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tdpbuud(DST as i8, A as i8, B as i8); +} + +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-fp16")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_dpfp16ps() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tdpfp16ps(DST as i8, A as i8, B as i8); +} + +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, +/// and then accumulated into the corresponding row and column of dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-complex")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cmmimfp16ps() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tcmmimfp16ps(DST as i8, A as i8, B as i8); +} + +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the real part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. +/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862) +#[inline] +#[rustc_legacy_const_generics(0, 1, 2)] +#[target_feature(enable = "amx-complex")] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn _tile_cmmrlfp16ps() { + static_assert_uimm_bits!(DST, 3); + static_assert_uimm_bits!(A, 3); + static_assert_uimm_bits!(B, 3); + tcmmrlfp16ps(DST as i8, A as i8, B as i8); +} + +#[allow(improper_ctypes)] +extern "C" { + #[link_name = "llvm.x86.ldtilecfg"] + fn ldtilecfg(mem_addr: *const u8); + #[link_name = "llvm.x86.sttilecfg"] + fn sttilecfg(mem_addr: *mut u8); + #[link_name = "llvm.x86.tileloadd64"] + fn tileloadd64(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddt164"] + fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tilerelease"] + fn tilerelease(); + #[link_name = "llvm.x86.tilestored64"] + fn tilestored64(dst: i8, base: *mut u8, stride: usize); + #[link_name = "llvm.x86.tilezero"] + fn tilezero(dst: i8); + #[link_name = "llvm.x86.tdpbf16ps"] + fn tdpbf16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbuud"] + fn tdpbuud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbusd"] + fn tdpbusd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbsud"] + fn tdpbsud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbssd"] + fn tdpbssd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpfp16ps"] + fn tdpfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmimfp16ps"] + fn tcmmimfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmrlfp16ps"] + fn tcmmrlfp16ps(dst: i8, a: i8, b: i8); +} + +#[cfg(test)] +mod tests { + use crate::core_arch::x86::_mm_cvtness_sbh; + use crate::core_arch::x86_64::*; + use core::mem::transmute; + use stdarch_test::simd_test; + #[cfg(target_os = "linux")] + use syscalls::{syscall, Sysno}; + + #[allow(non_camel_case_types)] + #[repr(packed)] + #[derive(Copy, Clone, Default, Debug, PartialEq)] + struct __tilecfg { + /// 0 `or` 1 + palette: u8, + start_row: u8, + /// reserved, must be zero + reserved_a0: [u8; 14], + /// number of bytes of one row in each tile + colsb: [u16; 8], + /// reserved, must be zero + reserved_b0: [u16; 8], + /// number of rows in each tile + rows: [u8; 8], + /// reserved, must be zero + reserved_c0: [u8; 8], + } + + impl __tilecfg { + fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self { + Self { + palette, + start_row, + reserved_a0: [0u8; 14], + colsb, + reserved_b0: [0u16; 8], + rows, + reserved_c0: [0u8; 8], + } + } + + const fn as_ptr(&self) -> *const u8 { + self as *const Self as *const u8 + } + + fn as_mut_ptr(&mut self) -> *mut u8 { + self as *mut Self as *mut u8 + } + } + + #[cfg(not(target_os = "linux"))] + #[target_feature(enable = "amx-tile")] + fn _init_amx() {} + + #[cfg(target_os = "linux")] + #[target_feature(enable = "amx-tile")] + #[inline] + unsafe fn _init_amx() { + let mut ret: usize; + let mut xfeatures: usize = 0; + ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) + .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed"); + if ret != 0 { + panic!("Failed to get XFEATURES"); + } else { + match 0b11 & (xfeatures >> 17) { + 0 => panic!("AMX is not available"), + 1 => { + ret = syscall!(Sysno::arch_prctl, 0x1023, 18) + .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed"); + if ret != 0 { + panic!("Failed to enable AMX"); + } + } + 3 => {} + _ => unreachable!(), + } + } + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_loadconfig() { + let config = __tilecfg::default(); + _tile_loadconfig(config.as_ptr()); + _tile_release(); + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_storeconfig() { + let config = __tilecfg::new(1, 0, [32; 8], [8; 8]); + _tile_loadconfig(config.as_ptr()); + let mut _config = __tilecfg::default(); + _tile_storeconfig(_config.as_mut_ptr()); + _tile_release(); + assert_eq!(config, _config); + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_zero() { + _init_amx(); + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + let mut out = [[1_i8; 64]; 16]; + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_release(); + assert_eq!(out, [[0; 64]; 16]); + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_stored() { + _init_amx(); + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + let mut out = [[1_i8; 64]; 16]; + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_release(); + assert_eq!(out, [[0; 64]; 16]); + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_loadd() { + _init_amx(); + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + let mat = [1_i8; 1024]; + _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); + let mut out = [[0_i8; 64]; 16]; + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_release(); + assert_eq!(out, [[1; 64]; 16]); + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_stream_loadd() { + _init_amx(); + let mut config = __tilecfg::default(); + config.palette = 1; + config.colsb[0] = 64; + config.rows[0] = 16; + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + let mat = [1_i8; 1024]; + _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); + let mut out = [[0_i8; 64]; 16]; + _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_release(); + assert_eq!(out, [[1; 64]; 16]); + } + + #[simd_test(enable = "amx-tile")] + unsafe fn test_tile_release() { + _tile_release(); + } + + #[simd_test(enable = "amx-bf16,avx512f")] + unsafe fn test_tile_dpbf16ps() { + _init_amx(); + let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); + let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); + let ones: [u8; 1024] = transmute([bf16_1; 512]); + let twos: [u8; 1024] = transmute([bf16_2; 512]); + let mut res = [[0f32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const u8, 64); + _tile_loadd::<2>(&twos as *const u8, 64); + _tile_dpbf16ps::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[64f32; 16]; 16]); + } + + #[simd_test(enable = "amx-int8")] + unsafe fn test_tile_dpbssd() { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); + _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_dpbssd::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[128_i32; 16]; 16]); + } + + #[simd_test(enable = "amx-int8")] + unsafe fn test_tile_dpbsud() { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); + _tile_loadd::<2>(&twos as *const u8, 64); + _tile_dpbsud::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[-128_i32; 16]; 16]); + } + + #[simd_test(enable = "amx-int8")] + unsafe fn test_tile_dpbusd() { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const u8, 64); + _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_dpbusd::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[-128_i32; 16]; 16]); + } + + #[simd_test(enable = "amx-int8")] + unsafe fn test_tile_dpbuud() { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const u8, 64); + _tile_loadd::<2>(&twos as *const u8, 64); + _tile_dpbuud::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[128_i32; 16]; 16]); + } + + #[simd_test(enable = "amx-fp16")] + unsafe fn test_tile_dpfp16ps() { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); + _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_dpfp16ps::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[64f32; 16]; 16]); + } + + #[simd_test(enable = "amx-complex")] + unsafe fn test_tile_cmmimfp16ps() { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); + _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_cmmimfp16ps::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[64f32; 16]; 16]); + } + + #[simd_test(enable = "amx-complex")] + unsafe fn test_tile_cmmrlfp16ps() { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + let mut config = __tilecfg::default(); + config.palette = 1; + (0..=2).for_each(|i| { + config.colsb[i] = 64; + config.rows[i] = 16; + }); + _tile_loadconfig(config.as_ptr()); + _tile_zero::<0>(); + _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); + _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_cmmrlfp16ps::<0, 1, 2>(); + _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_release(); + assert_eq!(res, [[0f32; 16]; 16]); + } +} diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs index e4ad644edf..32ebf87d9c 100644 --- a/crates/core_arch/src/x86_64/mod.rs +++ b/crates/core_arch/src/x86_64/mod.rs @@ -77,3 +77,7 @@ pub use self::bt::*; mod avx512fp16; #[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")] pub use self::avx512fp16::*; + +mod amx; +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub use self::amx::*; diff --git a/crates/stdarch-verify/src/lib.rs b/crates/stdarch-verify/src/lib.rs index efb5d50e26..3a5588bbfe 100644 --- a/crates/stdarch-verify/src/lib.rs +++ b/crates/stdarch-verify/src/lib.rs @@ -215,6 +215,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream { "u32" => quote! { &U32 }, "u64" => quote! { &U64 }, "u128" => quote! { &U128 }, + "usize" => quote! { &USIZE }, "u8" => quote! { &U8 }, "p8" => quote! { &P8 }, "p16" => quote! { &P16 }, diff --git a/crates/stdarch-verify/tests/x86-intel.rs b/crates/stdarch-verify/tests/x86-intel.rs index 23d68436c5..0b87ac7d2a 100644 --- a/crates/stdarch-verify/tests/x86-intel.rs +++ b/crates/stdarch-verify/tests/x86-intel.rs @@ -10,6 +10,7 @@ use serde::Deserialize; const PRINT_INSTRUCTION_VIOLATIONS: bool = false; const PRINT_MISSING_LISTS: bool = false; const PRINT_MISSING_LISTS_MARKDOWN: bool = false; +const SS: u8 = (8 * core::mem::size_of::()) as u8; struct Function { name: &'static str, @@ -36,6 +37,7 @@ static U16: Type = Type::PrimUnsigned(16); static U32: Type = Type::PrimUnsigned(32); static U64: Type = Type::PrimUnsigned(64); static U128: Type = Type::PrimUnsigned(128); +static USIZE: Type = Type::PrimUnsigned(SS); static ORDERING: Type = Type::Ordering; static M128: Type = Type::M128; @@ -708,7 +710,7 @@ fn equate( intel = intel.replace("const ", ""); intel = intel.replace('*', " const*"); } - if etype == "IMM" { + if etype == "IMM" || intel == "constexpr int" { // The _bittest intrinsics claim to only accept immediates but actually // accept run-time values as well. if !is_const && !intrinsic.starts_with("_bittest") { @@ -727,7 +729,7 @@ fn equate( (&Type::PrimFloat(64), "double") => {} (&Type::PrimSigned(8), "__int8" | "char") => {} (&Type::PrimSigned(16), "__int16" | "short") => {} - (&Type::PrimSigned(32), "__int32" | "const int" | "int") => {} + (&Type::PrimSigned(32), "__int32" | "constexpr int" | "const int" | "int") => {} (&Type::PrimSigned(64), "__int64" | "long long") => {} (&Type::PrimUnsigned(8), "unsigned char") => {} (&Type::PrimUnsigned(16), "unsigned short") => {} @@ -736,7 +738,7 @@ fn equate( &Type::PrimUnsigned(32), "unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int", ) => {} - (&Type::PrimUnsigned(64), "unsigned __int64") => {} + (&Type::PrimUnsigned(64), "unsigned __int64" | "size_t") => {} (&Type::M128, "__m128") => {} (&Type::M128BH, "__m128bh") => {}