Skip to content

Commit

Permalink
Add (simd) cosine similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Mar 9, 2024
1 parent 605fc56 commit f84e83d
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: 1.21.x
go-version: 1.22.x

- name: Run Linter
uses: golangci/golangci-lint-action@v3
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
github.com/stretchr/testify v1.9.0
github.com/weaviate/weaviate v1.24.1
golang.org/x/net v0.22.0
golang.org/x/sys v0.18.0
google.golang.org/grpc v1.62.1
google.golang.org/protobuf v1.33.0
)
Expand Down Expand Up @@ -97,7 +98,6 @@ require (
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect
golang.org/x/mod v0.16.0 // indirect
golang.org/x/oauth2 v0.18.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.19.0 // indirect
google.golang.org/api v0.169.0 // indirect
Expand Down
24 changes: 24 additions & 0 deletions internal/math32/floats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package math32

var (
useAVX512 bool // nolint unused
useNEON bool // nolint unused
)

// Dot two vectors.
func Dot(a, b []float32) (ret float32) {
if len(a) != len(b) {
panic("slice lengths do not match")
}

return dot(a, b)
}

func dotGeneric(a, b []float32) float32 {
var ret float32
for i := range a {
ret += a[i] * b[i]
}

return ret
}
20 changes: 20 additions & 0 deletions internal/math32/floats_amd64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//go:build amd64 && !noasm

package math32

import (
"golang.org/x/sys/cpu"
)

func init() {
useAVX512 = cpu.X86.HasAVX512
}

func dot(a, b []float32) float32 {
switch {
case useAVX512:
return dotGeneric(a, b) // TODO
default:
return dotGeneric(a, b)
}
}
31 changes: 31 additions & 0 deletions internal/math32/floats_arm64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//go:build arm64 && !noasm

package math32

import (
"unsafe"

"golang.org/x/sys/cpu"
)

func init() {
useNEON = cpu.ARM64.HasASIMD
}

//go:noescape
func vdotNEON(a unsafe.Pointer, b unsafe.Pointer, n uintptr, ret unsafe.Pointer)

func dot(a, b []float32) float32 {
switch {
case useNEON:
var ret float32

if len(a) > 0 {
vdotNEON(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), uintptr(len(a)), unsafe.Pointer(&ret))
}

return ret
default:
return dotGeneric(a, b)
}
}
67 changes: 67 additions & 0 deletions internal/math32/floats_arm64.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//go:build !noasm && arm64

TEXT ·vdotNEON(SB), $0-32
MOVD a+0(FP), R0 // Move the value of 'a' into register R0
MOVD b+8(FP), R1 // Move the value of 'b' into register R1
MOVD n+16(FP), R2 // Move the value of 'n' into register R2
MOVD ret+24(FP), R3 // Move the address of the return value into register R3
WORD $0xa9bf7bfd // Save the frame pointer and link register to the stack
WORD $0x91000c48 // Add the value of register x2 to register x8 and store the result in x8
WORD $0xf100005f // Compare the value in register x2 with 0 and set flags
WORD $0x9a82b108 // Conditional select: If the previous comparison result is less than, set x8 to x2, else keep x8 unchanged
WORD $0x9342fd0a // Arithmetic shift right: Shift the value in x8 right by the number of bits specified in x10 and store the result in x10
WORD $0x927ef508 // Bitwise AND: Perform a bitwise AND operation between the values in x8 and x8, store the result in x8
WORD $0x7100055f // Compare the value in w10 with 0 and set flags
WORD $0xcb080048 // Subtract the value in x8 from the value in x2 and store the result in x8
WORD $0x910003fd // Move the value in the stack pointer to register x29
WORD $0x540002ab // Branch to label .LBB4_5 if the previous comparison result is less than
WORD $0x3cc10400 // Load a quadword from the memory address stored in register x0 into vector register q0
WORD $0x3cc10421 // Load a quadword from the memory address stored in register x1 into vector register q1
WORD $0x71000549 // Subtract the value in w10 from w9 and set flags
WORD $0x6e21dc00 // Multiply the vectors in v0 and v1 element-wise and store the result in v0
WORD $0x54000200 // Branch to label .LBB4_6 if the previous comparison result is equal to
WORD $0xb27d7beb // Move the value in register x11 to register x11
WORD $0x8b0a096a // Add the value in x11 shifted left by the value in x10 to the value in x10 and store the result in x10
WORD $0x927e7d4a // Bitwise AND: Perform a bitwise AND operation between the values in x10 and x10, store the result in x10
WORD $0x9100114b // Add the value in x11 to the value in x10 and store the result in x11
WORD $0x8b0b080a // Add the value in x0 to the value in x11 shifted left by the value in x10 and store the result in x10
WORD $0xaa0103ec // Move the value in register x1 to register x12

