diff --git a/gen/elem.go b/gen/elem.go index 707b56b..68a3356 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -165,6 +165,7 @@ func (c *common) SetVarname(s string) { c.vname = s } func (c *common) Varname() string { return c.vname } func (c *common) Alias(typ string) { c.alias = typ } func (c *common) SortInterface() string { return "" } +func (c *common) LessFunction() string { return "" } func (c *common) SetAllocBound(s string) { c.allocbound = s } func (c *common) AllocBound() string { return c.allocbound } func (c *common) SetMaxTotalBytes(s string) { c.maxtotalbytes = s } @@ -229,6 +230,9 @@ type Elem interface { // slice of this type. SortInterface() string + // LessFunction returns the Less implementation for values of this type. + LessFunction() string + // Comparable returns whether the type is comparable, along the lines // of the Go spec (https://golang.org/ref/spec#Comparison_operators), // used to determine whether we can compare to a zero value to determine @@ -909,6 +913,14 @@ func (s *BaseElem) SortInterface() string { return "" } +func (s *BaseElem) LessFunction() string { + lessThan, ok := lessFunctions[s.TypeName()] + if ok { + return lessThan + } + return "" +} + func (k Primitive) String() string { switch k { case String: @@ -990,3 +1002,13 @@ func SetSortInterface(sorttype string, sortintf string) { sortInterface[sorttype] = sortintf } + +var lessFunctions map[string]string + +func SetLessFunction(sorttype string, lessfn string) { + if lessFunctions == nil { + lessFunctions = make(map[string]string) + } + + lessFunctions[sorttype] = lessfn +} diff --git a/gen/unmarshal.go b/gen/unmarshal.go index 4ad20f4..b86b2ca 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -60,12 +60,17 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) { u.p.printf("\n return ((*(%s))(%s)).UnmarshalMsg(bts)", baseType, c) u.p.printf("\n}") + u.p.printf("\nfunc (%s %s) UnmarshalValidateMsg(bts []byte) ([]byte, error) {", c, methodRecv) + u.p.printf("\n return ((*(%s))(%s)).UnmarshalValidateMsg(bts)", baseType, c) + u.p.printf("\n}") + u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv) u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv) u.p.printf("\n return ok") u.p.printf("\n}") u.topics.Add(methodRecv, "UnmarshalMsg") + u.topics.Add(methodRecv, "UnmarshalValidateMsg") u.topics.Add(methodRecv, "CanUnmarshalMsg") return u.msgs, u.p.err @@ -75,7 +80,7 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) { c := p.Varname() methodRecv := methodReceiver(p) - u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv) + u.p.printf("\nfunc (%s %s) unmarshalMsg(bts []byte, validate bool) (o []byte, err error) {", c, methodRecv) next(u, p) u.p.print("\no = bts") @@ -91,12 +96,21 @@ func (u *unmarshalGen) Execute(p Elem) ([]string, error) { } u.p.nakedReturn() + u.p.printf("\nfunc (%s %s) UnmarshalMsg(bts []byte) (o []byte, err error) {", c, methodRecv) + u.p.printf("\n return %s.unmarshalMsg(bts, false)", c) + u.p.printf("\n}") + + u.p.printf("\nfunc (%s %s) UnmarshalValidateMsg(bts []byte) (o []byte, err error) {", c, methodRecv) + u.p.printf("\n return %s.unmarshalMsg(bts, true)", c) + u.p.printf("\n}") + u.p.printf("\nfunc (_ %[2]s) CanUnmarshalMsg(%[1]s interface{}) bool {", c, methodRecv) u.p.printf("\n _, ok := (%s).(%s)", c, methodRecv) u.p.printf("\n return ok") u.p.printf("\n}") u.topics.Add(methodRecv, "UnmarshalMsg") + u.topics.Add(methodRecv, "UnmarshalValidateMsg") u.topics.Add(methodRecv, "CanUnmarshalMsg") return u.msgs, u.p.err @@ -144,8 +158,13 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.needsField() sz := randIdent() isnil := randIdent() + last := randIdent() + lastIsSet := randIdent() u.p.declare(sz, "int") + u.p.declare(last, "string") + u.p.declare(lastIsSet, "bool") u.p.declare(isnil, "bool") + u.p.printf("\n_=%s;\n_=%s", last, lastIsSet) // we might not use these for empty structs // go-codec compat: decode an array as sequential elements from this struct, // in the order they are defined in the Go type (as opposed to canonical @@ -155,6 +174,11 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.assignAndCheck(sz, isnil, arrayHeader) + u.p.print("\nif validate {") // map encoded as array => non canonical + u.p.print("\nerr = &msgp.ErrNonCanonical{}") + u.p.print("\nreturn") + u.p.print("\n}") + u.ctx.PushString("struct-from-array") for i := range s.Fields { if !ast.IsExported(s.Fields[i].FieldName) { @@ -195,13 +219,19 @@ func (u *unmarshalGen) mapstruct(s *Struct) { return } u.p.printf("\ncase \"%s\":", s.Fields[i].FieldTag) + u.p.printf("\nif validate && %s && \"%s\" < %s {", lastIsSet, s.Fields[i].FieldTag, last) + u.p.print("\nerr = &msgp.ErrNonCanonical{}") + u.p.printf("\nreturn") + u.p.print("\n}") u.ctx.PushString(s.Fields[i].FieldName) next(u, s.Fields[i].FieldElem) u.ctx.Pop() + u.p.printf("\n%s = \"%s\"", last, s.Fields[i].FieldTag) } u.p.print("\ndefault:\nerr = msgp.ErrNoField(string(field))") u.p.wrapErrCheck(u.ctx.ArgsStr()) u.p.print("\n}") // close switch + u.p.printf("\n%s = true", lastIsSet) u.p.print("\n}") // close for loop u.p.print("\n}") // close else statement for array decode } @@ -325,9 +355,27 @@ func (u *unmarshalGen) gMap(m *Map) { u.msgs = append(u.msgs, resizemsgs...) // loop and get key,value + last := randIdent() + lastSet := randIdent() + u.p.printf("\nvar %s %s; _ = %s", last, m.Key.TypeName(), last) // we might not use the sort if it's not defined + u.p.declare(lastSet, "bool") + u.p.printf("\n_ = %s", lastSet) // we might not use the flag u.p.printf("\nfor %s > 0 {", sz) u.p.printf("\nvar %s %s; var %s %s; %s--", m.Keyidx, m.Key.TypeName(), m.Validx, m.Value.TypeName(), sz) next(u, m.Key) + u.p.printf("\nif validate {") + if m.Key.LessFunction() != "" { + u.p.printf("\nif %s && %s(%s, %s) {", lastSet, m.Key.LessFunction(), m.Keyidx, last) + u.p.printf("\nerr = &msgp.ErrNonCanonical{}") + u.p.printf("\nreturn") + u.p.printf("\n}") + } else { + u.p.printf("\nerr = &msgp.ErrMissingLessFn{}") + u.p.printf("\nreturn") + } + u.p.printf("\n}") // close if validate block + u.p.printf("\n%s=%s", last, m.Keyidx) + u.p.printf("\n%s=true", lastSet) u.ctx.PushVar(m.Keyidx) next(u, m.Value) u.ctx.Pop() diff --git a/msgp/defs.go b/msgp/defs.go index b1188b0..7660d59 100644 --- a/msgp/defs.go +++ b/msgp/defs.go @@ -5,16 +5,19 @@ // generator implement the Marshaler/Unmarshaler and Encodable/Decodable interfaces. // // This package defines four "families" of functions: -// - AppendXxxx() appends an object to a []byte in MessagePack encoding. -// - ReadXxxxBytes() reads an object from a []byte and returns the remaining bytes. -// - (*Writer).WriteXxxx() writes an object to the buffered *Writer type. -// - (*Reader).ReadXxxx() reads an object from a buffered *Reader type. +// - AppendXxxx() appends an object to a []byte in MessagePack encoding. +// - ReadXxxxBytes() reads an object from a []byte and returns the remaining bytes. +// - (*Writer).WriteXxxx() writes an object to the buffered *Writer type. +// - (*Reader).ReadXxxx() reads an object from a buffered *Reader type. // // Once a type has satisfied the `Encodable` and `Decodable` interfaces, // it can be written and read from arbitrary `io.Writer`s and `io.Reader`s using -// msgp.Encode(io.Writer, msgp.Encodable) +// +// msgp.Encode(io.Writer, msgp.Encodable) +// // and -// msgp.Decode(io.Reader, msgp.Decodable) +// +// msgp.Decode(io.Reader, msgp.Decodable) // // There are also methods for converting MessagePack to JSON without // an explicit de-serialization step. @@ -23,6 +26,8 @@ // the wiki at http://github.com/tinylib/msgp package msgp +import "bytes" + const last4 = 0x0f const first4 = 0xf0 const last5 = 0x1f @@ -140,3 +145,62 @@ const ( mmap16 uint8 = 0xde mmap32 uint8 = 0xdf ) + +// The following section defines exported LessFns for built-in types to be used as a convenience +// when using the msgp:sort directive. + +func IntLess(a, b int) bool { + return a < b +} + +func Int8Less(a, b int8) bool { + return a < b +} + +func Int16Less(a, b int16) bool { + return a < b +} + +func Int32Less(a, b int32) bool { + return a < b +} + +func Int64Less(a, b int64) bool { + return a < b +} + +func UintLess(a, b uint) bool { + return a < b +} + +func Uint8Less(a, b uint8) bool { + return a < b +} + +func Uint16Less(a, b uint16) bool { + return a < b +} + +func Uint32Less(a, b uint32) bool { + return a < b +} + +func Uint64Less(a, b uint64) bool { + return a < b +} + +func Float32Less(a, b float32) bool { + return a < b +} + +func Float64Less(a, b float64) bool { + return a < b +} + +func BytesLess(a, b []byte) bool { + return bytes.Compare(a, b) < 0 +} + +func StringLess(a, b string) bool { + return a < b +} diff --git a/msgp/errors.go b/msgp/errors.go index 39d9286..da63cf2 100644 --- a/msgp/errors.go +++ b/msgp/errors.go @@ -81,7 +81,6 @@ func Resumable(e error) bool { // // ErrShortBytes is not wrapped with any context due to backward compatibility // issues with the public API. -// func WrapError(err error, ctx ...interface{}) error { switch e := err.(type) { case errShort: @@ -344,3 +343,29 @@ func (e *ErrUnsupportedType) withContext(ctx string) error { o.ctx = addCtx(o.ctx, ctx) return &o } + +// ErrNonCanonical is returned +// when unmarshaller detects that +// the message is not canonically encoded (pre-sorted) +type ErrNonCanonical struct{} + +// Error implements error +func (e *ErrNonCanonical) Error() string { + return fmt.Sprintf("msgp: non-canonical encoding detected") +} + +// Resumable returns false for errNonCanonical +func (e *ErrNonCanonical) Resumable() bool { return false } + +// ErrNonCanonical is returned +// when unmarshaller detects that +// the message is not canonically encoded (pre-sorted) +type ErrMissingLessFn struct{} + +// Error implements error +func (e *ErrMissingLessFn) Error() string { + return fmt.Sprintf("msgp: can't validate canonicity: missing LessFn") +} + +// Resumable returns false for errNonCanonical +func (e *ErrMissingLessFn) Resumable() bool { return false } diff --git a/msgp/read.go b/msgp/read.go index 1eed864..b99aff8 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -80,3 +80,11 @@ type Unmarshaler interface { UnmarshalMsg([]byte) ([]byte, error) CanUnmarshalMsg(o interface{}) bool } + +// UnmarshalerValidator extends the Unmarshaler interface +// and requires an additional UnmarshalValidateMsg method +// that checks whether the encoded bytes follow canonical encoding rules +type UnmarshalerValidator interface { + Unmarshaler + UnmarshalValidateMsg([]byte) ([]byte, error) +} diff --git a/parse/directives.go b/parse/directives.go index 75d0011..1ce45b2 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -31,12 +31,31 @@ var directives = map[string]directive{ _postunmarshalcheck: postunmarshalcheck, } +// map of base types with predefined LessFunctions used in the sort directive +var lessFns = map[string]string{ + "int": "msgp.IntLess", + "int8": "msgp.Int8Less", + "int16": "msgp.Int16Less", + "int32": "msgp.Int32Less", + "int64": "msgp.Int64Less", + "uint": "msgp.UintLess", + "uint8": "msgp.Uint8Less", + "uint16": "msgp.Uint16Less", + "uint32": "msgp.Uint32Less", + "uint64": "msgp.Uint64Less", + "float32": "msgp.Float32Less", + "float64": "msgp.Float64Less", + "string": "msgp.StringLess", + "[]byte": "msgp.BytesLess", +} + const _postunmarshalcheck = "postunmarshalcheck" var errNotEnoughArguments = errors.New("postunmarshalcheck did not receive enough arguments. expected at least 3") -//msgp:postunmarshalcheck {Type} {funcName} {funcName} ... // the functions should have no params, and output zero. +// +//msgp:postunmarshalcheck {Type} {funcName} {funcName} ... func postunmarshalcheck(text []string, f *FileSet) error { if len(text) < 3 { return errNotEnoughArguments @@ -164,15 +183,25 @@ func astuple(text []string, f *FileSet) error { return nil } -//msgp:sort {Type} {SortInterface} +//msgp:sort {Type} {SortInterface} {LessFunction} func sortintf(text []string, f *FileSet) error { - if len(text) != 3 { + if len(text) != 4 && len(text) != 3 { return nil } sortType := strings.TrimSpace(text[1]) sortIntf := strings.TrimSpace(text[2]) gen.SetSortInterface(sortType, sortIntf) infof("sorting %s using %s\n", sortType, sortIntf) + var lessFn string + if len(text) == 4 { + lessFn = strings.TrimSpace(text[3]) + } else if fn, ok := lessFns[sortType]; ok { + lessFn = fn + } else { + panic(fmt.Sprintf("no default less function for %s and no function is provided", sortType)) + } + gen.SetLessFunction(sortType, lessFn) + infof("less fn %s using %s\n", sortType, lessFn) return nil }