Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reflect generic type #98

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions mockgen/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"io"
"reflect"
"regexp"
"strings"
)

Expand Down Expand Up @@ -415,10 +416,7 @@ func typeFromType(t reflect.Type) (Type, error) {
}

if imp := t.PkgPath(); imp != "" {
return &NamedType{
Package: impPath(imp),
Type: t.Name(),
}, nil
return typeFromNamedType(imp, t.Name()), nil
}

// only unnamed or predeclared types after here
Expand Down Expand Up @@ -502,6 +500,51 @@ func typeFromType(t reflect.Type) (Type, error) {
return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind())
}

var genericRegex = regexp.MustCompile(`([^[\]]+)\[([^]]+(?:\[[^\]]*\])?[^[]*)\]`)

func typeFromNamedType(typePackage string, typeName string) Type {
namedType := &NamedType{
Package: impPath(typePackage),
Type: typeName,
}

match := genericRegex.FindStringSubmatch(typeName)
if len(match) < 3 {
// not a generic type
return namedType
}

namedType.TypeParams = &TypeParametersType{}

// likely a generic type
// e.g. Foo[Bar] / Foo[Baz.Bar] / Foo[Bar, Baz.Bar]
namedType.Type = match[1] // e.g. Foo
for _, typeParam := range strings.Split(match[2], ",") {
typeParam = strings.TrimSpace(typeParam)

var typeParamType Type

packageDotIdx := strings.LastIndex(typeParam, ".")
if packageDotIdx == -1 {
typeParamType = PredeclaredType(typeParam) // e.g. Bar
} else {
typeName := typeParam[packageDotIdx+1:] // e.g. Bar
packageName := typeParam[:packageDotIdx] // e.g. Baz
typeParamType = &NamedType{
Package: impPath(packageName),
Type: typeName,
}
}

namedType.TypeParams.TypeParameters = append(
namedType.TypeParams.TypeParameters,
typeParamType,
)
}

return namedType
}

// impPath sanitizes the package path returned by `PkgPath` method of a reflect Type so that
// it is importable. PkgPath might return a path that includes "vendor". These paths do not
// compile, so we need to remove everything up to and including "/vendor/".
Expand Down
202 changes: 202 additions & 0 deletions mockgen/model/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,205 @@ func TestImpPath(t *testing.T) {
})
}
}

func Test_typeFromNamedType(t *testing.T) {
testCases := []struct {
inputTypePackage string
inputTypeName string
expectedType *NamedType
}{
// not a generic - foo.Bar
{
inputTypePackage: "foo",
inputTypeName: "Bar",
expectedType: &NamedType{
Package: "foo",
Type: "Bar",
},
},
// generic - foo.T[int]
{
inputTypePackage: "foo",
inputTypeName: "T[int]",
expectedType: &NamedType{
Package: "foo",
Type: "T",
TypeParams: &TypeParametersType{
TypeParameters: []Type{
PredeclaredType("int"),
},
},
},
},
// generic - foo.T[int[]]
{
inputTypePackage: "foo",
inputTypeName: "T[int[]]",
expectedType: &NamedType{
Package: "foo",
Type: "T",
TypeParams: &TypeParametersType{
TypeParameters: []Type{
PredeclaredType("int[]"),
},
},
},
},
// FIXME: broken case
// generic - foo.T[int[], bool, string[]]
// {
// inputTypePackage: "foo",
// inputTypeName: "T[int[], bool, string[]]",
// expectedType: &NamedType{
// Package: "foo",
// Type: "T",
// TypeParams: &TypeParametersType{
// TypeParameters: []Type{
// PredeclaredType("int[]"),
// PredeclaredType("bool"),
// PredeclaredType("string[]"),
// },
// },
// },
// },
// generic - foo.T[int, string, int]
{
inputTypePackage: "foo",
inputTypeName: "T[ int,string, int]",
expectedType: &NamedType{
Package: "foo",
Type: "T",
TypeParams: &TypeParametersType{
TypeParameters: []Type{
PredeclaredType("int"),
PredeclaredType("string"),
PredeclaredType("int"),
},
},
},
},
// generic - foo.T[context.Context]
{
inputTypePackage: "foo",
inputTypeName: "T[context.Context]",
expectedType: &NamedType{
Package: "foo",
Type: "T",
TypeParams: &TypeParametersType{
TypeParameters: []Type{
&NamedType{
Package: "context",
Type: "Context",
},
},
},
},
},
// generic - foo.T[context.Context, github.com/foo/bar.X]
{
inputTypePackage: "foo",
inputTypeName: "T[context.Context , github.com/foo/bar.X ]",
expectedType: &NamedType{
Package: "foo",
Type: "T",
TypeParams: &TypeParametersType{
TypeParameters: []Type{
&NamedType{
Package: "context",
Type: "Context",
},
&NamedType{
Package: "github.com/foo/bar",
Type: "X",
},
},
},
},
},
// generic - foo.T[context.Context, github.com/foo/bar.🤣, int]
{
inputTypePackage: "foo",
inputTypeName: "T[context.Context , gtihub.com/foo/bar.🤣, int]",
expectedType: &NamedType{
Package: "foo",
Type: "T",
TypeParams: &TypeParametersType{
TypeParameters: []Type{
&NamedType{
Package: "context",
Type: "Context",
},
&NamedType{
Package: "github.com/foo/bar",
Type: "🤣",
},
PredeclaredType("int"),
},
},
},
},
// FIXME: broken case
// generic - foo.T[bool[], context.Context[], github.com/foo/bar.X[], int[]]
// {
// inputTypePackage: "foo",
// inputTypeName: "foo.T[bool[], context.Context[], github.com/foo/bar.X[], int[]]",
// expectedType: &NamedType{
// Package: "foo",
// Type: "T",
// TypeParams: &TypeParametersType{
// TypeParameters: []Type{
// PredeclaredType("bool[]"),
// &ArrayType{
// Len: -1,
// Type: &NamedType{
// Package: "context",
// Type: "Context",
// },
// },
// &ArrayType{
// Len: -1,
// Type: &NamedType{
// Package: "github.com/foo/bar",
// Type: "X",
// },
// },
// PredeclaredType("int[]"),
// },
// },
// },
// },
}

for idx := range testCases {
tc := testCases[idx]
t.Run(fmt.Sprintf("%s.%s", tc.inputTypePackage, tc.inputTypeName), func(t *testing.T) {
t.Log("input:", tc.inputTypePackage, tc.inputTypeName)

got := typeFromNamedType(tc.inputTypePackage, tc.inputTypeName)
gotNamedType, ok := got.(*NamedType)
if !ok {
t.Errorf("got %T; want *NamedType", got)
}
expected := tc.expectedType
if gotNamedType.Package != expected.Package {
t.Errorf("got %s; want %s", gotNamedType.Package, tc.expectedType.Package)
}
if expected.TypeParams == nil {
if gotNamedType.TypeParams != nil {
t.Errorf("got %s; want nil", gotNamedType.TypeParams)
}
} else {
if gotNamedType.TypeParams == nil {
t.Errorf("got nil; want %s", expected.TypeParams)
}

pm := map[string]string{}
expectedString := expected.String(pm, "")
gotString := gotNamedType.String(pm, "")
if gotString != expectedString {
t.Errorf("got %q; want %q", gotString, expectedString)
}
}
})
}
}