Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnmarshallValidateMsg to generated functions #25

Merged
merged 10 commits into from
Aug 11, 2023
22 changes: 22 additions & 0 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) SortInterface() string { return "" }
func (c *common) LessFunction() string { return "" }
func (c *common) SetAllocBound(s string) { c.allocbound = s }
func (c *common) AllocBound() string { return c.allocbound }
func (c *common) SetMaxTotalBytes(s string) { c.maxtotalbytes = s }
Expand Down Expand Up @@ -229,6 +230,9 @@ type Elem interface {
// slice of this type.
SortInterface() string

// LessFunction returns the Less implementation for values of this type.
LessFunction() string

// Comparable returns whether the type is comparable, along the lines
// of the Go spec (https://golang.org/ref/spec#Comparison_operators),
// used to determine whether we can compare to a zero value to determine
Expand Down Expand Up @@ -909,6 +913,14 @@ func (s *BaseElem) SortInterface() string {
return ""
}

func (s *BaseElem) LessFunction() string {
lessThan, ok := lessFunctions[s.TypeName()]
if ok {
return lessThan
}
return ""
}

func (k Primitive) String() string {
switch k {
case String:
Expand Down Expand Up @@ -990,3 +1002,13 @@ func SetSortInterface(sorttype string, sortintf string) {

sortInterface[sorttype] = sortintf
}

var lessFunctions map[string]string

func SetLessFunction(sorttype string, lessfn string) {
if lessFunctions == nil {
lessFunctions = make(map[string]string)
}

lessFunctions[sorttype] = lessfn
}
45 changes: 44 additions & 1 deletion gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
u.p.printf("\n return ((*(%s))(%s)).UnmarshalMsg(bts)", baseType, c)
u.p.printf("\n}")

u.p.printf("\nfunc (%s %s) UnmarshalValidateMsg(bts []byte) ([]byte, error) {", c, methodRecv)
u.p.printf("\n return ((*(%s))(%s)).UnmarshalValidateMsg(bts)", baseType, c)
u.p.printf("\n}")

u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv)
u.p.printf("\n return ok")
u.p.printf("\n}")

u.topics.Add(methodRecv, "UnmarshalMsg")
u.topics.Add(methodRecv, "UnmarshalValidateMsg")
u.topics.Add(methodRecv, "CanUnmarshalMsg")

return u.msgs, u.p.err
Expand All @@ -75,7 +80,7 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
c := p.Varname()
methodRecv := methodReceiver(p)

u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
u.p.printf("\nfunc (%s %s) unmarshalMsg(bts []byte, validate bool) (o []byte, err error) {", c, methodRecv)
next(u, p)
u.p.print("\no = bts")

Expand All @@ -91,12 +96,21 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) {
}
u.p.nakedReturn()

u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
u.p.printf("\n return %s.unmarshalMsg(bts, false)", c)
u.p.printf("\n}")

u.p.printf("\nfunc (%s %s) UnmarshalValidateMsg(bts []byte) (o []byte, err error) {", c, methodRecv)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a code that errs if LessFunction not set for this type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LessFunction missing is a compile time issue not runtime. Should I just emit a panic inside the code if LessFn is expected and missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see yeah I can just return a nil slice and error since I already have error interface.

u.p.printf("\n return %s.unmarshalMsg(bts, true)", c)
u.p.printf("\n}")

u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv)
u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv)
u.p.printf("\n return ok")
u.p.printf("\n}")

u.topics.Add(methodRecv, "UnmarshalMsg")
u.topics.Add(methodRecv, "UnmarshalValidateMsg")
u.topics.Add(methodRecv, "CanUnmarshalMsg")

return u.msgs, u.p.err
Expand Down Expand Up @@ -144,8 +158,13 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
u.needsField()
sz := randIdent()
isnil := randIdent()
last := randIdent()
lastIsSet := randIdent()
u.p.declare(sz, "int")
u.p.declare(last, "string")
u.p.declare(lastIsSet, "bool")
u.p.declare(isnil, "bool")
u.p.printf("\n_=%s;\n_=%s", last, lastIsSet) // we might not use these for empty structs

