-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
317 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package math32 | ||
|
||
var ( | ||
useAVX512 bool // nolint unused | ||
useNEON bool // nolint unused | ||
) | ||
|
||
// Dot two vectors. | ||
func Dot(a, b []float32) (ret float32) { | ||
if len(a) != len(b) { | ||
panic("slice lengths do not match") | ||
} | ||
|
||
return dot(a, b) | ||
} | ||
|
||
func dotGeneric(a, b []float32) float32 { | ||
var ret float32 | ||
for i := range a { | ||
ret += a[i] * b[i] | ||
} | ||
|
||
return ret | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
//go:build amd64 && !noasm | ||
|
||
package math32 | ||
|
||
import ( | ||
"golang.org/x/sys/cpu" | ||
) | ||
|
||
func init() { | ||
useAVX512 = cpu.X86.HasAVX512 | ||
} | ||
|
||
func dot(a, b []float32) float32 { | ||
switch { | ||
case useAVX512: | ||
return dotGeneric(a, b) // TODO | ||
default: | ||
return dotGeneric(a, b) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
//go:build arm64 && !noasm | ||
|
||
package math32 | ||
|
||
import ( | ||
"unsafe" | ||
|
||
"golang.org/x/sys/cpu" | ||
) | ||
|
||
func init() { | ||
useNEON = cpu.ARM64.HasASIMD | ||
} | ||
|
||
//go:noescape | ||
func vdotNEON(a unsafe.Pointer, b unsafe.Pointer, n uintptr, ret unsafe.Pointer) | ||
|
||
func dot(a, b []float32) float32 { | ||
switch { | ||
case useNEON: | ||
var ret float32 | ||
|
||
if len(a) > 0 { | ||
vdotNEON(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), uintptr(len(a)), unsafe.Pointer(&ret)) | ||
} | ||
|
||
return ret | ||
default: | ||
return dotGeneric(a, b) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
//go:build !noasm && arm64 | ||
|
||
TEXT ·vdotNEON(SB), $0-32 | ||
MOVD a+0(FP), R0 // Move the value of 'a' into register R0 | ||
MOVD b+8(FP), R1 // Move the value of 'b' into register R1 | ||
MOVD n+16(FP), R2 // Move the value of 'n' into register R2 | ||
MOVD ret+24(FP), R3 // Move the address of the return value into register R3 | ||
WORD $0xa9bf7bfd // Save the frame pointer and link register to the stack | ||
WORD $0x91000c48 // Add the value of register x2 to register x8 and store the result in x8 | ||
WORD $0xf100005f // Compare the value in register x2 with 0 and set flags | ||
WORD $0x9a82b108 // Conditional select: If the previous comparison result is less than, set x8 to x2, else keep x8 unchanged | ||
WORD $0x9342fd0a // Arithmetic shift right: Shift the value in x8 right by the number of bits specified in x10 and store the result in x10 | ||
WORD $0x927ef508 // Bitwise AND: Perform a bitwise AND operation between the values in x8 and x8, store the result in x8 | ||
WORD $0x7100055f // Compare the value in w10 with 0 and set flags | ||
WORD $0xcb080048 // Subtract the value in x8 from the value in x2 and store the result in x8 | ||
WORD $0x910003fd // Move the value in the stack pointer to register x29 | ||
WORD $0x540002ab // Branch to label .LBB4_5 if the previous comparison result is less than | ||
WORD $0x3cc10400 // Load a quadword from the memory address stored in register x0 into vector register q0 | ||
WORD $0x3cc10421 // Load a quadword from the memory address stored in register x1 into vector register q1 | ||
WORD $0x71000549 // Subtract the value in w10 from w9 and set flags | ||
WORD $0x6e21dc00 // Multiply the vectors in v0 and v1 element-wise and store the result in v0 | ||
WORD $0x54000200 // Branch to label .LBB4_6 if the previous comparison result is equal to | ||
WORD $0xb27d7beb // Move the value in register x11 to register x11 | ||
WORD $0x8b0a096a // Add the value in x11 shifted left by the value in x10 to the value in x10 and store the result in x10 | ||
WORD $0x927e7d4a // Bitwise AND: Perform a bitwise AND operation between the values in x10 and x10, store the result in x10 | ||
WORD $0x9100114b // Add the value in x11 to the value in x10 and store the result in x11 | ||
WORD $0x8b0b080a // Add the value in x0 to the value in x11 shifted left by the value in x10 and store the result in x10 | ||
WORD $0xaa0103ec // Move the value in register x1 to register x12 | ||
|
||
LBB4_3: | ||
WORD $0x3cc10401 // Load a quadword from the memory address stored in register x0 into vector register q1 | ||
WORD $0x3cc10582 // Load a quadword from the memory address stored in register x12 into vector register q2 | ||
WORD $0x71000529 // Subtract the value in w10 from w9 and set flags | ||
WORD $0x6e22dc21 // Multiply the vectors in v1 and v2 element-wise and store the result in v1 | ||
WORD $0x4e21d400 // Add the vectors in v0 and v1 and store the result in v0 | ||
WORD $0x54ffff61 // Branch to label .LBB4_3 if the previous comparison result is not equal to | ||
WORD $0x8b0b0821 // Add the value in x11 to the value in x1 and store the result in x1 | ||
WORD $0xaa0a03e0 // Move the value in register x10 to register x0 | ||
WORD $0x14000001 // Unconditional branch to label .LBB4_6 | ||
|
||
LBB4_5: | ||
LBB4_6: | ||
WORD $0x1e2703e1 // Move zero to register s1 | ||
WORD $0x5e0c0402 // Move the value in register v0.s[1] to register s2 | ||
WORD $0x5e140403 // Move the value in register v0.s[2] to register s3 | ||
WORD $0x5e1c0404 // Move the value in register v0.s[3] to register s4 | ||
WORD $0x1e212800 // Add the value in s0 to the value in s1 and store the result in s0 | ||
WORD $0x1e202840 // Add the value in s2 to the value in s0 and store the result in s0 | ||
WORD $0x1e202860 // Add the value in s3 to the value in s0 and store the result in s0 | ||
WORD $0x1e202880 // Add the value in s4 to the value in s0 and store the result in s0 | ||
WORD $0x7100011f // Compare the value in w8 with 0 and set flags | ||
WORD $0xbd000060 // Store the value in register s0 to the memory address stored in register x3 | ||
WORD $0x5400012d // Branch to label .LBB4_9 if the previous comparison result is less than or equal to | ||
WORD $0x92407d08 // Bitwise AND: Perform a bitwise AND operation between the values in x8 and x8, store the result in x8 | ||
|
||
LBB4_8: | ||
WORD $0xbc404401 // Load a single precision floating-point value from the memory address stored in register x0 into register s1 | ||
WORD $0xbc404422 // Load a single precision floating-point value from the memory address stored in register x1 into register s2 | ||
WORD $0xf1000508 // Subtract the value in x8 from 0 and set flags | ||
WORD $0x1e220821 // Multiply the values in s1 and s2 and store the result in s1 | ||
WORD $0x1e212800 // Add the value in s0 to the value in s1 and store the result in s0 | ||
WORD $0xbd000060 // Store the value in register s0 to the memory address stored in register x3 | ||
WORD $0x54ffff41 // Branch to label .LBB4_8 if the previous comparison result is not equal to | ||
|
||
LBB4_9: | ||
WORD $0xa8c17bfd // Load the frame pointer and link register from the stack | ||
WORD $0xd65f03c0 // Return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
//go:build noasm || (!arm64 && !amd64) | ||
|
||
package math32 | ||
|
||
func dot(a, b []float32) float32 { | ||
return dotGeneric(a, b) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package math32 | ||
|
||
import ( | ||
"math/rand" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestDot(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
a, b []float32 | ||
expected float32 | ||
}{ | ||
{"Positive values", []float32{1, 2, 3}, []float32{4, 5, 6}, 32.0}, | ||
{"Negative values", []float32{-1, -2, -3}, []float32{-4, -5, -6}, 32.0}, | ||
{"Mixed values", []float32{1, -2, 3}, []float32{-4, 5, -6}, -32.0}, | ||
{"Zero values", []float32{0, 0, 0}, []float32{0, 0, 0}, 0.0}, | ||
{"Different lengths", []float32{1, 2}, []float32{3, 4, 5}, 0.0}, // Expecting panic | ||
} | ||
|
||
for _, tc := range tests { | ||
t.Run(tc.name, func(t *testing.T) { | ||
if tc.name == "Different lengths" { | ||
assert.Panics(t, func() { Dot(tc.a, tc.b) }) | ||
} else { | ||
result := Dot(tc.a, tc.b) | ||
assert.Equal(t, tc.expected, result) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func BenchmarkDot(b *testing.B) { | ||
// Generate random float32 slices for benchmarking. | ||
const size = 1000000 // Size of slices | ||
va := make([]float32, size) | ||
vb := make([]float32, size) | ||
|
||
for i := range va { | ||
va[i] = rand.Float32() // nolint gosec | ||
vb[i] = rand.Float32() // nolint gosec | ||
} | ||
|
||
// Run the Dot function b.N times and measure the time taken. | ||
b.ResetTimer() | ||
|
||
for i := 0; i < b.N; i++ { | ||
_ = Dot(va, vb) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,29 @@ | ||
package util | ||
|
||
import "math" | ||
type Signed interface { | ||
~int | ~int8 | ~int16 | ~int32 | ~int64 | ||
} | ||
|
||
func CosineSimilarity(matrix1, matrix2 [][]float64) float64 { | ||
dotProduct := 0.0 | ||
magnitude1 := 0.0 | ||
magnitude2 := 0.0 | ||
type Unsigned interface { | ||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ||
} | ||
|
||
for i := 0; i < len(matrix1); i++ { | ||
for j := 0; j < len(matrix1[0]); j++ { | ||
dotProduct += matrix1[i][j] * matrix2[i][j] | ||
magnitude1 += math.Pow(matrix1[i][j], 2) | ||
magnitude2 += math.Pow(matrix2[i][j], 2) | ||
} | ||
} | ||
type Integer interface { | ||
Signed | Unsigned | ||
} | ||
|
||
magnitude1 = math.Sqrt(magnitude1) | ||
magnitude2 = math.Sqrt(magnitude2) | ||
type Float interface { | ||
~float32 | ~float64 | ||
} | ||
|
||
type Number interface { | ||
Integer | Float | ||
} | ||
|
||
if magnitude1 == 0 || magnitude2 == 0 { | ||
return 0.0 // Handle zero magnitude case | ||
func Min[T Number](a, b T) T { | ||
if a < b { | ||
return a | ||
} | ||
|
||
return dotProduct / (magnitude1 * magnitude2) | ||
return b | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1 @@ | ||
package util | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestCosineSimilarity(t *testing.T) { | ||
t.Run("Default", func(t *testing.T) { | ||
matrix1 := [][]float64{{1.0, 2.0}, {3.0, 4.0}} | ||
matrix2 := [][]float64{{1.0, 0.0}, {0.0, 1.0}} | ||
|
||
expected := 0.6454972243679028 // Expected cosine similarity value | ||
|
||
result := CosineSimilarity(matrix1, matrix2) | ||
assert.InDelta(t, expected, result, 1e-9) | ||
}) | ||
|
||
t.Run("Zero matrices", func(t *testing.T) { | ||
matrix1 := [][]float64{} | ||
matrix2 := [][]float64{} | ||
|
||
expected := 0.0 // Both matrices are empty, so cosine similarity should be 0 | ||
|
||
result := CosineSimilarity(matrix1, matrix2) | ||
assert.InDelta(t, expected, result, 1e-9) | ||
}) | ||
|
||
t.Run("Matrices with all zeros", func(t *testing.T) { | ||
matrix1 := [][]float64{{0.0, 0.0}, {0.0, 0.0}} | ||
matrix2 := [][]float64{{0.0, 0.0}, {0.0, 0.0}} | ||
|
||
expected := 0.0 // Both matrices have all zeros, so cosine similarity should be 0 | ||
|
||
result := CosineSimilarity(matrix1, matrix2) | ||
assert.InDelta(t, expected, result, 1e-9) | ||
}) | ||
|
||
t.Run("Matrices with orthogonal vectors", func(t *testing.T) { | ||
matrix1 := [][]float64{{1.0, 0.0}, {0.0, 1.0}} | ||
matrix2 := [][]float64{{0.0, 1.0}, {1.0, 0.0}} | ||
|
||
expected := 0.0 // Matrices have orthogonal vectors, so cosine similarity should be 0 | ||
|
||
result := CosineSimilarity(matrix1, matrix2) | ||
assert.InDelta(t, expected, result, 1e-9) | ||
}) | ||
|
||
t.Run("Matrices with identical vectors", func(t *testing.T) { | ||
matrix1 := [][]float64{{1.0, 2.0}, {3.0, 4.0}} | ||
matrix2 := [][]float64{{1.0, 2.0}, {3.0, 4.0}} | ||
|
||
expected := 1.0 // Matrices have identical vectors, so cosine similarity should be 1 | ||
|
||
result := CosineSimilarity(matrix1, matrix2) | ||
assert.InDelta(t, expected, result, 1e-9) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package metric | ||
|
||
import ( | ||
"github.com/hupe1980/golc/internal/math32" | ||
) | ||
|
||
// Magnitude calculates the magnitude (length) of a float32 slice. | ||
func Magnitude(a []float32) float32 { | ||
return math32.Sqrt(math32.Dot(a, a)) | ||
} | ||
|
||
// CosineSimilarity calculates the cosine similarity between two float32 slices. | ||
func CosineSimilarity(a, b []float32) float32 { | ||
dotProduct := math32.Dot(a, b) | ||
magnitudeA := Magnitude(a) | ||
magnitudeB := Magnitude(b) | ||
|
||
// Avoid division by zero | ||
if magnitudeA == 0 || magnitudeB == 0 { | ||
return 0 | ||
} | ||
|
||
return dotProduct / (magnitudeA * magnitudeB) | ||
} | ||
|
||
// CosineDistance calculates the cosine distance between two float32 slices. | ||
func CosineDistance(a, b []float32) float32 { | ||
return 1 - CosineSimilarity(a, b) | ||
} |
Oops, something went wrong.