From f1cd809a2ca9c6991ec6f36443b8dc1c4613658a Mon Sep 17 00:00:00 2001 From: Ian Suvak Date: Tue, 13 Jun 2023 09:31:44 -0400 Subject: [PATCH] Implement XyzMaxSize() generator - Adds maxsize.go that generates {typeName}MaxSize() functions - Adds maxtotalbytes annotation that limits the total number of bytes that a struct field can encode to - Adds allocbound handling to string types --- .gitignore | 3 +- gen/elem.go | 40 ++++-- gen/maxsize.go | 332 +++++++++++++++++++++++++------------------- gen/spec.go | 10 +- gen/unmarshal.go | 12 ++ main.go | 2 +- msgp/write.go | 10 ++ parse/directives.go | 2 +- parse/getast.go | 46 +++++- 9 files changed, 293 insertions(+), 164 deletions(-) diff --git a/.gitignore b/.gitignore index 2a4373e..2d19289 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ msgp/cover.out *~ *.coverprofile .idea/ -cover.out \ No newline at end of file +.vscode/ +cover.out diff --git a/gen/elem.go b/gen/elem.go index 2a339ed..707b56b 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -155,20 +155,23 @@ func (c Callback) GetName() string { return c.Fname } // common data/methods for every Elem type common struct { - vname, alias string - allocbound string - callbacks []Callback -} - -func (c *common) SetVarname(s string) { c.vname = s } -func (c *common) Varname() string { return c.vname } -func (c *common) Alias(typ string) { c.alias = typ } -func (c *common) SortInterface() string { return "" } -func (c *common) SetAllocBound(s string) { c.allocbound = s } -func (c *common) AllocBound() string { return c.allocbound } -func (c *common) GetCallbacks() []Callback { return c.callbacks } -func (c *common) AddCallback(cb Callback) { c.callbacks = append(c.callbacks, cb) } -func (c *common) hidden() {} + vname, alias string + allocbound string + maxtotalbytes string + callbacks []Callback +} + +func (c *common) SetVarname(s string) { c.vname = s } +func (c *common) Varname() string { return c.vname } +func (c *common) Alias(typ string) { c.alias = typ } +func (c *common) SortInterface() string { return "" } +func (c *common) SetAllocBound(s string) { c.allocbound = s } +func (c *common) AllocBound() string { return c.allocbound } +func (c *common) SetMaxTotalBytes(s string) { c.maxtotalbytes = s } +func (c *common) MaxTotalBytes() string { return c.maxtotalbytes } +func (c *common) GetCallbacks() []Callback { return c.callbacks } +func (c *common) AddCallback(cb Callback) { c.callbacks = append(c.callbacks, cb) } +func (c *common) hidden() {} func IsDangling(e Elem) bool { if be, ok := e.(*BaseElem); ok && be.Dangling() { @@ -241,6 +244,15 @@ type Elem interface { // when decoding this type. Meaningful for slices and maps. AllocBound() string + // SetMaxTotalBytes specifies the maximum number of bytes to allocate when + // decoding this type. + // Blank means unspecified bound. "-" means no bound. + SetMaxTotalBytes(bound string) + + // MaxTotalBytes specifies the maximum number of bytes to allocate when + // decoding this type. Meaningful for slices of strings or byteslices. + MaxTotalBytes() string + // AddCallback adds to the elem a Callback it should call at the end of marshaling AddCallback(Callback) diff --git a/gen/maxsize.go b/gen/maxsize.go index d3cbb3d..5052249 100644 --- a/gen/maxsize.go +++ b/gen/maxsize.go @@ -1,78 +1,84 @@ package gen import ( + "bytes" "fmt" "go/ast" "io" + "reflect" "strconv" + "strings" "github.com/algorand/msgp/msgp" ) -type sizeState uint8 +type maxSizeState uint8 const ( // need to write "s = ..." - assign sizeState = iota + assignM maxSizeState = iota // need to write "s += ..." - add + addM // can just append "+ ..." - expr + exprM + + multM + // the result is multiplied by whatever is preceeding it ) -func sizes(w io.Writer, topics *Topics) *sizeGen { - return &sizeGen{ +func maxSizes(w io.Writer, topics *Topics) *maxSizeGen { + return &maxSizeGen{ p: printer{w: w}, - state: assign, + state: assignM, topics: topics, } } -type sizeGen struct { +type maxSizeGen struct { passes p printer - state sizeState + state maxSizeState ctx *Context topics *Topics } -func (s *sizeGen) Method() Method { return Size } +func (s *maxSizeGen) Method() Method { return MaxSize } -func (s *sizeGen) Apply(dirs []string) error { +func (s *maxSizeGen) Apply(dirs []string) error { return nil } -func builtinSize(typ string) string { - return "msgp." + typ + "Size" -} - // this lets us chain together addition // operations where possible -func (s *sizeGen) addConstant(sz string) { +func (s *maxSizeGen) addConstant(sz string) { if !s.p.ok() { return } switch s.state { - case assign: + case assignM: s.p.print("\ns = " + sz) - s.state = expr + s.state = exprM return - case add: + case addM: s.p.print("\ns += " + sz) - s.state = expr + s.state = exprM return - case expr: + case exprM: s.p.print(" + " + sz) return + case multM: + s.p.print(" * ( " + sz + ")") + s.state = addM + return } panic("unknown size state") } -func (s *sizeGen) Execute(p Elem) ([]string, error) { +func (s *maxSizeGen) Execute(p Elem) ([]string, error) { if !s.p.ok() { return nil, s.p.err } @@ -85,33 +91,30 @@ func (s *sizeGen) Execute(p Elem) ([]string, error) { // to not affect other code that will use p. p = p.Copy() - s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message") + s.p.comment("MaxSize returns a maximum valid message size for this message type") if IsDangling(p) { baseType := p.(*BaseElem).IdentName - ptrName := p.Varname() - receiver := methodReceiver(p) - s.p.printf("\nfunc (%s %s) Msgsize() int {", ptrName, receiver) - s.p.printf("\n return ((*(%s))(%s)).Msgsize()", baseType, ptrName) + s.p.printf("\nfunc %s int{", getMaxSizeMethod(p.TypeName())) + s.p.printf("\n return %s", getMaxSizeMethod(baseType)) s.p.printf("\n}") - s.topics.Add(receiver, "Msgsize") + s.topics.Add(baseType, getMaxSizeMethod(baseType)) return nil, s.p.err } s.ctx = &Context{} s.ctx.PushString(p.TypeName()) - ptrName := p.Varname() - receiver := imutMethodReceiver(p) - s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ptrName, receiver) - s.state = assign + // receiver := imutMethodReceiver(p) + s.p.printf("\nfunc %s (s int) {", getMaxSizeMethod(p.TypeName())) + s.state = assignM next(s, p) s.p.nakedReturn() - s.topics.Add(receiver, "Msgsize") + s.topics.Add(p.TypeName(), getMaxSizeMethod(p.TypeName())) return nil, s.p.err } -func (s *sizeGen) gStruct(st *Struct) { +func (s *maxSizeGen) gStruct(st *Struct) { if !s.p.ok() { return } @@ -152,174 +155,225 @@ func (s *sizeGen) gStruct(st *Struct) { } } -func (s *sizeGen) gPtr(p *Ptr) { - s.state = add // inner must use add - s.p.printf("\nif %s == nil {\ns += msgp.NilSize\n} else {", p.Varname()) +func (s *maxSizeGen) gPtr(p *Ptr) { + s.state = addM // inner must use add next(s, p.Value) - s.state = add // closing block; reset to add - s.p.closeblock() + s.state = addM // closing block; reset to add } -func (s *sizeGen) gSlice(sl *Slice) { +func (s *maxSizeGen) gSlice(sl *Slice) { if !s.p.ok() { return } + s.state = addM + s.p.comment("Calculating size of slice: " + sl.Varname()) + if (sl.AllocBound() == "" || sl.AllocBound() == "-") && (sl.MaxTotalBytes() == "" || sl.MaxTotalBytes() == "-") { + s.p.printf("\npanic(\"Slice %s is unbounded\")", sl.Varname()) + s.state = addM // reset the add to prevent further + expressions from being added to the end the panic statement + return + } s.addConstant(builtinSize(arrayHeader)) - // if the slice's element is a fixed size - // (e.g. float64, [32]int, etc.), then - // print the length times the element size directly - if str, ok := fixedsizeExpr(sl.Els); ok { - s.addConstant(fmt.Sprintf("(%s * (%s))", lenExpr(sl), str)) + // use maxtotalbytes if it's available + if sl.common.MaxTotalBytes() != "" && sl.common.MaxTotalBytes() != "-" { + s.addConstant(sl.common.MaxTotalBytes()) return } - // add inside the range block, and immediately after - s.state = add - s.p.rangeBlock(s.ctx, sl.Index, sl.Varname(), s, sl.Els) - s.state = add + topLevelAllocBound := sl.AllocBound() + childElement := sl.Els + if sl.Els.AllocBound() == "" && len(strings.Split(sl.AllocBound(), ",")) > 1 { + splitIndex := strings.Index(sl.AllocBound(), ",") + childElement = sl.Els.Copy() + childElement.SetAllocBound(sl.AllocBound()[splitIndex+1:]) + topLevelAllocBound = sl.AllocBound()[:splitIndex] + } + + if str, err := maxSizeExpr(childElement); err == nil { + s.addConstant(fmt.Sprintf("((%s) * (%s))", topLevelAllocBound, str)) + } else { + s.p.printf("\npanic(\"Unable to determine max size: %s\")", err) + } + s.state = addM + return } -func (s *sizeGen) gArray(a *Array) { +func (s *maxSizeGen) gArray(a *Array) { if !s.p.ok() { return } + // If this is not the first line where we define s = ... then we need to reset the state + // to addM so that the comment is printed correctly on a newline + if s.state != assignM { + s.state = addM + } + s.p.comment("Calculating size of array: " + a.Varname()) s.addConstant(builtinSize(arrayHeader)) - // if the array's children are a fixed - // size, we can compile an expression - // that always represents the array's wire size - if str, ok := fixedsizeExpr(a); ok { - s.addConstant(str) - return - } + if str, err := maxSizeExpr(a.Els); err == nil { + s.addConstant(fmt.Sprintf("((%s) * (%s))", a.Size, str)) + } else { + s.p.printf("\npanic(\"Unable to determine max size: %s\")", err) - s.state = add - s.p.rangeBlock(s.ctx, a.Index, a.Varname(), s, a.Els) - s.state = add + } + s.state = addM + return } -func (s *sizeGen) gMap(m *Map) { - s.addConstant(builtinSize(mapHeader)) +func (s *maxSizeGen) gMap(m *Map) { vn := m.Varname() - s.p.printf("\nif %s != nil {", vn) - s.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vn) - s.p.printf("\n_ = %s", m.Keyidx) // we may not use the key - s.p.printf("\n_ = %s", m.Validx) // we may not use the value - s.p.printf("\ns += 0") - s.state = expr - s.ctx.PushVar(m.Keyidx) + s.state = addM + s.addConstant(builtinSize(mapHeader)) + topLevelAllocBound := m.AllocBound() + if topLevelAllocBound != "" && topLevelAllocBound == "-" { + s.p.printf("\npanic(\"Map %s is unbounded\")", m.Varname()) + s.state = addM // reset the add to prevent further + expressions from being added to the end the panic statement + return + } + splitBounds := strings.Split(m.AllocBound(), ",") + if len(splitBounds) > 1 { + topLevelAllocBound = splitBounds[0] + m.Key.SetAllocBound(splitBounds[1]) + if len(splitBounds) > 2 { + m.Value.SetAllocBound(splitBounds[2]) + } + } + + s.p.comment("Adding size of map keys for " + vn) + s.p.printf("\ns += %s", topLevelAllocBound) + s.state = multM next(s, m.Key) + + s.p.comment("Adding size of map values for " + vn) + s.p.printf("\ns += %s", topLevelAllocBound) + s.state = multM next(s, m.Value) - s.ctx.Pop() - s.p.closeblock() - s.p.closeblock() - s.state = add + + s.state = addM } -func (s *sizeGen) gBase(b *BaseElem) { +func (s *maxSizeGen) gBase(b *BaseElem) { if !s.p.ok() { return } + if b.MaxTotalBytes() != "" { + s.p.comment("Using maxtotalbytes for: " + b.Varname()) + s.state = addM + s.addConstant(b.MaxTotalBytes()) + s.state = addM + return + } if b.Convert && b.ShimMode == Convert { - s.state = add + s.state = addM vname := randIdent() s.p.printf("\nvar %s %s", vname, b.BaseType()) // ensure we don't get "unused variable" warnings from outer slice iterations s.p.printf("\n_ = %s", b.Varname()) - s.p.printf("\ns += %s", basesizeExpr(b.Value, vname, b.BaseName())) - s.state = expr + value, err := baseMaxSizeExpr(b.Value, vname, b.BaseName(), b.TypeName(), b.common.AllocBound()) + if err != nil { + s.p.printf("\npanic(\"Unable to determine max size: %s\")", err) + s.state = addM // reset the add to prevent further + expressions from being added to the end the panic statement + return + } + s.p.printf("\ns += %s", value) + s.state = exprM } else { vname := b.Varname() if b.Convert { vname = tobaseConvert(b) } - s.addConstant(basesizeExpr(b.Value, vname, b.BaseName())) + value, err := baseMaxSizeExpr(b.Value, vname, b.BaseName(), b.TypeName(), b.common.AllocBound()) + if err != nil { + s.p.printf("\npanic(\"Unable to determine max size: %s\")", err) + s.state = addM // reset the add to prevent further + expressions from being added to the end the panic statement + return + } + s.addConstant(value) } } -// returns "len(slice)" -func lenExpr(sl *Slice) string { - return "len(" + sl.Varname() + ")" -} - -// is a given primitive always the same (max) -// size on the wire? -func fixedSize(p Primitive) bool { - switch p { - case Intf, Ext, IDENT, Bytes, String: - return false - default: - return true +func baseMaxSizeExpr(value Primitive, vname, basename, typename string, allocbound string) (string, error) { + if typename == "msgp.Raw" { + return "", fmt.Errorf("MaxSize() not implemented for Raw type") } -} - -// strip reference from string -func stripRef(s string) string { - if s[0] == '&' { - return s[1:] + switch value { + case Ext: + return "", fmt.Errorf("MaxSize() not implemented for Ext type") + case Intf: + return "", fmt.Errorf("MaxSize() not implemented for Interfaces") + case IDENT: + return getMaxSizeMethod(typename), nil + case Bytes: + if allocbound == "" || allocbound == "-" { + return "", fmt.Errorf("Byteslice type %s is unbounded", vname) + } + return "msgp.BytesPrefixSize + " + allocbound, nil + case String: + if allocbound == "" || allocbound == "-" { + return "", fmt.Errorf("String type %s is unbounded", vname) + } + return "msgp.StringPrefixSize + " + allocbound, nil + default: + return builtinSize(basename), nil } - return s } // return a fixed-size expression, if possible. -// only possible for *BaseElem and *Array. -// returns (expr, ok) -func fixedsizeExpr(e Elem) (string, bool) { +// only possible for *BaseElem, *Array and Struct. +// returns (expr, err) +func maxSizeExpr(e Elem) (string, error) { switch e := e.(type) { case *Array: - if str, ok := fixedsizeExpr(e.Els); ok { - return fmt.Sprintf("(%s * (%s))", e.Size, str), true + if str, err := maxSizeExpr(e.Els); err == nil { + return fmt.Sprintf("(%s * (%s))", e.Size, str), nil + } else { + return "", err } case *BaseElem: if fixedSize(e.Value) { - return builtinSize(e.BaseName()), true + return builtinSize(e.BaseName()), nil + } else if (e.TypeName()) == "msgp.Raw" { + return "", fmt.Errorf("Raw type is unbounded") + } else if (e.Value) == String { + if e.AllocBound() == "" || e.AllocBound() == "-" { + return "", fmt.Errorf("String type is unbounded for %s", e.Varname()) + } + return fmt.Sprintf("(msgp.StringPrefixSize + %s)", e.AllocBound()), nil + } else if (e.Value) == IDENT { + return fmt.Sprintf("(%s)", getMaxSizeMethod(e.TypeName())), nil + } else if (e.Value) == Bytes { + if e.AllocBound() == "" || e.AllocBound() == "-" { + return "", fmt.Errorf("Inner byteslice type is unbounded") + } + return fmt.Sprintf("(msgp.BytesPrefixSize + %s)", e.AllocBound()), nil } case *Struct: - var str string - for _, f := range e.Fields { - if fs, ok := fixedsizeExpr(f.FieldElem); ok { - if str == "" { - str = fs - } else { - str += "+" + fs - } - } else { - return "", false - } + return fmt.Sprintf("(%s)", getMaxSizeMethod(e.TypeName())), nil + case *Slice: + if e.AllocBound() == "" || e.AllocBound() == "-" { + return "", fmt.Errorf("Slice %s is unbounded", e.Varname()) } - var hdrlen int - mhdr := msgp.AppendMapHeader(nil, uint32(len(e.Fields))) - hdrlen += len(mhdr) - var strbody []byte - for _, f := range e.Fields { - strbody = msgp.AppendString(strbody[:0], f.FieldTag) - hdrlen += len(strbody) + if str, err := maxSizeExpr(e.Els); err == nil { + return fmt.Sprintf("(%s * (%s))", e.AllocBound(), str), nil + } else { + return "", err } - return fmt.Sprintf("%d + %s", hdrlen, str), true } - return "", false + return fmt.Sprintf("%s, %s", e.TypeName(), reflect.TypeOf(e)), nil } -// print size expression of a variable name -func basesizeExpr(value Primitive, vname, basename string) string { - switch value { - case Ext: - return "msgp.ExtensionPrefixSize + " + stripRef(vname) + ".Len()" - case Intf: - return "msgp.GuessSize(" + vname + ")" - case IDENT: - return vname + ".Msgsize()" - case Bytes: - return "msgp.BytesPrefixSize + len(" + vname + ")" - case String: - return "msgp.StringPrefixSize + len(" + vname + ")" - default: - return builtinSize(basename) +func getMaxSizeMethod(typeName string) (s string) { + var pos int + dotIndex := strings.Index(typeName, ".") + if dotIndex != -1 { + pos = dotIndex + 1 } + b := []byte(typeName) + b[pos] = bytes.ToUpper(b)[pos] + return string(b) + "MaxSize()" } diff --git a/gen/spec.go b/gen/spec.go index 12d21d6..5567b72 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -39,11 +39,13 @@ func (m Method) String() string { return "size" case IsZero: return "iszero" + case MaxSize: + return "maxsize" case Test: return "test" default: // return e.g. "marshal+unmarshal+test" - modes := [...]Method{Marshal, Unmarshal, Size, IsZero, Test} + modes := [...]Method{Marshal, Unmarshal, Size, IsZero, MaxSize, Test} any := false nm := "" for _, mm := range modes { @@ -71,6 +73,8 @@ func strtoMeth(s string) Method { return Size case "iszero": return IsZero + case "maxsize": + return MaxSize case "test": return Test default: @@ -84,6 +88,7 @@ const ( Size // msgp.Sizer IsZero // implement MsgIsZero() Test // generate tests + MaxSize // msgp.MaxSize invalidmeth // this isn't a method marshaltest = Marshal | Unmarshal | Test // tests for Marshaler and Unmarshaler ) @@ -109,6 +114,9 @@ func NewPrinter(m Method, topics *Topics, out io.Writer, tests io.Writer) *Print if m.isset(IsZero) { gens = append(gens, isZeros(out, topics)) } + if m.isset(MaxSize) { + gens = append(gens, maxSizes(out, topics)) + } if m.isset(marshaltest) { gens = append(gens, mtest(tests)) } diff --git a/gen/unmarshal.go b/gen/unmarshal.go index d4e9495..4ad20f4 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -237,6 +237,18 @@ func (u *unmarshalGen) gBase(b *BaseElem) { u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered) case IDENT: u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered) + case String: + if b.common.AllocBound() != "" { + sz := randIdent() + u.p.printf("\nvar %s int", sz) + u.p.printf("\n%s, err = msgp.ReadBytesBytesHeader(bts)", sz) + u.p.wrapErrCheck(u.ctx.ArgsStr()) + u.p.printf("\nif %s > %s {", sz, b.common.AllocBound()) + u.p.printf("\nerr = msgp.ErrOverflow(uint64(%s), uint64(%s))", sz, b.common.AllocBound()) + u.p.printf("\nreturn") + u.p.printf("\n}") + } + u.p.printf("\n%s, bts, err = msgp.ReadStringBytes(bts)", refname) default: u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", refname, b.BaseName()) } diff --git a/main.go b/main.go index 5c49d43..8a157d2 100644 --- a/main.go +++ b/main.go @@ -59,7 +59,7 @@ func main() { var mode gen.Method if *marshal { - mode |= (gen.Marshal | gen.Unmarshal | gen.Size | gen.IsZero) + mode |= (gen.Marshal | gen.Unmarshal | gen.Size | gen.IsZero | gen.MaxSize) } if *tests { mode |= gen.Test diff --git a/msgp/write.go b/msgp/write.go index d8f355d..da23cf7 100644 --- a/msgp/write.go +++ b/msgp/write.go @@ -11,6 +11,16 @@ type Sizer interface { Msgsize() int } +// MaxSizer is an interface implemented +// by types that can determine their max +// when implemented. +// This interface is optional, but +// implementations may use this as a way to limit +// number of bytes read during deserialization +type MaxSizer interface { + MaxSize() int +} + // Require ensures that cap(old)-len(old) >= extra. // It might be that this is impossible because len(old)+extra // overflows int. If so, Require will not grow the slice, diff --git a/parse/directives.go b/parse/directives.go index 3c5f558..75d0011 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -27,7 +27,7 @@ var directives = map[string]directive{ "tuple": astuple, "sort": sortintf, "allocbound": allocbound, - // _postunmarshalcheck is used to add callbacks to the end of unmarshling that are tied to a specific Element. + // _postunmarshalcheck is used to add callbacks to the end of un-marshalling that are tied to a specific Element. _postunmarshalcheck: postunmarshalcheck, } diff --git a/parse/getast.go b/parse/getast.go index cd02d28..d0ff380 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -167,10 +167,10 @@ func (f *FileSet) applyDirectives() { // into just one level of indirection. // In other words, if we have: // -// type A uint64 -// type B A -// type C B -// type D C +// type A uint64 +// type B A +// type C B +// type D C // // ... then we want to end up // figuring out that D is just a uint64. @@ -252,6 +252,8 @@ func strToMethod(s string) gen.Method { return gen.Marshal case "unmarshal": return gen.Unmarshal + case "maxsize": + return gen.MaxSize default: return 0 } @@ -407,6 +409,7 @@ func (fs *FileSet) getField(importPrefix string, f *ast.Field) []gen.StructField var extension, flatten bool var allocbound string var allocbounds []string + var maxtotalbytes string // always flatten embedded structs flatten = true @@ -423,6 +426,9 @@ func (fs *FileSet) getField(importPrefix string, f *ast.Field) []gen.StructField if strings.HasPrefix(tag, "allocbound=") { allocbounds = append(allocbounds, strings.Split(tag, "=")[1]) } + if strings.HasPrefix(tag, "maxtotalbytes=") { + maxtotalbytes = strings.Split(tag, "=")[1] + } } // ignore "-" fields if tags[0] == "-" { @@ -472,6 +478,31 @@ func (fs *FileSet) getField(importPrefix string, f *ast.Field) []gen.StructField } return sf } + + // resolve local package type aliases that referenced in this package structs + resolveAlias := func(el gen.Elem) { + if a, ok := fs.Aliases[el.TypeName()]; ok { + if b, ok := a.(*ast.SelectorExpr); ok { + if c, ok := b.X.(*ast.Ident); ok { + el.Alias(c.Name + "." + b.Sel.Name) + } + } else if b, ok := a.(*ast.Ident); ok { + el.Alias(b.Name) + } + } + } + // resolve field alias type + resolveAlias(ex) + // resolve field map type that have alias type key or value + if m, ok := ex.(*gen.Map); ok { + resolveAlias(m.Key) + resolveAlias(m.Value) + } + // resolve field slice type that have alias type element + if m, ok := ex.(*gen.Slice); ok { + resolveAlias(m.Els) + } + sf[0].FieldElem = ex if sf[0].FieldTag == "" { sf[0].FieldTag = sf[0].FieldName @@ -480,6 +511,7 @@ func (fs *FileSet) getField(importPrefix string, f *ast.Field) []gen.StructField sf[0].FieldTagParts = []string{sf[0].FieldName} } sf[0].FieldElem.SetAllocBound(allocbound) + sf[0].FieldElem.SetMaxTotalBytes(maxtotalbytes) // validate extension if extension { @@ -543,9 +575,9 @@ func (fs *FileSet) getFieldsFromEmbeddedStruct(importPrefix string, f ast.Expr) // // so, for a struct like // -// type A struct { -// io.Writer -// } +// type A struct { +// io.Writer +// } // // we want "Writer" func embedded(f ast.Expr) string {