From e86e18cb7ae50c7610dfe86481a0f6c6426a0398 Mon Sep 17 00:00:00 2001 From: Jack O'Connor Date: Tue, 18 Jul 2023 00:35:37 -0700 Subject: [PATCH] AVX-512 hash_parents --- c/blake3_avx512_x86-64_unix.S | 108 +++++++++++++++++++++++++++++++++- rust/guts/src/avx512.rs | 18 ++++++ 2 files changed, 123 insertions(+), 3 deletions(-) diff --git a/c/blake3_avx512_x86-64_unix.S b/c/blake3_avx512_x86-64_unix.S index 99ce68ff2..b6135b21f 100644 --- a/c/blake3_avx512_x86-64_unix.S +++ b/c/blake3_avx512_x86-64_unix.S @@ -25,6 +25,8 @@ .global _blake3_guts_avx512_compress_xof .global blake3_guts_avx512_hash_chunks_16_exact .global _blake3_guts_avx512_hash_chunks_16_exact +.global blake3_guts_avx512_hash_parents_16_exact +.global _blake3_guts_avx512_hash_parents_16_exact .global blake3_guts_avx512_xof_16_exact .global _blake3_guts_avx512_xof_16_exact .global blake3_guts_avx512_xof_xor_16_exact @@ -3542,7 +3544,7 @@ blake3_guts_avx512_kernel_16: // rdi: block pointer // esi: [unused] -// rdx: cv +// rdx: [unused] // rcx: counter // r8d: flags // r9: out pointer @@ -3686,7 +3688,7 @@ blake3_guts_avx512_hash_blocks_16_exact: // rdi: block pointer // esi: [unused] -// rdx: cv +// rdx: key // rcx: counter // r8d: flags // r9: out pointer @@ -3741,7 +3743,7 @@ blake3_guts_avx512_hash_chunks_16_exact: or r8d, 0x2 call blake3_guts_avx512_hash_blocks_16_exact - // write aligned, transposed outputs with a stride of 2*MAX_SIMD_DEGREE words + // write aligned+transposed outputs with a stride of 2*MAX_SIMD_DEGREE words vmovdqa32 ZMMWORD PTR [r9+0x0*0x80],zmm0 vmovdqa32 ZMMWORD PTR [r9+0x1*0x80],zmm1 vmovdqa32 ZMMWORD PTR [r9+0x2*0x80],zmm2 @@ -3752,6 +3754,100 @@ blake3_guts_avx512_hash_chunks_16_exact: vmovdqa32 ZMMWORD PTR [r9+0x7*0x80],zmm7 ret +// rdi: aligned+transposed input +// rsi: [unused] +// rdx: key +// ecx: flags +// r8: out pointer +.p2align 6 +_blake3_guts_avx512_hash_parents_16_exact: +blake3_guts_avx512_hash_parents_16_exact: + // load the transposed CVs and split alternating words into the low and + // high halfs of the input vectors + vmovdqa32 zmm0, ZMMWORD PTR [EVEN+rip] + vmovdqa32 zmm1, ZMMWORD PTR [ODD+rip] + vmovdqa32 zmm16, ZMMWORD PTR [rdi+0x0*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0x1*0x40] + vmovdqa32 zmm24, zmm16 + vpermt2d zmm16, zmm0, zmm2 + vpermt2d zmm24, zmm1, zmm2 + vmovdqa32 zmm17, ZMMWORD PTR [rdi+0x2*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0x3*0x40] + vmovdqa32 zmm25, zmm17 + vpermt2d zmm17, zmm0, zmm2 + vpermt2d zmm25, zmm1, zmm2 + vmovdqa32 zmm18, ZMMWORD PTR [rdi+0x4*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0x5*0x40] + vmovdqa32 zmm26, zmm18 + vpermt2d zmm18, zmm0, zmm2 + vpermt2d zmm26, zmm1, zmm2 + vmovdqa32 zmm19, ZMMWORD PTR [rdi+0x6*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0x7*0x40] + vmovdqa32 zmm27, zmm19 + vpermt2d zmm19, zmm0, zmm2 + vpermt2d zmm27, zmm1, zmm2 + vmovdqa32 zmm20, ZMMWORD PTR [rdi+0x8*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0x9*0x40] + vmovdqa32 zmm28, zmm20 + vpermt2d zmm20, zmm0, zmm2 + vpermt2d zmm28, zmm1, zmm2 + vmovdqa32 zmm21, ZMMWORD PTR [rdi+0xa*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0xb*0x40] + vmovdqa32 zmm29, zmm21 + vpermt2d zmm21, zmm0, zmm2 + vpermt2d zmm29, zmm1, zmm2 + vmovdqa32 zmm22, ZMMWORD PTR [rdi+0xc*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0xd*0x40] + vmovdqa32 zmm30, zmm22 + vpermt2d zmm22, zmm0, zmm2 + vpermt2d zmm30, zmm1, zmm2 + vmovdqa32 zmm23, ZMMWORD PTR [rdi+0xe*0x40] + vmovdqa32 zmm2, ZMMWORD PTR [rdi+0xf*0x40] + vmovdqa32 zmm31, zmm23 + vpermt2d zmm23, zmm0, zmm2 + vpermt2d zmm31, zmm1, zmm2 + // broadcast the key + vpbroadcastd zmm0,DWORD PTR [rdx] + vpbroadcastd zmm1,DWORD PTR [rdx+0x4] + vpbroadcastd zmm2,DWORD PTR [rdx+0x8] + vpbroadcastd zmm3,DWORD PTR [rdx+0xc] + vpbroadcastd zmm4,DWORD PTR [rdx+0x10] + vpbroadcastd zmm5,DWORD PTR [rdx+0x14] + vpbroadcastd zmm6,DWORD PTR [rdx+0x18] + vpbroadcastd zmm7,DWORD PTR [rdx+0x1c] + // zero the counter + mov eax, 0 + vpbroadcastd zmm12,eax + vpbroadcastd zmm13,eax + // broadcast the block length + mov eax, 64 + vpbroadcastd zmm14, eax + // broadcast the flags + vpbroadcastd zmm15, ecx + + // execute the kernel + call blake3_guts_avx512_kernel_16 + + // xor the two halves of the state + vpxord zmm0, zmm0, zmm8 + vpxord zmm1, zmm1, zmm9 + vpxord zmm2, zmm2, zmm10 + vpxord zmm3, zmm3, zmm11 + vpxord zmm4, zmm4, zmm12 + vpxord zmm5, zmm5, zmm13 + vpxord zmm6, zmm6, zmm14 + vpxord zmm7, zmm7, zmm15 + // write aligned+transposed outputs with a stride of 2*MAX_SIMD_DEGREE words + vmovdqa32 ZMMWORD PTR [r8+0x0*0x80],zmm0 + vmovdqa32 ZMMWORD PTR [r8+0x1*0x80],zmm1 + vmovdqa32 ZMMWORD PTR [r8+0x2*0x80],zmm2 + vmovdqa32 ZMMWORD PTR [r8+0x3*0x80],zmm3 + vmovdqa32 ZMMWORD PTR [r8+0x4*0x80],zmm4 + vmovdqa32 ZMMWORD PTR [r8+0x5*0x80],zmm5 + vmovdqa32 ZMMWORD PTR [r8+0x6*0x80],zmm6 + vmovdqa32 ZMMWORD PTR [r8+0x7*0x80],zmm7 + ret + // rdi: block pointer // esi: block_len // rdx: cv @@ -3972,6 +4068,12 @@ INDEX0: INDEX1: .long 4, 5, 6, 7, 20, 21, 22, 23 .long 12, 13, 14, 15, 28, 29, 30, 31 +EVEN: + .long 0, 2, 4, 6, 8, 10, 12, 14 + .long 16, 18, 20, 22, 24, 26, 28, 30 +ODD: + .long 1, 3, 5, 7, 9, 11, 13, 15 + .long 17, 19, 21, 23, 25, 27, 29, 31 ADD0: .long 0, 1, 2, 3, 4, 5, 6, 7 .long 8, 9, 10, 11, 12, 13, 14, 15 diff --git a/rust/guts/src/avx512.rs b/rust/guts/src/avx512.rs index dc4ce4a30..afb65f25e 100644 --- a/rust/guts/src/avx512.rs +++ b/rust/guts/src/avx512.rs @@ -27,6 +27,13 @@ extern "C" { flags: u32, transposed_output: *mut u32, ); + fn blake3_guts_avx512_hash_parents_16_exact( + transposed_input: *const u32, + num_parents: usize, + key: *const CVBytes, + flags: u32, + transposed_output: *mut u32, + ); fn blake3_guts_avx512_xof_16_exact( block: *const BlockBytes, block_len: u32, @@ -83,6 +90,17 @@ unsafe extern "C" fn hash_parents( flags: u32, transposed_output: *mut u32, // may overlap the input ) { + debug_assert!(num_parents <= 16); + if num_parents == 16 { + blake3_guts_avx512_hash_parents_16_exact( + transposed_input, + num_parents, + key, + flags, + transposed_output, + ); + return; + } crate::hash_parents_using_compress( blake3_guts_avx512_compress, transposed_input,