diff --git a/CHANGELOG.md b/CHANGELOG.md index efe29a9..40e7cc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ - `gopter.GenParameters` now has a `CloneWithSeed(seed int64)` function to temparary copies to create rerunable sections of code. - Added `gopter.Gen.MapResult` for power-user mappings +- Added `gopter.DeriveGen` to derive a generator and it's shrinker from a + bi-directional mapping (`gopter.BiMapper`) ### Changed - Refactored `commands` package under the hood to allow the use of mutable state. diff --git a/bi_mapper.go b/bi_mapper.go new file mode 100644 index 0000000..cb627c9 --- /dev/null +++ b/bi_mapper.go @@ -0,0 +1,107 @@ +package gopter + +import ( + "fmt" + "reflect" +) + +// BiMapper is a bi-directional (or bijective) mapper of a tuple of values (up) +// to another tuple of values (down). +type BiMapper struct { + UpTypes []reflect.Type + DownTypes []reflect.Type + Downstream reflect.Value + Upstream reflect.Value +} + +// NewBiMapper creates a BiMapper of two functions `downstream` and its +// inverse `upstream`. +// That is: The return values of `downstream` must match the parameters of +// `upstream` and vice versa. +func NewBiMapper(downstream interface{}, upstream interface{}) *BiMapper { + downstreamVal := reflect.ValueOf(downstream) + if downstreamVal.Kind() != reflect.Func { + panic("downstream has to be a function") + } + upstreamVal := reflect.ValueOf(upstream) + if upstreamVal.Kind() != reflect.Func { + panic("upstream has to be a function") + } + + downstreamType := downstreamVal.Type() + upTypes := make([]reflect.Type, downstreamType.NumIn()) + for i := 0; i < len(upTypes); i++ { + upTypes[i] = downstreamType.In(i) + } + downTypes := make([]reflect.Type, downstreamType.NumOut()) + for i := 0; i < len(downTypes); i++ { + downTypes[i] = downstreamType.Out(i) + } + + upstreamType := upstreamVal.Type() + if len(upTypes) != upstreamType.NumOut() { + panic(fmt.Sprintf("upstream is expected to have %d return values", len(upTypes))) + } + for i, upType := range upTypes { + if upstreamType.Out(i) != upType { + panic(fmt.Sprintf("upstream has wrong return type %d: %v != %v", i, upstreamType.Out(i), upType)) + } + } + if len(downTypes) != upstreamType.NumIn() { + panic(fmt.Sprintf("upstream is expected to have %d parameters", len(downTypes))) + } + for i, downType := range downTypes { + if upstreamType.In(i) != downType { + panic(fmt.Sprintf("upstream has wrong parameter type %d: %v != %v", i, upstreamType.In(i), downType)) + } + } + + return &BiMapper{ + UpTypes: upTypes, + DownTypes: downTypes, + Downstream: downstreamVal, + Upstream: upstreamVal, + } +} + +func (b *BiMapper) ConvertUp(down []interface{}) []interface{} { + if len(down) != len(b.DownTypes) { + panic(fmt.Sprintf("Expected %d values != %d", len(b.DownTypes), len(down))) + } + downVals := make([]reflect.Value, len(b.DownTypes)) + for i, val := range down { + if val == nil { + downVals[i] = reflect.Zero(b.DownTypes[i]) + } else { + downVals[i] = reflect.ValueOf(val) + } + } + upVals := b.Upstream.Call(downVals) + up := make([]interface{}, len(upVals)) + for i, upVal := range upVals { + up[i] = upVal.Interface() + } + + return up +} + +func (b *BiMapper) ConvertDown(up []interface{}) []interface{} { + if len(up) != len(b.UpTypes) { + panic(fmt.Sprintf("Expected %d values != %d", len(b.UpTypes), len(up))) + } + upVals := make([]reflect.Value, len(b.UpTypes)) + for i, val := range up { + if val == nil { + upVals[i] = reflect.Zero(b.UpTypes[i]) + } else { + upVals[i] = reflect.ValueOf(val) + } + } + downVals := b.Downstream.Call(upVals) + down := make([]interface{}, len(downVals)) + for i, downVal := range downVals { + down[i] = downVal.Interface() + } + + return down +} diff --git a/bi_mapper_test.go b/bi_mapper_test.go new file mode 100644 index 0000000..5a0166c --- /dev/null +++ b/bi_mapper_test.go @@ -0,0 +1,27 @@ +package gopter_test + +import ( + "testing" + + "github.com/leanovate/gopter" +) + +func TestBiMapperParamNotMatch(t *testing.T) { + defer expectPanic(t, "upstream has wrong parameter type 0: string != int") + gopter.NewBiMapper(func(int) int { return 0 }, func(string) int { return 0 }) +} + +func TestBiMapperReturnNotMatch(t *testing.T) { + defer expectPanic(t, "upstream has wrong return type 0: string != int") + gopter.NewBiMapper(func(int) int { return 0 }, func(int) string { return "" }) +} + +func TestBiMapperInvalidDownstream(t *testing.T) { + defer expectPanic(t, "downstream has to be a function") + gopter.NewBiMapper(1, 2) +} + +func TestBiMapperInvalidUpstream(t *testing.T) { + defer expectPanic(t, "upstream has to be a function") + gopter.NewBiMapper(func(int) int { return 0 }, 2) +} diff --git a/derived_gen.go b/derived_gen.go new file mode 100644 index 0000000..2c44798 --- /dev/null +++ b/derived_gen.go @@ -0,0 +1,122 @@ +package gopter + +import ( + "fmt" + "reflect" +) + +type derivedGen struct { + biMapper *BiMapper + upGens []Gen + upSieves []func(interface{}) bool + upShrinker Shrinker + resultType reflect.Type +} + +func (d *derivedGen) Generate(genParams *GenParameters) *GenResult { + labels := []string{} + up := make([]interface{}, len(d.upGens)) + shrinkers := make([]Shrinker, len(d.upGens)) + sieves := make([]func(v interface{}) bool, len(d.upGens)) + + var ok bool + for i, gen := range d.upGens { + result := gen(genParams) + labels = append(labels, result.Labels...) + shrinkers[i] = result.Shrinker + sieves[i] = result.Sieve + up[i], ok = result.Retrieve() + if !ok { + return &GenResult{ + Shrinker: d.Shrinker, + result: nil, + Labels: result.Labels, + ResultType: d.resultType, + Sieve: d.Sieve, + } + } + } + down := d.biMapper.ConvertDown(up) + if len(down) == 1 { + return &GenResult{ + Shrinker: d.Shrinker, + result: down[0], + Labels: labels, + ResultType: reflect.TypeOf(down[0]), + Sieve: d.Sieve, + } + } + return &GenResult{ + Shrinker: d.Shrinker, + result: down, + Labels: labels, + ResultType: reflect.TypeOf(down), + Sieve: d.Sieve, + } +} + +func (d *derivedGen) Sieve(down interface{}) bool { + if down == nil { + return false + } + downs, ok := down.([]interface{}) + if !ok { + downs = []interface{}{down} + } + ups := d.biMapper.ConvertUp(downs) + for i, up := range ups { + if d.upSieves[i] != nil && !d.upSieves[i](up) { + return false + } + } + return true +} + +func (d *derivedGen) Shrinker(down interface{}) Shrink { + downs, ok := down.([]interface{}) + if !ok { + downs = []interface{}{down} + } + ups := d.biMapper.ConvertUp(downs) + upShrink := d.upShrinker(ups) + + return upShrink.Map(func(shrinkedUps []interface{}) interface{} { + downs := d.biMapper.ConvertDown(shrinkedUps) + if len(downs) == 1 { + return downs[0] + } + return downs + }) +} + +// DeriveGen derives a generator with shrinkers from a sequence of other +// generators mapped by a bijective function (BiMapper) +func DeriveGen(downstream interface{}, upstream interface{}, gens ...Gen) Gen { + biMapper := NewBiMapper(downstream, upstream) + + if len(gens) != len(biMapper.UpTypes) { + panic(fmt.Sprintf("Expected %d generators != %d", len(biMapper.UpTypes), len(gens))) + } + + resultType := reflect.TypeOf([]interface{}{}) + if len(biMapper.DownTypes) == 1 { + resultType = biMapper.DownTypes[0] + } + + sieves := make([]func(interface{}) bool, len(gens)) + shrinkers := make([]Shrinker, len(gens)) + for i, gen := range gens { + result := gen(DefaultGenParameters()) + sieves[i] = result.Sieve + shrinkers[i] = result.Shrinker + } + + derived := &derivedGen{ + biMapper: biMapper, + upGens: gens, + upSieves: sieves, + upShrinker: CombineShrinker(shrinkers...), + resultType: resultType, + } + return derived.Generate +} diff --git a/derived_gen_test.go b/derived_gen_test.go new file mode 100644 index 0000000..af7f89b --- /dev/null +++ b/derived_gen_test.go @@ -0,0 +1,186 @@ +package gopter_test + +import ( + "reflect" + "testing" + + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/gen" +) + +type downStruct struct { + a int + b string + c bool +} + +func TestDeriveGenSingleDown(t *testing.T) { + gen := gopter.DeriveGen( + func(a int, b string, c bool) *downStruct { + return &downStruct{a: a, b: b, c: c} + }, + func(d *downStruct) (int, string, bool) { + return d.a, d.b, d.c + }, + gen.Int(), + gen.AnyString(), + gen.Bool(), + ) + + sample, ok := gen.Sample() + if !ok { + t.Error("Sample not ok") + } + _, ok = sample.(*downStruct) + if !ok { + t.Errorf("%#v is not a downStruct", sample) + } + + shrinker := gen(gopter.DefaultGenParameters()).Shrinker + shrink := shrinker(&downStruct{a: 10, b: "abcd", c: false}) + + shrinkedStructs := make([]*downStruct, 0) + value, next := shrink() + for next { + shrinkedStruct, ok := value.(*downStruct) + if !ok { + t.Errorf("Invalid shrinked value: %#v", value) + } + shrinkedStructs = append(shrinkedStructs, shrinkedStruct) + value, next = shrink() + } + + expected := []*downStruct{ + &downStruct{a: 0, b: "abcd", c: false}, + &downStruct{a: 5, b: "abcd", c: false}, + &downStruct{a: -5, b: "abcd", c: false}, + &downStruct{a: 8, b: "abcd", c: false}, + &downStruct{a: -8, b: "abcd", c: false}, + &downStruct{a: 9, b: "abcd", c: false}, + &downStruct{a: -9, b: "abcd", c: false}, + &downStruct{a: 10, b: "cd", c: false}, + &downStruct{a: 10, b: "ab", c: false}, + &downStruct{a: 10, b: "bcd", c: false}, + &downStruct{a: 10, b: "acd", c: false}, + &downStruct{a: 10, b: "abd", c: false}, + &downStruct{a: 10, b: "abc", c: false}, + } + if !reflect.DeepEqual(shrinkedStructs, expected) { + t.Errorf("%v does not equal %v", shrinkedStructs, expected) + } +} + +func TestDeriveGenSingleDownWithSieves(t *testing.T) { + gen := gopter.DeriveGen( + func(a int, b string, c bool) *downStruct { + return &downStruct{a: a, b: b, c: c} + }, + func(d *downStruct) (int, string, bool) { + return d.a, d.b, d.c + }, + gen.Int().SuchThat(func(i int) bool { + return i%2 == 0 + }), + gen.AnyString(), + gen.Bool(), + ) + + parameters := gopter.DefaultGenParameters() + parameters.Rng.Seed(1234) + + hasNoValue := false + for i := 0; i < 100; i++ { + result := gen(parameters) + _, ok := result.Retrieve() + if !ok { + hasNoValue = true + break + } + } + if !hasNoValue { + t.Error("Sieve is not applied") + } + + sieve := gen(parameters).Sieve + + if !sieve(&downStruct{a: 2, b: "something", c: false}) { + t.Error("Sieve did not pass even") + } + + if sieve(&downStruct{a: 3, b: "something", c: false}) { + t.Error("Sieve did pass odd") + } +} + +func TestDeriveGenMultiDown(t *testing.T) { + gen := gopter.DeriveGen( + func(a int, b string, c bool, d int32) (*downStruct, int64) { + return &downStruct{a: a, b: b, c: c}, int64(a) + int64(d) + }, + func(d *downStruct, diff int64) (int, string, bool, int32) { + return d.a, d.b, d.c, int32(diff - int64(d.a)) + }, + gen.Int(), + gen.AnyString(), + gen.Bool(), + gen.Int32(), + ) + + sample, ok := gen.Sample() + if !ok { + t.Error("Sample not ok") + } + values, ok := sample.([]interface{}) + if !ok || len(values) != 2 { + t.Errorf("%#v is not a slice of interface", sample) + } + _, ok = values[0].(*downStruct) + if !ok { + t.Errorf("%#v is not a downStruct", values[0]) + } + _, ok = values[1].(int64) + if !ok { + t.Errorf("%#v is not a int64", values[1]) + } + + shrinker := gen(gopter.DefaultGenParameters()).Shrinker + shrink := shrinker([]interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(20)}) + + value, next := shrink() + shrinkedValues := make([][]interface{}, 0) + for next { + shrinked, ok := value.([]interface{}) + if !ok || len(values) != 2 { + t.Errorf("%#v is not a slice of interface", sample) + } + shrinkedValues = append(shrinkedValues, shrinked) + value, next = shrink() + } + + expected := [][]interface{}{ + []interface{}{&downStruct{a: 0, b: "abcd", c: false}, int64(10)}, + []interface{}{&downStruct{a: 5, b: "abcd", c: false}, int64(15)}, + []interface{}{&downStruct{a: -5, b: "abcd", c: false}, int64(5)}, + []interface{}{&downStruct{a: 8, b: "abcd", c: false}, int64(18)}, + []interface{}{&downStruct{a: -8, b: "abcd", c: false}, int64(2)}, + []interface{}{&downStruct{a: 9, b: "abcd", c: false}, int64(19)}, + []interface{}{&downStruct{a: -9, b: "abcd", c: false}, int64(1)}, + []interface{}{&downStruct{a: 10, b: "cd", c: false}, int64(20)}, + []interface{}{&downStruct{a: 10, b: "ab", c: false}, int64(20)}, + []interface{}{&downStruct{a: 10, b: "bcd", c: false}, int64(20)}, + []interface{}{&downStruct{a: 10, b: "acd", c: false}, int64(20)}, + []interface{}{&downStruct{a: 10, b: "abd", c: false}, int64(20)}, + []interface{}{&downStruct{a: 10, b: "abc", c: false}, int64(20)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(10)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(15)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(5)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(18)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(2)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(19)}, + []interface{}{&downStruct{a: 10, b: "abcd", c: false}, int64(1)}, + } + + if !reflect.DeepEqual(shrinkedValues, expected) { + t.Errorf("%v does not equal %v", shrinkedValues, expected) + } +}