LBB4_3:
WORD $0x3cc10401 // Load a quadword from the memory address stored in register x0 into vector register q1
WORD $0x3cc10582 // Load a quadword from the memory address stored in register x12 into vector register q2
WORD $0x71000529 // Subtract the value in w10 from w9 and set flags
WORD $0x6e22dc21 // Multiply the vectors in v1 and v2 element-wise and store the result in v1
WORD $0x4e21d400 // Add the vectors in v0 and v1 and store the result in v0
WORD $0x54ffff61 // Branch to label .LBB4_3 if the previous comparison result is not equal to
WORD $0x8b0b0821 // Add the value in x11 to the value in x1 and store the result in x1
WORD $0xaa0a03e0 // Move the value in register x10 to register x0
WORD $0x14000001 // Unconditional branch to label .LBB4_6

LBB4_5:
LBB4_6:
WORD $0x1e2703e1 // Move zero to register s1
WORD $0x5e0c0402 // Move the value in register v0.s[1] to register s2
WORD $0x5e140403 // Move the value in register v0.s[2] to register s3
WORD $0x5e1c0404 // Move the value in register v0.s[3] to register s4
WORD $0x1e212800 // Add the value in s0 to the value in s1 and store the result in s0
WORD $0x1e202840 // Add the value in s2 to the value in s0 and store the result in s0
WORD $0x1e202860 // Add the value in s3 to the value in s0 and store the result in s0
WORD $0x1e202880 // Add the value in s4 to the value in s0 and store the result in s0
WORD $0x7100011f // Compare the value in w8 with 0 and set flags
WORD $0xbd000060 // Store the value in register s0 to the memory address stored in register x3
WORD $0x5400012d // Branch to label .LBB4_9 if the previous comparison result is less than or equal to
WORD $0x92407d08 // Bitwise AND: Perform a bitwise AND operation between the values in x8 and x8, store the result in x8

LBB4_8:
WORD $0xbc404401 // Load a single precision floating-point value from the memory address stored in register x0 into register s1
WORD $0xbc404422 // Load a single precision floating-point value from the memory address stored in register x1 into register s2
WORD $0xf1000508 // Subtract the value in x8 from 0 and set flags
WORD $0x1e220821 // Multiply the values in s1 and s2 and store the result in s1
WORD $0x1e212800 // Add the value in s0 to the value in s1 and store the result in s0
WORD $0xbd000060 // Store the value in register s0 to the memory address stored in register x3
WORD $0x54ffff41 // Branch to label .LBB4_8 if the previous comparison result is not equal to

LBB4_9:
WORD $0xa8c17bfd // Load the frame pointer and link register from the stack
WORD $0xd65f03c0 // Return
7 changes: 7 additions & 0 deletions internal/math32/floats_noasm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build noasm || (!arm64 && !amd64)

package math32

func dot(a, b []float32) float32 {
return dotGeneric(a, b)
}
52 changes: 52 additions & 0 deletions internal/math32/floats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package math32

import (
"math/rand"
"testing"

"github.com/stretchr/testify/assert"
)

func TestDot(t *testing.T) {
tests := []struct {
name string
a, b []float32
expected float32
}{
{"Positive values", []float32{1, 2, 3}, []float32{4, 5, 6}, 32.0},
{"Negative values", []float32{-1, -2, -3}, []float32{-4, -5, -6}, 32.0},
{"Mixed values", []float32{1, -2, 3}, []float32{-4, 5, -6}, -32.0},
{"Zero values", []float32{0, 0, 0}, []float32{0, 0, 0}, 0.0},
{"Different lengths", []float32{1, 2}, []float32{3, 4, 5}, 0.0}, // Expecting panic
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.name == "Different lengths" {
assert.Panics(t, func() { Dot(tc.a, tc.b) })
} else {
result := Dot(tc.a, tc.b)
assert.Equal(t, tc.expected, result)
}
})
}
}

func BenchmarkDot(b *testing.B) {
// Generate random float32 slices for benchmarking.
const size = 1000000 // Size of slices
va := make([]float32, size)
vb := make([]float32, size)

for i := range va {
va[i] = rand.Float32() // nolint gosec
vb[i] = rand.Float32() // nolint gosec
}

// Run the Dot function b.N times and measure the time taken.
b.ResetTimer()

for i := 0; i < b.N; i++ {
_ = Dot(va, vb)
}
}
37 changes: 20 additions & 17 deletions internal/util/math.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
package util

