diff --git a/lang/go/context.go b/lang/go/context.go index 52005d5..e39410d 100644 --- a/lang/go/context.go +++ b/lang/go/context.go @@ -60,6 +60,16 @@ type Context interface { // OutputPath returns the output path relative to the plugin's output destination OutputPath(entity pgs.Entity) pgs.FilePath + + // FieldTypeImportPath returns name of the Field type's package as it would appear in + // Go source generated by the official protoc-gen-go plugin. + // For builtin types empty FieldPath will be returned. + FieldTypePackageName(field pgs.Field) pgs.Name + + // FieldTypeImportPath returns the Go import path of the type of the Field + // as it would be included in an import block in a Go file. + // For builtin types empty FieldPath will be returned. + FieldTypeImportPath(field pgs.Field) pgs.FilePath } type context struct{ p pgs.Parameters } diff --git a/lang/go/package.go b/lang/go/package.go index f6b67c5..1653856 100644 --- a/lang/go/package.go +++ b/lang/go/package.go @@ -14,34 +14,6 @@ import ( var nonAlphaNumPattern = regexp.MustCompile("[^a-zA-Z0-9]") -func (c context) PackageName(node pgs.Node) pgs.Name { - e, ok := node.(pgs.Entity) - if !ok { - e = node.(pgs.Package).Files()[0] - } - - _, pkg := c.optionPackage(e) - - // use import_path parameter ONLY if there is no go_package option in the file. - if ip := c.p.Str("import_path"); ip != "" && - e.File().Descriptor().GetOptions().GetGoPackage() == "" { - pkg = ip - } - - // if the package name is a Go keyword, prefix with '_' - if token.Lookup(pkg).IsKeyword() { - pkg = "_" + pkg - } - - // if package starts with digit, prefix with `_` - if r, _ := utf8.DecodeRuneInString(pkg); unicode.IsDigit(r) { - pkg = "_" + pkg - } - - // package name is kosher - return pgs.Name(pkg) -} - func gogoType(f pgs.Field) (pgs.FilePath, TypeName, bool) { ft := f.Type() switch { @@ -106,14 +78,59 @@ func gogoType(f pgs.Field) (pgs.FilePath, TypeName, bool) { return "", TypeName(typeName), true } -func (c gogoContext) PackageName(node pgs.Node) pgs.Name { - f, ok := node.(pgs.Field) +func (c context) PackageName(node pgs.Node) pgs.Name { + e, ok := node.(pgs.Entity) if !ok { - return c.context.PackageName(node) + e = node.(pgs.Package).Files()[0] + } + + _, pkg := c.optionPackage(e) + + // use import_path parameter ONLY if there is no go_package option in the file. + if ip := c.p.Str("import_path"); ip != "" && + e.File().Descriptor().GetOptions().GetGoPackage() == "" { + pkg = ip } + + // if the package name is a Go keyword, prefix with '_' + if token.Lookup(pkg).IsKeyword() { + pkg = "_" + pkg + } + + // if package starts with digit, prefix with `_` + if r, _ := utf8.DecodeRuneInString(pkg); unicode.IsDigit(r) { + pkg = "_" + pkg + } + + // package name is kosher + return pgs.Name(pkg) +} + +func (c context) FieldTypePackageName(f pgs.Field) pgs.Name { + var en pgs.Entity + switch ft := f.Type(); { + case ft.IsEmbed(): + en = ft.Embed() + case ft.IsEnum(): + en = ft.Enum() + case ft.IsRepeated(), ft.IsMap(): + el := ft.Element() + switch { + case el.IsEmbed(): + en = el.Embed() + case el.IsEnum(): + en = el.Enum() + } + default: + return pgs.Name("") + } + return c.PackageName(en) +} + +func (c gogoContext) FieldTypePackageName(f pgs.Field) pgs.Name { pkg, _, ok := gogoType(f) if !ok { - return c.context.PackageName(node) + return c.context.FieldTypePackageName(f) } return pgs.Name(nonAlphaNumPattern.ReplaceAllString(string(pkg), "_")) } @@ -124,14 +141,31 @@ func (c context) ImportPath(e pgs.Entity) pgs.FilePath { return pgs.FilePath(path) } -func (c gogoContext) ImportPath(e pgs.Entity) pgs.FilePath { - f, ok := e.(pgs.Field) - if !ok { - return c.context.ImportPath(e) +func (c context) FieldTypeImportPath(f pgs.Field) pgs.FilePath { + var en pgs.Entity + switch ft := f.Type(); { + case ft.IsEmbed(): + en = ft.Embed() + case ft.IsEnum(): + en = ft.Enum() + case ft.IsRepeated(), ft.IsMap(): + el := ft.Element() + switch { + case el.IsEmbed(): + en = el.Embed() + case el.IsEnum(): + en = el.Enum() + } + default: + return pgs.FilePath("") } + return c.ImportPath(en) +} + +func (c gogoContext) FieldTypeImportPath(f pgs.Field) pgs.FilePath { pkg, _, ok := gogoType(f) if !ok { - return c.context.ImportPath(e) + return c.context.FieldTypeImportPath(f) } return pkg }