diff --git a/api_utils.go b/api_utils.go index 3cf55f0..2841d45 100644 --- a/api_utils.go +++ b/api_utils.go @@ -1,7 +1,6 @@ package tensor import ( - "log" "math" "math/rand" "reflect" @@ -10,41 +9,49 @@ import ( "github.com/chewxy/math32" ) -// SortIndex is similar to numpy's argsort -// TODO: tidy this up +// SortIndex: Similar to numpy's argsort. +// Returns indices for sorting a slice in increasing order. +// Input slice remains unchanged. +// SortIndex may not be stable; for stability, use SortIndexStable. func SortIndex(in interface{}) (out []int) { + return sortIndex(in, sort.Slice) +} + +// SortIndexStable: Similar to SortIndex, but stable. +// Returns indices for sorting a slice in increasing order. +// Input slice remains unchanged. +func SortIndexStable(in interface{}) (out []int) { + return sortIndex(in, sort.SliceStable) +} + +func sortIndex(in interface{}, sortFunc func(x interface{}, less func(i int, j int) bool)) (out []int) { switch list := in.(type) { case []int: - orig := make([]int, len(list)) out = make([]int, len(list)) - copy(orig, list) - sort.Ints(list) - for i, s := range list { - for j, o := range orig { - if o == s { - out[i] = j - break - } - } + for i := 0; i < len(list); i++ { + out[i] = i } + sortFunc(out, func(i, j int) bool { + return list[out[i]] < list[out[j]] + }) case []float64: - orig := make([]float64, len(list)) out = make([]int, len(list)) - copy(orig, list) - sort.Float64s(list) - - for i, s := range list { - for j, o := range orig { - if o == s { - out[i] = j - break - } - } + for i := 0; i < len(list); i++ { + out[i] = i } + sortFunc(out, func(i, j int) bool { + return list[out[i]] < list[out[j]] + }) case sort.Interface: - sort.Sort(list) - - log.Printf("TODO: SortIndex for sort.Interface not yet done.") + out = make([]int, list.Len()) + for i := 0; i < list.Len(); i++ { + out[i] = i + } + sortFunc(out, func(i, j int) bool { + return list.Less(out[i], out[j]) + }) + default: + panic("The slice type is not currently supported.") } return diff --git a/api_utils_test.go b/api_utils_test.go new file mode 100644 index 0000000..243cc6f --- /dev/null +++ b/api_utils_test.go @@ -0,0 +1,62 @@ +package tensor + +import ( + "testing" +) + +type testInt []int + +func (m testInt) Less(i, j int) bool { return m[i] < m[j] } +func (m testInt) Len() int { return len(m) } +func (m testInt) Swap(i, j int) { m[i], m[j] = m[j], m[i] } + +func TestSortIndexInts(t *testing.T) { + in := []int{9, 8, 7, 6, 5, 4, 10, -1, -2, -4, 11, 13, 15, 100, 99} + inCopy := make([]int, len(in)) + copy(inCopy, in) + out := SortIndex(in) + for i := 1; i < len(out); i++ { + if inCopy[out[i]] < inCopy[out[i-1]] { + t.Fatalf("Unexpected output") + } + } + for i := range in { + if in[i] != inCopy[i] { + t.Fatalf("The input slice should not be changed") + } + } +} + +func TestSortIndexFloats(t *testing.T) { + in := []float64{.9, .8, .7, .6, .5, .4, .10, -.1, -.2, -.4, .11, .13, .15, .100, .99} + inCopy := make([]float64, len(in)) + copy(inCopy, in) + out := SortIndex(in) + for i := 1; i < len(out); i++ { + if inCopy[out[i]] < inCopy[out[i-1]] { + t.Fatalf("Unexpected output") + } + } + for i := range in { + if in[i] != inCopy[i] { + t.Fatalf("The input slice should not be changed") + } + } +} + +func TestSortIndexSortInterface(t *testing.T) { + in := testInt{9, 8, 7, 6, 5, 4, 10, -1, -2, -4, 11, 13, 15, 100, 99} + inCopy := make(testInt, len(in)) + copy(inCopy, in) + out := SortIndex(in) + for i := 1; i < len(out); i++ { + if inCopy[out[i]] < inCopy[out[i-1]] { + t.Fatalf("Unexpected output") + } + } + for i := range in { + if in[i] != inCopy[i] { + t.Fatalf("The input slice should not be changed") + } + } +}