Skip to content

Commit

Permalink
Add avx512
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Mar 21, 2024
1 parent a4df507 commit 56a7f12
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 73 deletions.
5 changes: 3 additions & 2 deletions internal/math32/floats.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package math32

var (
useAVX bool // nolint unused
useNEON bool // nolint unused
useAVX bool // nolint unused
useAVX512 bool // nolint unused
useNEON bool // nolint unused
)

// Dot two vectors.
Expand Down
23 changes: 23 additions & 0 deletions internal/math32/floats_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,31 @@ import (

func init() {
useAVX = cpu.X86.HasAVX
useAVX = cpu.X86.HasAVX512
}

//go:noescape
func _dot_product_avx(a, b unsafe.Pointer, n uintptr, result unsafe.Pointer)

//go:noescape
func _dot_product_avx512(a, b unsafe.Pointer, n uintptr, result unsafe.Pointer)

//go:noescape
func _squared_l2_avx(a, b unsafe.Pointer, n uintptr, result unsafe.Pointer)

//go:noescape
func _squared_l2_avx512(a, b unsafe.Pointer, n uintptr, result unsafe.Pointer)

func dot(a, b []float32) float32 {
switch {
case useAVX512:
var ret float32

if len(a) > 0 {
_dot_product_avx512(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), uintptr(len(a)), unsafe.Pointer(&ret))
}

return ret
case useAVX:
var ret float32

Expand All @@ -35,6 +50,14 @@ func dot(a, b []float32) float32 {

func squaredL2(a, b []float32) float32 {
switch {
case useAVX512:
var ret float32

if len(a) > 0 {
_squared_l2_avx512(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), uintptr(len(a)), unsafe.Pointer(&ret))
}

return ret
case useAVX:
var ret float32

Expand Down
67 changes: 0 additions & 67 deletions internal/math32/floats_arm64.s

This file was deleted.

215 changes: 215 additions & 0 deletions internal/math32/floats_avx512.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
//go:build !noasm && amd64
// Code generated by GoLC. DO NOT EDIT.

#include "textflag.h"

TEXT ·_dot_product_avx512(SB), $0-32
MOVQ vec1+0(FP), DI
MOVQ vec2+8(FP), SI
MOVQ n+16(FP), DX
MOVQ result+24(FP), CX
BYTE $0x55 // pushq %rbp
WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp
LONG $0xf8e48348 // andq $-8, %rsp
LONG $0x0f4a8d4c // leaq 15(%rdx), %r9
WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx
LONG $0xca490f4c // cmovnsq %rdx, %r9
LONG $0xf0e18349 // andq $-16, %r9
WORD $0x8949; BYTE $0xd0 // movq %rdx, %r8
WORD $0x294d; BYTE $0xc8 // subq %r9, %r8
LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0
WORD $0x854d; BYTE $0xc9 // testq %r9, %r9
JLE LBB0_1
WORD $0xc031 // xorl %eax, %eax

LBB0_4:
LONG $0x487cf162; WORD $0x0c10; BYTE $0x87 // vmovups (%rdi,%rax,4), %zmm1
LONG $0x4874f162; WORD $0x0c59; BYTE $0x86 // vmulps (%rsi,%rax,4), %zmm1, %zmm1
LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0
LONG $0x10c08348 // addq $16, %rax
WORD $0x394c; BYTE $0xc8 // cmpq %r9, %rax
JL LBB0_4

LBB0_1:
LONG $0xc957f0c5 // vxorps %xmm1, %xmm1, %xmm1
LONG $0xc958fac5 // vaddss %xmm1, %xmm0, %xmm1
LONG $0xd016fac5 // vmovshdup %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0579e3c4; WORD $0x01d0 // vpermilpd $1, %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0479e3c4; WORD $0xffd0 // vpermilps $255, %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x197de3c4; WORD $0x01c2 // vextractf128 $1, %ymm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0xda16fac5 // vmovshdup %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0579e3c4; WORD $0x01da // vpermilpd $1, %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0479e3c4; WORD $0xffd2 // vpermilps $255, %xmm2, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x487df362; WORD $0xc219; BYTE $0x02 // vextractf32x4 $2, %zmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0xda16fac5 // vmovshdup %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0579e3c4; WORD $0x01da // vpermilpd $1, %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0479e3c4; WORD $0xffd2 // vpermilps $255, %xmm2, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x487df362; WORD $0xc019; BYTE $0x03 // vextractf32x4 $3, %zmm0, %xmm0
LONG $0xc958fac5 // vaddss %xmm1, %xmm0, %xmm1
LONG $0xd016fac5 // vmovshdup %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0579e3c4; WORD $0x01d0 // vpermilpd $1, %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0479e3c4; WORD $0xffc0 // vpermilps $255, %xmm0, %xmm0
LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0
LONG $0x0111fac5 // vmovss %xmm0, (%rcx)
WORD $0x854d; BYTE $0xc0 // testq %r8, %r8
JLE LBB0_2

LBB0_5:
LONG $0x107aa1c4; WORD $0x8f0c // vmovss (%rdi,%r9,4), %xmm1
LONG $0xb971a2c4; WORD $0x8e04 // vfmadd231ss (%rsi,%r9,4), %xmm1, %xmm0
LONG $0x0111fac5 // vmovss %xmm0, (%rcx)
LONG $0x01c18349 // addq $1, %r9
WORD $0x3949; BYTE $0xd1 // cmpq %rdx, %r9
JL LBB0_5

LBB0_2:
WORD $0x8948; BYTE $0xec // movq %rbp, %rsp
BYTE $0x5d // popq %rbp
WORD $0xf8c5; BYTE $0x77 // vzeroupper
BYTE $0xc3 // retq

TEXT ·_squared_l2_avx512(SB), $0-32
MOVQ vec1+0(FP), DI
MOVQ vec2+8(FP), SI
MOVQ n+16(FP), DX
MOVQ result+24(FP), CX
BYTE $0x55 // pushq %rbp
WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp
LONG $0xf8e48348 // andq $-8, %rsp
LONG $0x0f428d4c // leaq 15(%rdx), %r8
WORD $0x8548; BYTE $0xd2 // testq %rdx, %rdx
LONG $0xc2490f4c // cmovnsq %rdx, %r8
LONG $0x10fa8348 // cmpq $16, %rdx
JL LBB1_1
LONG $0x04f8c149 // sarq $4, %r8
WORD $0x8948; BYTE $0xd0 // movq %rdx, %rax
LONG $0xf0e08348 // andq $-16, %rax
LONG $0x10f88348 // cmpq $16, %rax
JNE LBB1_4
LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0
WORD $0xc031 // xorl %eax, %eax
JMP LBB1_6

LBB1_1:
LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0
WORD $0x3145; BYTE $0xc0 // xorl %r8d, %r8d
JMP LBB1_9

LBB1_4:
WORD $0x894d; BYTE $0xc1 // movq %r8, %r9
LONG $0xfee18349 // andq $-2, %r9
LONG $0xc057f8c5 // vxorps %xmm0, %xmm0, %xmm0
WORD $0xc031 // xorl %eax, %eax

LBB1_5:
LONG $0x487cf162; WORD $0x0c10; BYTE $0x87 // vmovups (%rdi,%rax,4), %zmm1
QUAD $0x01875410487cf162 // vmovups 64(%rdi,%rax,4), %zmm2
LONG $0x4874f162; WORD $0x0c5c; BYTE $0x86 // vsubps (%rsi,%rax,4), %zmm1, %zmm1
LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1
LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0
QUAD $0x01864c5c486cf162 // vsubps 64(%rsi,%rax,4), %zmm2, %zmm1
LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1
LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0
LONG $0x20c08348 // addq $32, %rax
LONG $0xfec18349 // addq $-2, %r9
JNE LBB1_5

LBB1_6:
LONG $0x01c0f641 // testb $1, %r8b
JE LBB1_8
LONG $0x487cf162; WORD $0x0c10; BYTE $0x87 // vmovups (%rdi,%rax,4), %zmm1
LONG $0x4874f162; WORD $0x0c5c; BYTE $0x86 // vsubps (%rsi,%rax,4), %zmm1, %zmm1
LONG $0x4874f162; WORD $0xc959 // vmulps %zmm1, %zmm1, %zmm1
LONG $0x487cf162; WORD $0xc158 // vaddps %zmm1, %zmm0, %zmm0

LBB1_8:
LONG $0x04e0c149 // shlq $4, %r8

LBB1_9:
LONG $0xc816fac5 // vmovshdup %xmm0, %xmm1
LONG $0xc958fac5 // vaddss %xmm1, %xmm0, %xmm1
LONG $0x0579e3c4; WORD $0x01d0 // vpermilpd $1, %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0479e3c4; WORD $0xffd0 // vpermilps $255, %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x197de3c4; WORD $0x01c2 // vextractf128 $1, %ymm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0xda16fac5 // vmovshdup %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0579e3c4; WORD $0x01da // vpermilpd $1, %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0479e3c4; WORD $0xffd2 // vpermilps $255, %xmm2, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x487df362; WORD $0xc219; BYTE $0x02 // vextractf32x4 $2, %zmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0xda16fac5 // vmovshdup %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0579e3c4; WORD $0x01da // vpermilpd $1, %xmm2, %xmm3
LONG $0xc958e2c5 // vaddss %xmm1, %xmm3, %xmm1
LONG $0x0479e3c4; WORD $0xffd2 // vpermilps $255, %xmm2, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x487df362; WORD $0xc019; BYTE $0x03 // vextractf32x4 $3, %zmm0, %xmm0
LONG $0xc958fac5 // vaddss %xmm1, %xmm0, %xmm1
LONG $0xd016fac5 // vmovshdup %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0579e3c4; WORD $0x01d0 // vpermilpd $1, %xmm0, %xmm2
LONG $0xc958eac5 // vaddss %xmm1, %xmm2, %xmm1
LONG $0x0479e3c4; WORD $0xffc0 // vpermilps $255, %xmm0, %xmm0
LONG $0xc158fac5 // vaddss %xmm1, %xmm0, %xmm0
WORD $0x3949; BYTE $0xd0 // cmpq %rdx, %r8
JGE LBB1_14
WORD $0x894d; BYTE $0xc1 // movq %r8, %r9
WORD $0xf749; BYTE $0xd1 // notq %r9
WORD $0x0149; BYTE $0xd1 // addq %rdx, %r9
WORD $0x8948; BYTE $0xd0 // movq %rdx, %rax
LONG $0x03e08348 // andq $3, %rax
JE LBB1_12

LBB1_11:
LONG $0x107aa1c4; WORD $0x870c // vmovss (%rdi,%r8,4), %xmm1
LONG $0x5c72a1c4; WORD $0x860c // vsubss (%rsi,%r8,4), %xmm1, %xmm1
LONG $0xb971e2c4; BYTE $0xc1 // vfmadd231ss %xmm1, %xmm1, %xmm0
LONG $0x01c08349 // addq $1, %r8
LONG $0xffc08348 // addq $-1, %rax
JNE LBB1_11

LBB1_12:
LONG $0x03f98349 // cmpq $3, %r9
JB LBB1_14

LBB1_13:
LONG $0x107aa1c4; WORD $0x870c // vmovss (%rdi,%r8,4), %xmm1
LONG $0x107aa1c4; WORD $0x8754; BYTE $0x04 // vmovss 4(%rdi,%r8,4), %xmm2
LONG $0x5c72a1c4; WORD $0x860c // vsubss (%rsi,%r8,4), %xmm1, %xmm1
LONG $0x5c6aa1c4; WORD $0x8654; BYTE $0x04 // vsubss 4(%rsi,%r8,4), %xmm2, %xmm2
LONG $0xa971e2c4; BYTE $0xc8 // vfmadd213ss %xmm0, %xmm1, %xmm1
LONG $0x107aa1c4; WORD $0x8744; BYTE $0x08 // vmovss 8(%rdi,%r8,4), %xmm0
LONG $0x5c7aa1c4; WORD $0x865c; BYTE $0x08 // vsubss 8(%rsi,%r8,4), %xmm0, %xmm3
LONG $0xa969e2c4; BYTE $0xd1 // vfmadd213ss %xmm1, %xmm2, %xmm2
LONG $0x107aa1c4; WORD $0x8744; BYTE $0x0c // vmovss 12(%rdi,%r8,4), %xmm0
LONG $0x5c7aa1c4; WORD $0x8644; BYTE $0x0c // vsubss 12(%rsi,%r8,4), %xmm0, %xmm0
LONG $0xa961e2c4; BYTE $0xda // vfmadd213ss %xmm2, %xmm3, %xmm3
LONG $0xa979e2c4; BYTE $0xc3 // vfmadd213ss %xmm3, %xmm0, %xmm0
LONG $0x04c08349 // addq $4, %r8
WORD $0x394c; BYTE $0xc2 // cmpq %r8, %rdx
JNE LBB1_13

LBB1_14:
LONG $0x0111fac5 // vmovss %xmm0, (%rcx)
WORD $0x8948; BYTE $0xec // movq %rbp, %rsp
BYTE $0x5d // popq %rbp
WORD $0xf8c5; BYTE $0x77 // vzeroupper
BYTE $0xc3 // retq
2 changes: 0 additions & 2 deletions internal/math32/floats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func TestDot(t *testing.T) {
}
}

// BenchmarkDot-10 7623 157954 ns/op 0 B/op 0 allocs/op
func BenchmarkDot(b *testing.B) {
// Generate random float32 slices for benchmarking.
const size = 1000000 // Size of slices
Expand Down Expand Up @@ -72,7 +71,6 @@ func TestSquaredL2(t *testing.T) {
}
}

// BenchmarkSquaredL2-10 5128 235120 ns/op 0 B/op 0 allocs/op
func BenchmarkSquaredL2(b *testing.B) {
// Generate random float32 slices for benchmarking.
const size = 1000000 // Size of slices
Expand Down
1 change: 1 addition & 0 deletions internal/math32/src/floats_avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ void _dot_product_avx(float *a, float *b, long n, float *res)
{
*res += temp[j];
}

for (; i < n; i++)
{
*res += a[i] * b[i];
Expand Down
Loading

0 comments on commit 56a7f12

Please sign in to comment.