From 4261fcbdbf1dba6c023b983678f425577f9b830a Mon Sep 17 00:00:00 2001 From: Kyle Gray Date: Mon, 5 Aug 2024 13:12:17 -0700 Subject: [PATCH] import: https://github.com/sqlc-dev/sqlc/commit/fe75daef25575b1e7bfc1fd4de60ce4d237901d6 --- internal/opts/options.go | 13 +++++++++++++ internal/result.go | 4 +++- internal/struct.go | 4 ++-- .../templates/go-sql-driver-mysql/copyfromCopy.tmpl | 6 +++--- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/internal/opts/options.go b/internal/opts/options.go index 6833242..65484c2 100644 --- a/internal/opts/options.go +++ b/internal/opts/options.go @@ -43,6 +43,9 @@ type Options struct { OmitSqlcVersion bool `json:"omit_sqlc_version,omitempty" yaml:"omit_sqlc_version"` OmitUnusedStructs bool `json:"omit_unused_structs,omitempty" yaml:"omit_unused_structs"` BuildTags string `json:"build_tags,omitempty" yaml:"build_tags"` + Initialisms *[]string `json:"initialisms,omitempty" yaml:"initialisms"` + + InitialismsMap map[string]struct{} `json:"-" yaml:"-"` } type GlobalOptions struct { @@ -111,6 +114,16 @@ func parseOpts(req *plugin.GenerateRequest) (*Options, error) { *options.QueryParameterLimit = 1 } + if options.Initialisms == nil { + options.Initialisms = new([]string) + *options.Initialisms = []string{"id"} + } + + options.InitialismsMap = map[string]struct{}{} + for _, initial := range *options.Initialisms { + options.InitialismsMap[initial] = struct{}{} + } + return &options, nil } diff --git a/internal/result.go b/internal/result.go index c9ae12a..84ff2fa 100644 --- a/internal/result.go +++ b/internal/result.go @@ -259,7 +259,9 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] EmitPointer: options.EmitParamsStructPointers, } - if len(query.Params) <= qpl { + // if query params is 2, and query params limit is 4 AND this is a copyfrom, we still want to emit the query's model + // otherwise we end up with a copyfrom using a struct without the struct definition + if len(query.Params) <= qpl && query.Cmd != ":copyfrom" { gq.Arg.Emit = false } } diff --git a/internal/struct.go b/internal/struct.go index c747808..ef14af7 100644 --- a/internal/struct.go +++ b/internal/struct.go @@ -32,8 +32,8 @@ func StructName(name string, options *opts.Options) string { }, name) for _, p := range strings.Split(name, "_") { - if p == "id" { - out += "ID" + if _, found := options.InitialismsMap[p]; found { + out += strings.ToUpper(p) } else { out += strings.Title(p) } diff --git a/internal/templates/go-sql-driver-mysql/copyfromCopy.tmpl b/internal/templates/go-sql-driver-mysql/copyfromCopy.tmpl index e6b9061..e21475b 100644 --- a/internal/templates/go-sql-driver-mysql/copyfromCopy.tmpl +++ b/internal/templates/go-sql-driver-mysql/copyfromCopy.tmpl @@ -9,11 +9,11 @@ func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) { {{- with $arg := .Arg }} {{- range $arg.CopyFromMySQLFields}} {{- if eq .Type "string"}} - e.AppendString({{if eq (len $arg.CopyFromMySQLFields) 1}}row{{else}}row.{{.Name}}{{end}}) + e.AppendString({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) {{- else if or (eq .Type "[]byte") (eq .Type "json.RawMessage")}} - e.AppendBytes({{if eq (len $arg.CopyFromMySQLFields) 1}}row{{else}}row.{{.Name}}{{end}}) + e.AppendBytes({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) {{- else}} - e.AppendValue({{if eq (len $arg.CopyFromMySQLFields) 1}}row{{else}}row.{{.Name}}{{end}}) + e.AppendValue({{if $arg.Struct}}row.{{.Name}}{{else}}row{{end}}) {{- end}} {{- end}} {{- end}}