From f84e83d86b3f8177d3c16090789f911ea5f07fd9 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Sat, 9 Mar 2024 22:47:51 +0100 Subject: [PATCH] Add (simd) cosine similarity --- .github/workflows/build.yml | 2 +- go.mod | 2 +- internal/math32/floats.go | 24 ++++++++++++ internal/math32/floats_amd64.go | 20 ++++++++++ internal/math32/floats_arm64.go | 31 +++++++++++++++ internal/math32/floats_arm64.s | 67 +++++++++++++++++++++++++++++++++ internal/math32/floats_noasm.go | 7 ++++ internal/math32/floats_test.go | 52 +++++++++++++++++++++++++ internal/util/math.go | 37 +++++++++--------- internal/util/math_test.go | 58 ---------------------------- metric/cosine.go | 29 ++++++++++++++ metric/cosine_test.go | 65 ++++++++++++++++++++++++++++++++ 12 files changed, 317 insertions(+), 77 deletions(-) create mode 100644 internal/math32/floats.go create mode 100644 internal/math32/floats_amd64.go create mode 100644 internal/math32/floats_arm64.go create mode 100644 internal/math32/floats_arm64.s create mode 100644 internal/math32/floats_noasm.go create mode 100644 internal/math32/floats_test.go create mode 100644 metric/cosine.go create mode 100644 metric/cosine_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 574b68f..7f9f80b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,7 +26,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: 1.21.x + go-version: 1.22.x - name: Run Linter uses: golangci/golangci-lint-action@v3 diff --git a/go.mod b/go.mod index 637610a..c9a6022 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/weaviate/weaviate v1.24.1 golang.org/x/net v0.22.0 + golang.org/x/sys v0.18.0 google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 ) @@ -97,7 +98,6 @@ require ( golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/mod v0.16.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect - golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.19.0 // indirect google.golang.org/api v0.169.0 // indirect diff --git a/internal/math32/floats.go b/internal/math32/floats.go new file mode 100644 index 0000000..9ce32bf --- /dev/null +++ b/internal/math32/floats.go @@ -0,0 +1,24 @@ +package math32 + +var ( + useAVX512 bool // nolint unused + useNEON bool // nolint unused +) + +// Dot two vectors. +func Dot(a, b []float32) (ret float32) { + if len(a) != len(b) { + panic("slice lengths do not match") + } + + return dot(a, b) +} + +func dotGeneric(a, b []float32) float32 { + var ret float32 + for i := range a { + ret += a[i] * b[i] + } + + return ret +} diff --git a/internal/math32/floats_amd64.go b/internal/math32/floats_amd64.go new file mode 100644 index 0000000..f46ffe6 --- /dev/null +++ b/internal/math32/floats_amd64.go @@ -0,0 +1,20 @@ +//go:build amd64 && !noasm + +package math32 + +import ( + "golang.org/x/sys/cpu" +) + +func init() { + useAVX512 = cpu.X86.HasAVX512 +} + +func dot(a, b []float32) float32 { + switch { + case useAVX512: + return dotGeneric(a, b) // TODO + default: + return dotGeneric(a, b) + } +} diff --git a/internal/math32/floats_arm64.go b/internal/math32/floats_arm64.go new file mode 100644 index 0000000..a23a7b8 --- /dev/null +++ b/internal/math32/floats_arm64.go @@ -0,0 +1,31 @@ +//go:build arm64 && !noasm + +package math32 + +import ( + "unsafe" + + "golang.org/x/sys/cpu" +) + +func init() { + useNEON = cpu.ARM64.HasASIMD +} + +//go:noescape +func vdotNEON(a unsafe.Pointer, b unsafe.Pointer, n uintptr, ret unsafe.Pointer) + +func dot(a, b []float32) float32 { + switch { + case useNEON: + var ret float32 + + if len(a) > 0 { + vdotNEON(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), uintptr(len(a)), unsafe.Pointer(&ret)) + } + + return ret + default: + return dotGeneric(a, b) + } +} diff --git a/internal/math32/floats_arm64.s b/internal/math32/floats_arm64.s new file mode 100644 index 0000000..42393a4 --- /dev/null +++ b/internal/math32/floats_arm64.s @@ -0,0 +1,67 @@ +//go:build !noasm && arm64 + +TEXT ·vdotNEON(SB), $0-32 + MOVD a+0(FP), R0 // Move the value of 'a' into register R0 + MOVD b+8(FP), R1 // Move the value of 'b' into register R1 + MOVD n+16(FP), R2 // Move the value of 'n' into register R2 + MOVD ret+24(FP), R3 // Move the address of the return value into register R3 + WORD $0xa9bf7bfd // Save the frame pointer and link register to the stack + WORD $0x91000c48 // Add the value of register x2 to register x8 and store the result in x8 + WORD $0xf100005f // Compare the value in register x2 with 0 and set flags + WORD $0x9a82b108 // Conditional select: If the previous comparison result is less than, set x8 to x2, else keep x8 unchanged + WORD $0x9342fd0a // Arithmetic shift right: Shift the value in x8 right by the number of bits specified in x10 and store the result in x10 + WORD $0x927ef508 // Bitwise AND: Perform a bitwise AND operation between the values in x8 and x8, store the result in x8 + WORD $0x7100055f // Compare the value in w10 with 0 and set flags + WORD $0xcb080048 // Subtract the value in x8 from the value in x2 and store the result in x8 + WORD $0x910003fd // Move the value in the stack pointer to register x29 + WORD $0x540002ab // Branch to label .LBB4_5 if the previous comparison result is less than + WORD $0x3cc10400 // Load a quadword from the memory address stored in register x0 into vector register q0 + WORD $0x3cc10421 // Load a quadword from the memory address stored in register x1 into vector register q1 + WORD $0x71000549 // Subtract the value in w10 from w9 and set flags + WORD $0x6e21dc00 // Multiply the vectors in v0 and v1 element-wise and store the result in v0 + WORD $0x54000200 // Branch to label .LBB4_6 if the previous comparison result is equal to + WORD $0xb27d7beb // Move the value in register x11 to register x11 + WORD $0x8b0a096a // Add the value in x11 shifted left by the value in x10 to the value in x10 and store the result in x10 + WORD $0x927e7d4a // Bitwise AND: Perform a bitwise AND operation between the values in x10 and x10, store the result in x10 + WORD $0x9100114b // Add the value in x11 to the value in x10 and store the result in x11 + WORD $0x8b0b080a // Add the value in x0 to the value in x11 shifted left by the value in x10 and store the result in x10 + WORD $0xaa0103ec // Move the value in register x1 to register x12 + +LBB4_3: + WORD $0x3cc10401 // Load a quadword from the memory address stored in register x0 into vector register q1 + WORD $0x3cc10582 // Load a quadword from the memory address stored in register x12 into vector register q2 + WORD $0x71000529 // Subtract the value in w10 from w9 and set flags + WORD $0x6e22dc21 // Multiply the vectors in v1 and v2 element-wise and store the result in v1 + WORD $0x4e21d400 // Add the vectors in v0 and v1 and store the result in v0 + WORD $0x54ffff61 // Branch to label .LBB4_3 if the previous comparison result is not equal to + WORD $0x8b0b0821 // Add the value in x11 to the value in x1 and store the result in x1 + WORD $0xaa0a03e0 // Move the value in register x10 to register x0 + WORD $0x14000001 // Unconditional branch to label .LBB4_6 + +LBB4_5: +LBB4_6: + WORD $0x1e2703e1 // Move zero to register s1 + WORD $0x5e0c0402 // Move the value in register v0.s[1] to register s2 + WORD $0x5e140403 // Move the value in register v0.s[2] to register s3 + WORD $0x5e1c0404 // Move the value in register v0.s[3] to register s4 + WORD $0x1e212800 // Add the value in s0 to the value in s1 and store the result in s0 + WORD $0x1e202840 // Add the value in s2 to the value in s0 and store the result in s0 + WORD $0x1e202860 // Add the value in s3 to the value in s0 and store the result in s0 + WORD $0x1e202880 // Add the value in s4 to the value in s0 and store the result in s0 + WORD $0x7100011f // Compare the value in w8 with 0 and set flags + WORD $0xbd000060 // Store the value in register s0 to the memory address stored in register x3 + WORD $0x5400012d // Branch to label .LBB4_9 if the previous comparison result is less than or equal to + WORD $0x92407d08 // Bitwise AND: Perform a bitwise AND operation between the values in x8 and x8, store the result in x8 + +LBB4_8: + WORD $0xbc404401 // Load a single precision floating-point value from the memory address stored in register x0 into register s1 + WORD $0xbc404422 // Load a single precision floating-point value from the memory address stored in register x1 into register s2 + WORD $0xf1000508 // Subtract the value in x8 from 0 and set flags + WORD $0x1e220821 // Multiply the values in s1 and s2 and store the result in s1 + WORD $0x1e212800 // Add the value in s0 to the value in s1 and store the result in s0 + WORD $0xbd000060 // Store the value in register s0 to the memory address stored in register x3 + WORD $0x54ffff41 // Branch to label .LBB4_8 if the previous comparison result is not equal to + +LBB4_9: + WORD $0xa8c17bfd // Load the frame pointer and link register from the stack + WORD $0xd65f03c0 // Return diff --git a/internal/math32/floats_noasm.go b/internal/math32/floats_noasm.go new file mode 100644 index 0000000..e28622a --- /dev/null +++ b/internal/math32/floats_noasm.go @@ -0,0 +1,7 @@ +//go:build noasm || (!arm64 && !amd64) + +package math32 + +func dot(a, b []float32) float32 { + return dotGeneric(a, b) +} diff --git a/internal/math32/floats_test.go b/internal/math32/floats_test.go new file mode 100644 index 0000000..519ac04 --- /dev/null +++ b/internal/math32/floats_test.go @@ -0,0 +1,52 @@ +package math32 + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDot(t *testing.T) { + tests := []struct { + name string + a, b []float32 + expected float32 + }{ + {"Positive values", []float32{1, 2, 3}, []float32{4, 5, 6}, 32.0}, + {"Negative values", []float32{-1, -2, -3}, []float32{-4, -5, -6}, 32.0}, + {"Mixed values", []float32{1, -2, 3}, []float32{-4, 5, -6}, -32.0}, + {"Zero values", []float32{0, 0, 0}, []float32{0, 0, 0}, 0.0}, + {"Different lengths", []float32{1, 2}, []float32{3, 4, 5}, 0.0}, // Expecting panic + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.name == "Different lengths" { + assert.Panics(t, func() { Dot(tc.a, tc.b) }) + } else { + result := Dot(tc.a, tc.b) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func BenchmarkDot(b *testing.B) { + // Generate random float32 slices for benchmarking. + const size = 1000000 // Size of slices + va := make([]float32, size) + vb := make([]float32, size) + + for i := range va { + va[i] = rand.Float32() // nolint gosec + vb[i] = rand.Float32() // nolint gosec + } + + // Run the Dot function b.N times and measure the time taken. + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _ = Dot(va, vb) + } +} diff --git a/internal/util/math.go b/internal/util/math.go index 2d4317a..2f740a4 100644 --- a/internal/util/math.go +++ b/internal/util/math.go @@ -1,26 +1,29 @@ package util -import "math" +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} -func CosineSimilarity(matrix1, matrix2 [][]float64) float64 { - dotProduct := 0.0 - magnitude1 := 0.0 - magnitude2 := 0.0 +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} - for i := 0; i < len(matrix1); i++ { - for j := 0; j < len(matrix1[0]); j++ { - dotProduct += matrix1[i][j] * matrix2[i][j] - magnitude1 += math.Pow(matrix1[i][j], 2) - magnitude2 += math.Pow(matrix2[i][j], 2) - } - } +type Integer interface { + Signed | Unsigned +} - magnitude1 = math.Sqrt(magnitude1) - magnitude2 = math.Sqrt(magnitude2) +type Float interface { + ~float32 | ~float64 +} + +type Number interface { + Integer | Float +} - if magnitude1 == 0 || magnitude2 == 0 { - return 0.0 // Handle zero magnitude case +func Min[T Number](a, b T) T { + if a < b { + return a } - return dotProduct / (magnitude1 * magnitude2) + return b } diff --git a/internal/util/math_test.go b/internal/util/math_test.go index 033b8bb..c7d8682 100644 --- a/internal/util/math_test.go +++ b/internal/util/math_test.go @@ -1,59 +1 @@ package util - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCosineSimilarity(t *testing.T) { - t.Run("Default", func(t *testing.T) { - matrix1 := [][]float64{{1.0, 2.0}, {3.0, 4.0}} - matrix2 := [][]float64{{1.0, 0.0}, {0.0, 1.0}} - - expected := 0.6454972243679028 // Expected cosine similarity value - - result := CosineSimilarity(matrix1, matrix2) - assert.InDelta(t, expected, result, 1e-9) - }) - - t.Run("Zero matrices", func(t *testing.T) { - matrix1 := [][]float64{} - matrix2 := [][]float64{} - - expected := 0.0 // Both matrices are empty, so cosine similarity should be 0 - - result := CosineSimilarity(matrix1, matrix2) - assert.InDelta(t, expected, result, 1e-9) - }) - - t.Run("Matrices with all zeros", func(t *testing.T) { - matrix1 := [][]float64{{0.0, 0.0}, {0.0, 0.0}} - matrix2 := [][]float64{{0.0, 0.0}, {0.0, 0.0}} - - expected := 0.0 // Both matrices have all zeros, so cosine similarity should be 0 - - result := CosineSimilarity(matrix1, matrix2) - assert.InDelta(t, expected, result, 1e-9) - }) - - t.Run("Matrices with orthogonal vectors", func(t *testing.T) { - matrix1 := [][]float64{{1.0, 0.0}, {0.0, 1.0}} - matrix2 := [][]float64{{0.0, 1.0}, {1.0, 0.0}} - - expected := 0.0 // Matrices have orthogonal vectors, so cosine similarity should be 0 - - result := CosineSimilarity(matrix1, matrix2) - assert.InDelta(t, expected, result, 1e-9) - }) - - t.Run("Matrices with identical vectors", func(t *testing.T) { - matrix1 := [][]float64{{1.0, 2.0}, {3.0, 4.0}} - matrix2 := [][]float64{{1.0, 2.0}, {3.0, 4.0}} - - expected := 1.0 // Matrices have identical vectors, so cosine similarity should be 1 - - result := CosineSimilarity(matrix1, matrix2) - assert.InDelta(t, expected, result, 1e-9) - }) -} diff --git a/metric/cosine.go b/metric/cosine.go new file mode 100644 index 0000000..f4743b7 --- /dev/null +++ b/metric/cosine.go @@ -0,0 +1,29 @@ +package metric + +import ( + "github.com/hupe1980/golc/internal/math32" +) + +// Magnitude calculates the magnitude (length) of a float32 slice. +func Magnitude(a []float32) float32 { + return math32.Sqrt(math32.Dot(a, a)) +} + +// CosineSimilarity calculates the cosine similarity between two float32 slices. +func CosineSimilarity(a, b []float32) float32 { + dotProduct := math32.Dot(a, b) + magnitudeA := Magnitude(a) + magnitudeB := Magnitude(b) + + // Avoid division by zero + if magnitudeA == 0 || magnitudeB == 0 { + return 0 + } + + return dotProduct / (magnitudeA * magnitudeB) +} + +// CosineDistance calculates the cosine distance between two float32 slices. +func CosineDistance(a, b []float32) float32 { + return 1 - CosineSimilarity(a, b) +} diff --git a/metric/cosine_test.go b/metric/cosine_test.go new file mode 100644 index 0000000..0c750fa --- /dev/null +++ b/metric/cosine_test.go @@ -0,0 +1,65 @@ +package metric + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMagnitude(t *testing.T) { + tests := []struct { + name string + a []float32 + expected float32 + }{ + {"Positive values", []float32{3, 4}, 5.0}, + {"Negative values", []float32{-3, -4}, 5.0}, + {"Mixed values", []float32{3, -4}, 5.0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := Magnitude(tc.a) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestCosineDistance(t *testing.T) { + tests := []struct { + name string + a, b []float32 + expected float32 + }{ + {"Orthogonal vectors", []float32{1, 0}, []float32{0, 1}, 1.0}, + {"Parallel vectors", []float32{1, 0}, []float32{1, 0}, 0.0}, + {"Opposite vectors", []float32{1, 0}, []float32{-1, 0}, 2.0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := CosineDistance(tc.a, tc.b) + assert.Equal(t, tc.expected, result) + }) + } +} + +// BenchmarkCosineSimilarity benchmarks the CosineSimilarity function. +func BenchmarkCosineSimilarity(b *testing.B) { + // Prepare random input data + const size = 10000 + va := make([]float32, size) + vb := make([]float32, size) + + for i := 0; i < size; i++ { + va[i] = rand.Float32() // nolint gosec + vb[i] = rand.Float32() // nolint gosec + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + CosineSimilarity(va, vb) + } +}