Skip to content

Commit

Permalink
fix: Select the corret schemas that correspond to the messages used
Browse files Browse the repository at this point in the history
  • Loading branch information
zaakn committed Nov 9, 2023
1 parent ee84fd2 commit dfd4412
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 12 deletions.
35 changes: 27 additions & 8 deletions cmd/protoc-gen-openapi/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,15 @@ func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
// While we have required schemas left to generate, go through the files again
// looking for the related message and adding them to the document if required.
for len(g.reflect.requiredSchemas) > 0 {
count := len(g.reflect.requiredSchemas)
for _, file := range g.plugin.Files {
g.addSchemasForMessagesToDocumentV3(d, file.Messages)
}
g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
// clear the generated schemas
for schema := range g.reflect.requiredSchemas {
if contains(g.generatedSchemas, schema) {
delete(g.reflect.requiredSchemas, schema)
}
}
}

// If there is only 1 service, then use it's title for the
Expand Down Expand Up @@ -771,12 +775,14 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr
}
}

// addSchemaForMessageToDocumentV3 adds the schema to the document if required
// addSchemaToDocumentV3 adds the schema to the document if required
func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
if contains(g.generatedSchemas, schema.Name) {
return
// check if schema already exists in Schemas, instead of checking "generated"
for _, prop := range d.Components.Schemas.AdditionalProperties {
if prop.Name == schema.Name {
return
}
}
g.generatedSchemas = append(g.generatedSchemas, schema.Name)
d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
}

Expand All @@ -789,12 +795,25 @@ func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, m
}

schemaName := g.reflect.formatMessageName(message.Desc)
fqSchemaName := g.reflect.formatPackageMessageName(message.Desc)

// Only generate this if we need it and haven't already generated it.
if !contains(g.reflect.requiredSchemas, schemaName) ||
contains(g.generatedSchemas, schemaName) {
requiredFQSchema, ok := g.reflect.requiredSchemas[schemaName]
if !ok {
continue
} else if requiredFQSchema != fqSchemaName {
// "schemaName" with same name is required, but it's not the actual
// schema with "fqSchemaName". Try to use the fully-qualified schema.
if _, ok = g.reflect.requiredSchemas[fqSchemaName]; !ok {
continue
}
// use fully-qualified name as schema name if there are same named messages
schemaName = fqSchemaName
}
if contains(g.generatedSchemas, schemaName) {
continue
}
g.generatedSchemas = append(g.generatedSchemas, schemaName)

typeName := g.reflect.fullMessageTypeName(message.Desc)
messageDescription := g.filterCommentString(message.Comments.Leading)
Expand Down
25 changes: 21 additions & 4 deletions cmd/protoc-gen-openapi/generator/reflector.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ const (
type OpenAPIv3Reflector struct {
conf Configuration

requiredSchemas []string // Names of schemas which are used through references.
// Names of schemas which are used through references.
// map: schema name will be used actually -> fully-qualified schema name
requiredSchemas map[string]string
}

// NewOpenAPIv3Reflector creates a new reflector.
func NewOpenAPIv3Reflector(conf Configuration) *OpenAPIv3Reflector {
return &OpenAPIv3Reflector{
conf: conf,

requiredSchemas: make([]string, 0),
requiredSchemas: make(map[string]string, 0),
}
}

Expand Down Expand Up @@ -86,6 +88,14 @@ func (r *OpenAPIv3Reflector) formatMessageName(message protoreflect.MessageDescr
return name
}

// formatPackageMessageName returns the fully-qualified name of a message.
func (r *OpenAPIv3Reflector) formatPackageMessageName(message protoreflect.MessageDescriptor) string {
package_name := string(message.ParentFile().Package())
name := package_name + "." + r.getMessageName(message)

return name
}

func (r *OpenAPIv3Reflector) formatFieldName(field protoreflect.FieldDescriptor) string {
if *r.conf.Naming == "proto" {
return string(field.Name())
Expand Down Expand Up @@ -116,8 +126,15 @@ func (r *OpenAPIv3Reflector) responseContentForMessage(message protoreflect.Mess

func (r *OpenAPIv3Reflector) schemaReferenceForMessage(message protoreflect.MessageDescriptor) string {
schemaName := r.formatMessageName(message)
if !contains(r.requiredSchemas, schemaName) {
r.requiredSchemas = append(r.requiredSchemas, schemaName)
fqSchemaName := r.formatPackageMessageName(message)
requiredFQSchema, ok := r.requiredSchemas[schemaName]
if !ok {
// new required, use schemaName
r.requiredSchemas[schemaName] = fqSchemaName
} else if requiredFQSchema != fqSchemaName {
// use the fully-qualified schema name as there are same named messages
schemaName = fqSchemaName
r.requiredSchemas[schemaName] = fqSchemaName
}
return "#/components/schemas/" + schemaName
}
Expand Down

0 comments on commit dfd4412

Please sign in to comment.