import "math"
type Signed interface {
~int | ~int8 | ~int16 | ~int32 | ~int64
}

func CosineSimilarity(matrix1, matrix2 [][]float64) float64 {
dotProduct := 0.0
magnitude1 := 0.0
magnitude2 := 0.0
type Unsigned interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
}

for i := 0; i < len(matrix1); i++ {
for j := 0; j < len(matrix1[0]); j++ {
dotProduct += matrix1[i][j] * matrix2[i][j]
magnitude1 += math.Pow(matrix1[i][j], 2)
magnitude2 += math.Pow(matrix2[i][j], 2)
}
}
type Integer interface {
Signed | Unsigned
}

magnitude1 = math.Sqrt(magnitude1)
magnitude2 = math.Sqrt(magnitude2)
type Float interface {
~float32 | ~float64
}

type Number interface {
Integer | Float
}

if magnitude1 == 0 || magnitude2 == 0 {
return 0.0 // Handle zero magnitude case
func Min[T Number](a, b T) T {
if a < b {
return a
}

return dotProduct / (magnitude1 * magnitude2)
return b
}
58 changes: 0 additions & 58 deletions internal/util/math_test.go
Original file line number Diff line number Diff line change
@@ -1,59 +1 @@
package util

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestCosineSimilarity(t *testing.T) {
t.Run("Default", func(t *testing.T) {
matrix1 := [][]float64{{1.0, 2.0}, {3.0, 4.0}}
matrix2 := [][]float64{{1.0, 0.0}, {0.0, 1.0}}

expected := 0.6454972243679028 // Expected cosine similarity value

result := CosineSimilarity(matrix1, matrix2)
assert.InDelta(t, expected, result, 1e-9)
})

t.Run("Zero matrices", func(t *testing.T) {
matrix1 := [][]float64{}
matrix2 := [][]float64{}

expected := 0.0 // Both matrices are empty, so cosine similarity should be 0

result := CosineSimilarity(matrix1, matrix2)
assert.InDelta(t, expected, result, 1e-9)
})

t.Run("Matrices with all zeros", func(t *testing.T) {
matrix1 := [][]float64{{0.0, 0.0}, {0.0, 0.0}}
matrix2 := [][]float64{{0.0, 0.0}, {0.0, 0.0}}

expected := 0.0 // Both matrices have all zeros, so cosine similarity should be 0

result := CosineSimilarity(matrix1, matrix2)
assert.InDelta(t, expected, result, 1e-9)
})

t.Run("Matrices with orthogonal vectors", func(t *testing.T) {
matrix1 := [][]float64{{1.0, 0.0}, {0.0, 1.0}}
matrix2 := [][]float64{{0.0, 1.0}, {1.0, 0.0}}

expected := 0.0 // Matrices have orthogonal vectors, so cosine similarity should be 0

result := CosineSimilarity(matrix1, matrix2)
assert.InDelta(t, expected, result, 1e-9)
})

t.Run("Matrices with identical vectors", func(t *testing.T) {
matrix1 := [][]float64{{1.0, 2.0}, {3.0, 4.0}}
matrix2 := [][]float64{{1.0, 2.0}, {3.0, 4.0}}

expected := 1.0 // Matrices have identical vectors, so cosine similarity should be 1

result := CosineSimilarity(matrix1, matrix2)
assert.InDelta(t, expected, result, 1e-9)
})
}
29 changes: 29 additions & 0 deletions metric/cosine.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package metric

import (
"github.com/hupe1980/golc/internal/math32"
)

// Magnitude calculates the magnitude (length) of a float32 slice.
func Magnitude(a []float32) float32 {
return math32.Sqrt(math32.Dot(a, a))
}

// CosineSimilarity calculates the cosine similarity between two float32 slices.
func CosineSimilarity(a, b []float32) float32 {
dotProduct := math32.Dot(a, b)
magnitudeA := Magnitude(a)
magnitudeB := Magnitude(b)

// Avoid division by zero
if magnitudeA == 0 || magnitudeB == 0 {
return 0
}

return dotProduct / (magnitudeA * magnitudeB)
}

// CosineDistance calculates the cosine distance between two float32 slices.
func CosineDistance(a, b []float32) float32 {
return 1 - CosineSimilarity(a, b)
}
Loading

0 comments on commit f84e83d

Please sign in to comment.