Skip to content

Commit

Permalink
Fix #140 (#141)
Browse files Browse the repository at this point in the history
* Fix #140

+ Fix SortIndex()
+ Add SortIndexStable()

* `any` is not supported in Go1.15

---------

Co-authored-by: Chewxy <[email protected]>
  • Loading branch information
ksw2000 and chewxy authored Apr 9, 2024
1 parent 01c8417 commit c35555f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 27 deletions.
61 changes: 34 additions & 27 deletions api_utils.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tensor

import (
"log"
"math"
"math/rand"
"reflect"
Expand All @@ -10,41 +9,49 @@ import (
"github.com/chewxy/math32"
)

// SortIndex is similar to numpy's argsort
// TODO: tidy this up
// SortIndex: Similar to numpy's argsort.
// Returns indices for sorting a slice in increasing order.
// Input slice remains unchanged.
// SortIndex may not be stable; for stability, use SortIndexStable.
func SortIndex(in interface{}) (out []int) {
return sortIndex(in, sort.Slice)
}

// SortIndexStable: Similar to SortIndex, but stable.
// Returns indices for sorting a slice in increasing order.
// Input slice remains unchanged.
func SortIndexStable(in interface{}) (out []int) {
return sortIndex(in, sort.SliceStable)
}

func sortIndex(in interface{}, sortFunc func(x interface{}, less func(i int, j int) bool)) (out []int) {
switch list := in.(type) {
case []int:
orig := make([]int, len(list))
out = make([]int, len(list))
copy(orig, list)
sort.Ints(list)
for i, s := range list {
for j, o := range orig {
if o == s {
out[i] = j
break
}
}
for i := 0; i < len(list); i++ {
out[i] = i
}
sortFunc(out, func(i, j int) bool {
return list[out[i]] < list[out[j]]
})
case []float64:
orig := make([]float64, len(list))
out = make([]int, len(list))
copy(orig, list)
sort.Float64s(list)

for i, s := range list {
for j, o := range orig {
if o == s {
out[i] = j
break
}
}
for i := 0; i < len(list); i++ {
out[i] = i
}
sortFunc(out, func(i, j int) bool {
return list[out[i]] < list[out[j]]
})
case sort.Interface:
sort.Sort(list)

log.Printf("TODO: SortIndex for sort.Interface not yet done.")
out = make([]int, list.Len())
for i := 0; i < list.Len(); i++ {
out[i] = i
}
sortFunc(out, func(i, j int) bool {
return list.Less(out[i], out[j])
})
default:
panic("The slice type is not currently supported.")
}

return
Expand Down
62 changes: 62 additions & 0 deletions api_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package tensor

import (
"testing"
)

type testInt []int

func (m testInt) Less(i, j int) bool { return m[i] < m[j] }
func (m testInt) Len() int { return len(m) }
func (m testInt) Swap(i, j int) { m[i], m[j] = m[j], m[i] }

func TestSortIndexInts(t *testing.T) {
in := []int{9, 8, 7, 6, 5, 4, 10, -1, -2, -4, 11, 13, 15, 100, 99}
inCopy := make([]int, len(in))
copy(inCopy, in)
out := SortIndex(in)
for i := 1; i < len(out); i++ {
if inCopy[out[i]] < inCopy[out[i-1]] {
t.Fatalf("Unexpected output")
}
}
for i := range in {
if in[i] != inCopy[i] {
t.Fatalf("The input slice should not be changed")
}
}
}

func TestSortIndexFloats(t *testing.T) {
in := []float64{.9, .8, .7, .6, .5, .4, .10, -.1, -.2, -.4, .11, .13, .15, .100, .99}
inCopy := make([]float64, len(in))
copy(inCopy, in)
out := SortIndex(in)
for i := 1; i < len(out); i++ {
if inCopy[out[i]] < inCopy[out[i-1]] {
t.Fatalf("Unexpected output")
}
}
for i := range in {
if in[i] != inCopy[i] {
t.Fatalf("The input slice should not be changed")
}
}
}

func TestSortIndexSortInterface(t *testing.T) {
in := testInt{9, 8, 7, 6, 5, 4, 10, -1, -2, -4, 11, 13, 15, 100, 99}
inCopy := make(testInt, len(in))
copy(inCopy, in)
out := SortIndex(in)
for i := 1; i < len(out); i++ {
if inCopy[out[i]] < inCopy[out[i-1]] {
t.Fatalf("Unexpected output")
}
}
for i := range in {
if in[i] != inCopy[i] {
t.Fatalf("The input slice should not be changed")
}
}
}

0 comments on commit c35555f

Please sign in to comment.