diff --git a/mockgen/model/model.go b/mockgen/model/model.go index e2dde53..f084e62 100644 --- a/mockgen/model/model.go +++ b/mockgen/model/model.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "reflect" + "regexp" "strings" ) @@ -415,10 +416,7 @@ func typeFromType(t reflect.Type) (Type, error) { } if imp := t.PkgPath(); imp != "" { - return &NamedType{ - Package: impPath(imp), - Type: t.Name(), - }, nil + return typeFromNamedType(imp, t.Name()), nil } // only unnamed or predeclared types after here @@ -502,6 +500,51 @@ func typeFromType(t reflect.Type) (Type, error) { return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind()) } +var genericRegex = regexp.MustCompile(`([^[\]]+)\[([^]]+(?:\[[^\]]*\])?[^[]*)\]`) + +func typeFromNamedType(typePackage string, typeName string) Type { + namedType := &NamedType{ + Package: impPath(typePackage), + Type: typeName, + } + + match := genericRegex.FindStringSubmatch(typeName) + if len(match) < 3 { + // not a generic type + return namedType + } + + namedType.TypeParams = &TypeParametersType{} + + // likely a generic type + // e.g. Foo[Bar] / Foo[Baz.Bar] / Foo[Bar, Baz.Bar] + namedType.Type = match[1] // e.g. Foo + for _, typeParam := range strings.Split(match[2], ",") { + typeParam = strings.TrimSpace(typeParam) + + var typeParamType Type + + packageDotIdx := strings.LastIndex(typeParam, ".") + if packageDotIdx == -1 { + typeParamType = PredeclaredType(typeParam) // e.g. Bar + } else { + typeName := typeParam[packageDotIdx+1:] // e.g. Bar + packageName := typeParam[:packageDotIdx] // e.g. Baz + typeParamType = &NamedType{ + Package: impPath(packageName), + Type: typeName, + } + } + + namedType.TypeParams.TypeParameters = append( + namedType.TypeParams.TypeParameters, + typeParamType, + ) + } + + return namedType +} + // impPath sanitizes the package path returned by `PkgPath` method of a reflect Type so that // it is importable. PkgPath might return a path that includes "vendor". These paths do not // compile, so we need to remove everything up to and including "/vendor/". diff --git a/mockgen/model/model_test.go b/mockgen/model/model_test.go index 02ad6be..9164f78 100644 --- a/mockgen/model/model_test.go +++ b/mockgen/model/model_test.go @@ -34,3 +34,205 @@ func TestImpPath(t *testing.T) { }) } } + +func Test_typeFromNamedType(t *testing.T) { + testCases := []struct { + inputTypePackage string + inputTypeName string + expectedType *NamedType + }{ + // not a generic - foo.Bar + { + inputTypePackage: "foo", + inputTypeName: "Bar", + expectedType: &NamedType{ + Package: "foo", + Type: "Bar", + }, + }, + // generic - foo.T[int] + { + inputTypePackage: "foo", + inputTypeName: "T[int]", + expectedType: &NamedType{ + Package: "foo", + Type: "T", + TypeParams: &TypeParametersType{ + TypeParameters: []Type{ + PredeclaredType("int"), + }, + }, + }, + }, + // generic - foo.T[int[]] + { + inputTypePackage: "foo", + inputTypeName: "T[int[]]", + expectedType: &NamedType{ + Package: "foo", + Type: "T", + TypeParams: &TypeParametersType{ + TypeParameters: []Type{ + PredeclaredType("int[]"), + }, + }, + }, + }, + // FIXME: broken case + // generic - foo.T[int[], bool, string[]] + // { + // inputTypePackage: "foo", + // inputTypeName: "T[int[], bool, string[]]", + // expectedType: &NamedType{ + // Package: "foo", + // Type: "T", + // TypeParams: &TypeParametersType{ + // TypeParameters: []Type{ + // PredeclaredType("int[]"), + // PredeclaredType("bool"), + // PredeclaredType("string[]"), + // }, + // }, + // }, + // }, + // generic - foo.T[int, string, int] + { + inputTypePackage: "foo", + inputTypeName: "T[ int,string, int]", + expectedType: &NamedType{ + Package: "foo", + Type: "T", + TypeParams: &TypeParametersType{ + TypeParameters: []Type{ + PredeclaredType("int"), + PredeclaredType("string"), + PredeclaredType("int"), + }, + }, + }, + }, + // generic - foo.T[context.Context] + { + inputTypePackage: "foo", + inputTypeName: "T[context.Context]", + expectedType: &NamedType{ + Package: "foo", + Type: "T", + TypeParams: &TypeParametersType{ + TypeParameters: []Type{ + &NamedType{ + Package: "context", + Type: "Context", + }, + }, + }, + }, + }, + // generic - foo.T[context.Context, github.com/foo/bar.X] + { + inputTypePackage: "foo", + inputTypeName: "T[context.Context , github.com/foo/bar.X ]", + expectedType: &NamedType{ + Package: "foo", + Type: "T", + TypeParams: &TypeParametersType{ + TypeParameters: []Type{ + &NamedType{ + Package: "context", + Type: "Context", + }, + &NamedType{ + Package: "github.com/foo/bar", + Type: "X", + }, + }, + }, + }, + }, + // generic - foo.T[context.Context, github.com/foo/bar.🤣, int] + { + inputTypePackage: "foo", + inputTypeName: "T[context.Context , gtihub.com/foo/bar.🤣, int]", + expectedType: &NamedType{ + Package: "foo", + Type: "T", + TypeParams: &TypeParametersType{ + TypeParameters: []Type{ + &NamedType{ + Package: "context", + Type: "Context", + }, + &NamedType{ + Package: "github.com/foo/bar", + Type: "🤣", + }, + PredeclaredType("int"), + }, + }, + }, + }, + // FIXME: broken case + // generic - foo.T[bool[], context.Context[], github.com/foo/bar.X[], int[]] + // { + // inputTypePackage: "foo", + // inputTypeName: "foo.T[bool[], context.Context[], github.com/foo/bar.X[], int[]]", + // expectedType: &NamedType{ + // Package: "foo", + // Type: "T", + // TypeParams: &TypeParametersType{ + // TypeParameters: []Type{ + // PredeclaredType("bool[]"), + // &ArrayType{ + // Len: -1, + // Type: &NamedType{ + // Package: "context", + // Type: "Context", + // }, + // }, + // &ArrayType{ + // Len: -1, + // Type: &NamedType{ + // Package: "github.com/foo/bar", + // Type: "X", + // }, + // }, + // PredeclaredType("int[]"), + // }, + // }, + // }, + // }, + } + + for idx := range testCases { + tc := testCases[idx] + t.Run(fmt.Sprintf("%s.%s", tc.inputTypePackage, tc.inputTypeName), func(t *testing.T) { + t.Log("input:", tc.inputTypePackage, tc.inputTypeName) + + got := typeFromNamedType(tc.inputTypePackage, tc.inputTypeName) + gotNamedType, ok := got.(*NamedType) + if !ok { + t.Errorf("got %T; want *NamedType", got) + } + expected := tc.expectedType + if gotNamedType.Package != expected.Package { + t.Errorf("got %s; want %s", gotNamedType.Package, tc.expectedType.Package) + } + if expected.TypeParams == nil { + if gotNamedType.TypeParams != nil { + t.Errorf("got %s; want nil", gotNamedType.TypeParams) + } + } else { + if gotNamedType.TypeParams == nil { + t.Errorf("got nil; want %s", expected.TypeParams) + } + + pm := map[string]string{} + expectedString := expected.String(pm, "") + gotString := gotNamedType.String(pm, "") + if gotString != expectedString { + t.Errorf("got %q; want %q", gotString, expectedString) + } + } + }) + } +}