Skip to content

Commit

Permalink
Support compare functions with SortSlices and SortMaps
Browse files Browse the repository at this point in the history
The SortSlices and SortMaps options predate generics and
accept an interface{}, so it is possible with reflection
to support other function signatures than "func(T, T) bool".

In particular, the Go ecosystem is increasingly moving towards
"func(T, T) int" as the signature for ordering as evidenced
by the newer slices.SortFunc function in stdlib.

Thus, modernize cmpopts by supporting "func(T, T) int".

Also, bump the minimum version to Go 1.21 to match
the minimum supported version of google.golang.org/protobuf.

Fixes #365
  • Loading branch information
dsnet committed Oct 23, 2024
1 parent c3ad843 commit d53b1ba
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
test:
strategy:
matrix:
go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x]
go-version: [1.21.x]
os: [ubuntu-latest, macos-latest]
runs-on: ${{ matrix.os }}
steps:
Expand Down
64 changes: 44 additions & 20 deletions cmp/cmpopts/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,29 @@ import (
)

// SortSlices returns a [cmp.Transformer] option that sorts all []V.
// The less function must be of the form "func(T, T) bool" which is used to
// sort any slice with element type V that is assignable to T.
// The lessOrCompareFunc function must be either
// a less function of the form "func(T, T) bool" or
// a compare function of the format "func(T, T) int"
// which is used to sort any slice with element type V that is assignable to T.
//
// The less function must be:
// A less function must be:
// - Deterministic: less(x, y) == less(x, y)
// - Irreflexive: !less(x, x)
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
//
// The less function does not have to be "total". That is, if !less(x, y) and
// !less(y, x) for two elements x and y, their relative order is maintained.
// A compare function must be:
// - Deterministic: compare(x, y) == compare(x, y)
// - Irreflexive: compare(x, x) == 0
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
//
// The function does not have to be "total". That is, if x != y, but
// less or compare report inequality, their relative order is maintained.
//
// SortSlices can be used in conjunction with [EquateEmpty].
func SortSlices(lessFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessFunc)
if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
func SortSlices(lessOrCompareFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessOrCompareFunc)
if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() {
panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc))
}
ss := sliceSorter{vf.Type().In(0), vf}
return cmp.FilterValues(ss.filter, cmp.Transformer("cmpopts.SortSlices", ss.sort))
Expand Down Expand Up @@ -79,28 +86,40 @@ func (ss sliceSorter) checkSort(v reflect.Value) {
}
func (ss sliceSorter) less(v reflect.Value, i, j int) bool {
vx, vy := v.Index(i), v.Index(j)
return ss.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
vo := ss.fnc.Call([]reflect.Value{vx, vy})[0]
if vo.Kind() == reflect.Bool {
return vo.Bool()
} else {
return vo.Int() < 0
}
}

// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be a
// sorted []struct{K, V}. The less function must be of the form
// "func(T, T) bool" which is used to sort any map with key K that is
// assignable to T.
// SortMaps returns a [cmp.Transformer] option that flattens map[K]V types to be
// a sorted []struct{K, V}. The lessOrCompareFunc function must be either
// a less function of the form "func(T, T) bool" or
// a compare function of the format "func(T, T) int"
// which is used to sort any map with key K that is assignable to T.
//
// Flattening the map into a slice has the property that [cmp.Equal] is able to
// use [cmp.Comparer] options on K or the K.Equal method if it exists.
//
// The less function must be:
// A less function must be:
// - Deterministic: less(x, y) == less(x, y)
// - Irreflexive: !less(x, x)
// - Transitive: if !less(x, y) and !less(y, z), then !less(x, z)
// - Total: if x != y, then either less(x, y) or less(y, x)
//
// A compare function must be:
// - Deterministic: compare(x, y) == compare(x, y)
// - Irreflexive: compare(x, x) == 0
// - Transitive: if compare(x, y) < 0 and compare(y, z) < 0, then compare(x, z) < 0
// - Total: if x != y, then compare(x, y) != 0
//
// SortMaps can be used in conjunction with [EquateEmpty].
func SortMaps(lessFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessFunc)
if !function.IsType(vf.Type(), function.Less) || vf.IsNil() {
panic(fmt.Sprintf("invalid less function: %T", lessFunc))
func SortMaps(lessOrCompareFunc interface{}) cmp.Option {
vf := reflect.ValueOf(lessOrCompareFunc)
if (!function.IsType(vf.Type(), function.Less) && !function.IsType(vf.Type(), function.Compare)) || vf.IsNil() {
panic(fmt.Sprintf("invalid less or compare function: %T", lessOrCompareFunc))
}
ms := mapSorter{vf.Type().In(0), vf}
return cmp.FilterValues(ms.filter, cmp.Transformer("cmpopts.SortMaps", ms.sort))
Expand Down Expand Up @@ -143,5 +162,10 @@ func (ms mapSorter) checkSort(v reflect.Value) {
}
func (ms mapSorter) less(v reflect.Value, i, j int) bool {
vx, vy := v.Index(i).Field(0), v.Index(j).Field(0)
return ms.fnc.Call([]reflect.Value{vx, vy})[0].Bool()
vo := ms.fnc.Call([]reflect.Value{vx, vy})[0]
if vo.Kind() == reflect.Bool {
return vo.Bool()
} else {
return vo.Int() < 0
}
}
48 changes: 34 additions & 14 deletions cmp/cmpopts/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,23 @@ func TestOptions(t *testing.T) {
opts: []cmp.Option{SortSlices(func(x, y int) bool { return x < y })},
wantEqual: true,
reason: "equal because SortSlices sorts the slices",
}, {
label: "SortSlices",
x: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
y: []int{1, 0, 5, 2, 8, 9, 4, 3, 6, 7},
opts: []cmp.Option{SortSlices(func(x, y int) int {
// TODO(Go1.22): Use cmp.Compare.
switch {
case x < y:
return -1
case y > x:
return +1
default:
return 0
}
})},
wantEqual: true,
reason: "equal because SortSlices sorts the slices",
}, {
label: "SortSlices",
x: []MyInt{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
Expand Down Expand Up @@ -201,6 +218,21 @@ func TestOptions(t *testing.T) {
opts: []cmp.Option{SortMaps(func(x, y time.Time) bool { return x.Before(y) })},
wantEqual: true,
reason: "equal because SortMaps flattens to a slice where Time.Equal can be used",
}, {
label: "SortMaps",
x: map[time.Time]string{
time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC): "0th birthday",
time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC): "1st birthday",
time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC): "2nd birthday",
},
y: map[time.Time]string{
time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "0th birthday",
time.Date(2010, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "1st birthday",
time.Date(2011, time.November, 10, 23, 0, 0, 0, time.UTC).In(time.Local): "2nd birthday",
},
opts: []cmp.Option{SortMaps(func(x, y time.Time) int { return time.Time.Compare(x, y) })},
wantEqual: true,
reason: "equal because SortMaps flattens to a slice where Time.Equal can be used",
}, {
label: "SortMaps",
x: map[MyTime]string{
Expand Down Expand Up @@ -1184,29 +1216,17 @@ func TestPanic(t *testing.T) {
args: args(time.Duration(-1)),
wantPanic: "margin must be a non-negative number",
reason: "negative duration is invalid",
}, {
label: "SortSlices",
fnc: SortSlices,
args: args(strings.Compare),
wantPanic: "invalid less function",
reason: "func(x, y string) int is wrong signature for less",
}, {
label: "SortSlices",
fnc: SortSlices,
args: args((func(_, _ int) bool)(nil)),
wantPanic: "invalid less function",
wantPanic: "invalid less or compare function",
reason: "nil value is not valid",
}, {
label: "SortMaps",
fnc: SortMaps,
args: args(strings.Compare),
wantPanic: "invalid less function",
reason: "func(x, y string) int is wrong signature for less",
}, {
label: "SortMaps",
fnc: SortMaps,
args: args((func(_, _ int) bool)(nil)),
wantPanic: "invalid less function",
wantPanic: "invalid less or compare function",
reason: "nil value is not valid",
}, {
label: "IgnoreFields",
Expand Down
7 changes: 7 additions & 0 deletions cmp/internal/function/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const (

tbFunc // func(T) bool
ttbFunc // func(T, T) bool
ttiFunc // func(T, T) int
trbFunc // func(T, R) bool
tibFunc // func(T, I) bool
trFunc // func(T) R
Expand All @@ -28,11 +29,13 @@ const (
Transformer = trFunc // func(T) R
ValueFilter = ttbFunc // func(T, T) bool
Less = ttbFunc // func(T, T) bool
Compare = ttiFunc // func(T, T) int
ValuePredicate = tbFunc // func(T) bool
KeyValuePredicate = trbFunc // func(T, R) bool
)

var boolType = reflect.TypeOf(true)
var intType = reflect.TypeOf(0)

// IsType reports whether the reflect.Type is of the specified function type.
func IsType(t reflect.Type, ft funcType) bool {
Expand All @@ -49,6 +52,10 @@ func IsType(t reflect.Type, ft funcType) bool {
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == boolType {
return true
}
case ttiFunc: // func(T, T) int
if ni == 2 && no == 1 && t.In(0) == t.In(1) && t.Out(0) == intType {
return true
}
case trbFunc: // func(T, R) bool
if ni == 2 && no == 1 && t.Out(0) == boolType {
return true
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/google/go-cmp

go 1.13
go 1.21

0 comments on commit d53b1ba

Please sign in to comment.