// go-codec compat: decode an array as sequential elements from this struct,
// in the order they are defined in the Go type (as opposed to canonical
Expand All @@ -155,6 +174,11 @@ func (u *unmarshalGen) mapstruct(s *Struct) {

u.assignAndCheck(sz, isnil, arrayHeader)

u.p.print("\nif validate {") // map encoded as array => non canonical
u.p.print("\nerr = &msgp.ErrNonCanonical{}")
u.p.print("\nreturn")
u.p.print("\n}")

u.ctx.PushString("struct-from-array")
for i := range s.Fields {
if !ast.IsExported(s.Fields[i].FieldName) {
Expand Down Expand Up @@ -195,13 +219,19 @@ func (u *unmarshalGen) mapstruct(s *Struct) {
return
}
u.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag)
u.p.printf("\nif validate && %s && \"%s\" < %s {", lastIsSet, s.Fields[i].FieldTag, last)
u.p.print("\nerr = &msgp.ErrNonCanonical{}")
u.p.printf("\nreturn")
u.p.print("\n}")
u.ctx.PushString(s.Fields[i].FieldName)
next(u, s.Fields[i].FieldElem)
u.ctx.Pop()
u.p.printf("\n%s = \"%s\"", last, s.Fields[i].FieldTag)
}
u.p.print("\ndefault:\nerr = msgp.ErrNoField(string(field))")
u.p.wrapErrCheck(u.ctx.ArgsStr())
u.p.print("\n}") // close switch
u.p.printf("\n%s = true", lastIsSet)
u.p.print("\n}") // close for loop
u.p.print("\n}") // close else statement for array decode
}
Expand Down Expand Up @@ -325,9 +355,22 @@ func (u *unmarshalGen) gMap(m *Map) {
u.msgs = append(u.msgs, resizemsgs...)

// loop and get key,value
last := randIdent()
lastSet := randIdent()
u.p.printf("\nvar %s %s; _ = %s", last, m.Key.TypeName(), last) // we might not use the sort if it's not defined
algorandskiy marked this conversation as resolved.
Show resolved Hide resolved
u.p.declare(lastSet, "bool")
u.p.printf("\n_ = %s", lastSet) // we might not use the flag
u.p.printf("\nfor %s > 0 {", sz)
u.p.printf("\nvar %s %s; var %s %s; %s--", m.Keyidx, m.Key.TypeName(), m.Validx, m.Value.TypeName(), sz)
next(u, m.Key)
if m.Key.LessFunction() != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do not require less func for all types? since numbers, strings, arrays are comparable, byte slices can be generated with bytes.Compare if m.Key.TypeName() == []bytes - i.e. to make this feature more usable by requiring less work by a user.

Btw, it looks like I can call UnmarshalValidateMsg on a type without providing LessFunction, this code would not be generated and I might get no err even there is an ordering issue.

I guess we need to go one of the roads:

  1. allow UnmarshalValidateMsg only for types with LessFunction provided.
  2. have UnmarshalValidateMsg to generate own code by propagating validate=true/false flag down in generator methods: UnmarshalMsg would generate regular stuff as now without if validate overhead, and UnmarshalValidateMsg would generate code with if validate + err if less function not provided but needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe do not require less func for all types? since numbers, strings, arrays are comparable, byte slices can be generated with bytes.Compare if m.Key.TypeName() == []bytes - i.e. to make this feature more usable by requiring less work by a user.

I agree -- I can make it smarter

I guess we need to go one of the roads:

  1. allow UnmarshalValidateMsg only for types with LessFunction provided.
  2. have UnmarshalValidateMsg to generate own code by propagating validate=true/false flag down in generator methods: UnmarshalMsg would generate regular stuff as now without if validate overhead, and UnmarshalValidateMsg would generate code with if validate + err if less function not provided but needed.

These two don't seem mutually exclusive to me. The first one is a correctness issue and the second seems like a performance concern.

I could fix the issue by including the if validate statement as separate and then inside that block do another if check if we have the LessFn and return an error uncoditionally if LessFn is missing.

u.p.printf("\nif validate && %s && %s(%s, %s) {", lastSet, m.Key.LessFunction(), m.Keyidx, last)
u.p.printf("\nerr = &msgp.ErrNonCanonical{}")
u.p.printf("\nreturn")
u.p.printf("\n}")
}
u.p.printf("\n%s=%s", last, m.Keyidx)
u.p.printf("\n%s=true", lastSet)
u.ctx.PushVar(m.Keyidx)
next(u, m.Value)
u.ctx.Pop()
Expand Down
14 changes: 13 additions & 1 deletion msgp/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ func Resumable(e error) bool {
//
// ErrShortBytes is not wrapped with any context due to backward compatibility
// issues with the public API.
//
func WrapError(err error, ctx ...interface{}) error {
switch e := err.(type) {
case errShort:
Expand Down Expand Up @@ -344,3 +343,16 @@ func (e *ErrUnsupportedType) withContext(ctx string) error {
o.ctx = addCtx(o.ctx, ctx)
return &o
}

// ErrNonCanonical is returned
// when unmarshaller detects that
// the message is not canonically encoded (pre-sorted)
type ErrNonCanonical struct{}

// Error implements error
func (e *ErrNonCanonical) Error() string {
return fmt.Sprintf("msgp: non-canonical encoding detected")
}

// Resumable returns false for errNonCanonical
func (e *ErrNonCanonical) Resumable() bool { return false }
8 changes: 8 additions & 0 deletions msgp/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,11 @@ type Unmarshaler interface {
UnmarshalMsg([]byte) ([]byte, error)
CanUnmarshalMsg(o interface{}) bool
}

// UnmarshalerValidator extends the Unmarshaler interface
// and requires an additional UnmarshalValidateMsg method
// that checks whether the encoded bytes follow canonical encoding rules
type UnmarshalerValidator interface {
Unmarshaler
UnmarshalValidateMsg([]byte) ([]byte, error)
}
7 changes: 5 additions & 2 deletions parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,18 @@ func astuple(text []string, f *FileSet) error {
return nil
}

//msgp:sort {Type} {SortInterface}
//msgp:sort {Type} {SortInterface} {LessFunction}
func sortintf(text []string, f *FileSet) error {
if len(text) != 3 {
if len(text) != 4 {
return nil
}
sortType := strings.TrimSpace(text[1])
sortIntf := strings.TrimSpace(text[2])
lessFn := strings.TrimSpace(text[3])
gen.SetSortInterface(sortType, sortIntf)
infof("sorting %s using %s\n", sortType, sortIntf)
gen.SetLessFunction(sortType, lessFn)
infof("less fn %s using %s\n", sortType, lessFn)
return nil
}

Expand Down