From 58c677e2099b3f86bc8ddb1c9390345d87b14698 Mon Sep 17 00:00:00 2001 From: Phil Kedy Date: Wed, 19 Oct 2022 14:07:49 -0400 Subject: [PATCH] Adding Read/Write Time functions --- decoder.go | 139 +++++++++++++++++++++++++++++++++++++++++++++++++- encoder.go | 87 +++++++++++++++++++++++++++++++ interfaces.go | 8 +++ sizer.go | 57 ++++++++++++++++++++- 4 files changed, 288 insertions(+), 3 deletions(-) diff --git a/decoder.go b/decoder.go index 71e7bd6..e70a4b2 100644 --- a/decoder.go +++ b/decoder.go @@ -1,8 +1,10 @@ package msgpack import ( + "encoding/binary" "math" "strconv" + "time" ) type Decoder struct { @@ -345,6 +347,80 @@ func (d *Decoder) ReadNillableFloat64() (*float64, error) { return &val, err } +func (d *Decoder) ReadTime() (time.Time, error) { + prefix, err := d.reader.PeekUint8() + if err != nil { + return time.Time{}, err + } + + if isString(prefix) { + str, err := d.ReadString() + if err != nil { + return time.Time{}, err + } + return time.Parse(time.RFC3339Nano, str) + } + + d.reader.Discard(1) + extID, extLen, err := d.extHeader(prefix) + if err != nil { + return time.Time{}, err + } + + // NodeJS seems to use extID 13. + if extID != -1 && extID != 13 { + return time.Time{}, ReadError{"msgpack: invalid time ext id=" + strconv.FormatUint(uint64(extID), 10)} + } + + tm, err := d.decodeTime(extLen) + if err != nil { + return tm, err + } + + if tm.IsZero() { + // Zero time does not have timezone information. + return tm.UTC(), nil + } + + return tm, nil +} + +func (d *Decoder) ReadNillableTime() (*time.Time, error) { + isNil, err := d.IsNextNil() + if isNil || err != nil { + return nil, err + } + val, err := d.ReadTime() + if err != nil { + return nil, err + } + return &val, err +} + +func (d *Decoder) decodeTime(extLen uint32) (time.Time, error) { + b, err := d.reader.GetBytes(extLen) + if err != nil { + return time.Time{}, err + } + + switch len(b) { + case 4: + sec := binary.BigEndian.Uint32(b) + return time.Unix(int64(sec), 0), nil + case 8: + sec := binary.BigEndian.Uint64(b) + nsec := int64(sec >> 34) + sec &= 0x00000003ffffffff + return time.Unix(int64(sec), nsec), nil + case 12: + nsec := binary.BigEndian.Uint32(b) + sec := binary.BigEndian.Uint64(b[4:]) + return time.Unix(int64(sec), int64(nsec)), nil + default: + return time.Time{}, ReadError{"msgpack: invalid time ext len=" + strconv.FormatUint(uint64(extLen), 10)} + } +} + func (d *Decoder) ReadString() (string, error) { strLen, err := d.readStringLength() return d.readString(strLen, err) @@ -378,13 +454,14 @@ func (d *Decoder) readStringLength() (uint32, error) { case FormatString8: v, err := d.reader.GetUint8() return uint32(v), err - case FormatString16: + case FormatString16, FormatArray16: v, err := d.reader.GetUint16() return uint32(v), err - case FormatString32: + case FormatString32, FormatArray32: v, err := d.reader.GetUint32() return v, err } + return 0, ReadError{"bad prefix for string length"} } @@ -744,6 +821,54 @@ func (d *Decoder) readMap(m map[any]any, length uint32) error { return nil } +func (d *Decoder) extHeader(c byte) (int8, uint32, error) { + extLen, err := d.parseExtLen(c) + if err != nil { + return 0, 0, err + } + + extID, err := d.readCode() + if err != nil { + return 0, 0, err + } + + return int8(extID), extLen, nil +} + +func (d *Decoder) readCode() (byte, error) { + c, err := d.reader.GetUint8() + if err != nil { + return 0, err + } + return c, nil +} + +func (d *Decoder) parseExtLen(c byte) (uint32, error) { + switch c { + case FormatFixExt1: + return 1, nil + case FormatFixExt2: + return 2, nil + case FormatFixExt4: + return 4, nil + case FormatFixExt8: + return 8, nil + case FormatFixExt16: + return 16, nil + case FormatExt8: + n, err := d.ReadUint8() + return uint32(n), err + case FormatExt16: + n, err := d.ReadUint16() + return uint32(n), err + case FormatExt32: + n, err := d.ReadUint32() + return n, err + default: + return 0, ReadError{"msgpack: invalid code=" + strconv.FormatUint(uint64(c), 16) + " decoding ext len"} + } +} + func (d *Decoder) Err() error { return d.reader.Err() } @@ -797,6 +922,16 @@ func isFixedString(u byte) bool { return (u & 0xe0) == FormatFixString } +func isString(u byte) bool { + return isFixedString(u) || + u == FormatString8 || + u == FormatString16 || + u == FormatString32 || + isFixedArray(u) || + u == FormatArray16 || + u == FormatArray32 +} + type ReadError struct { message string } diff --git a/encoder.go b/encoder.go index 29fa6a0..772bcfa 100644 --- a/encoder.go +++ b/encoder.go @@ -1,7 +1,9 @@ package msgpack import ( + "encoding/binary" "math" + "time" ) type Encoder struct { @@ -215,6 +217,44 @@ func (e *Encoder) WriteNillableString(value *string) { } } +func (e *Encoder) WriteTime(tm time.Time) { + var timeBuf [12]byte + b := e.encodeTime(tm, timeBuf[:]) + e.encodeExtLen(len(b)) + e.reader.SetInt8(-1) + e.reader.SetBytes(b) +} + +func (e *Encoder) WriteNillableTime(value *time.Time) { + if value == nil { + e.WriteNil() + } else { + e.WriteTime(*value) + } +} + +func (e *Encoder) encodeTime(tm time.Time, timeBuf []byte) []byte { + secs := uint64(tm.Unix()) + if secs>>34 == 0 { + data := uint64(tm.Nanosecond())<<34 | secs + + if data&0xffffffff00000000 == 0 { + b := timeBuf[:4] + binary.BigEndian.PutUint32(b, uint32(data)) + return b + } + + b := timeBuf[:8] + binary.BigEndian.PutUint64(b, data) + return b + } + + b := timeBuf[:12] + binary.BigEndian.PutUint32(b, uint32(tm.Nanosecond())) + binary.BigEndian.PutUint64(b[4:], secs) + return b +} + func (e *Encoder) writeBinLength(length uint32) { if length <= math.MaxUint8 { e.reader.SetUint8(FormatBin8) @@ -477,6 +517,53 @@ func (e *Encoder) WriteAny(value any) { } } +func (e *Encoder) encodeExtLen(l int) error { + switch l { + case 1: + return e.reader.SetUint8(FormatFixExt1) + case 2: + return e.reader.SetUint8(FormatFixExt2) + case 4: + return e.reader.SetUint8(FormatFixExt4) + case 8: + return e.reader.SetUint8(FormatFixExt8) + case 16: + return e.reader.SetUint8(FormatFixExt16) + } + if l <= math.MaxUint8 { + return e.write1(FormatExt8, uint8(l)) + } + if l <= math.MaxUint16 { + return e.write2(FormatExt16, uint16(l)) + } + return e.write4(FormatExt32, uint32(l)) +} + +func (e *Encoder) write1(code byte, n uint8) error { + var buf [2]byte + buf[0] = code + buf[1] = n + return e.reader.SetBytes(buf[:]) +} + +func (e *Encoder) write2(code byte, n uint16) error { + var buf [3]byte + buf[0] = code + buf[1] = byte(n >> 8) + buf[2] = byte(n) + return e.reader.SetBytes(buf[:]) +} + +func (e *Encoder) write4(code byte, n uint32) error { + var buf [5]byte + buf[0] = code + buf[1] = byte(n >> 24) + buf[2] = byte(n >> 16) + buf[3] = byte(n >> 8) + buf[4] = byte(n) + return e.reader.SetBytes(buf[:]) +} + func (e *Encoder) Err() error { return e.reader.Err() } diff --git a/interfaces.go b/interfaces.go index 8cae3b5..21e75a2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -1,5 +1,9 @@ package msgpack +import ( + "time" +) + // Reader is the interface for reading data from the MessagePack format. type Reader interface { IsNextNil() (bool, error) @@ -27,6 +31,8 @@ type Reader interface { ReadNillableFloat64() (*float64, error) ReadString() (string, error) ReadNillableString() (*string, error) + ReadTime() (time.Time, error) + ReadNillableTime() (*time.Time, error) ReadByteArray() ([]byte, error) ReadNillableByteArray() ([]byte, error) ReadArraySize() (uint32, error) @@ -63,6 +69,8 @@ type Writer interface { WriteNillableFloat64(value *float64) WriteString(value string) WriteNillableString(value *string) + WriteTime(value time.Time) + WriteNillableTime(value *time.Time) WriteByteArray(value []byte) WriteNillableByteArray(value []byte) WriteArraySize(length uint32) diff --git a/sizer.go b/sizer.go index 06a1253..c1a1dc9 100644 --- a/sizer.go +++ b/sizer.go @@ -1,6 +1,9 @@ package msgpack -import "math" +import ( + "math" + "time" +) type Sizer struct { length uint32 @@ -45,6 +48,50 @@ func (s *Sizer) writeStringLength(length uint32) { } } +func (s *Sizer) WriteTime(value time.Time) { + l := s.encodeTime(value) + s.encodeExtLen(l) + s.length += 1 + uint32(l) +} + +func (s *Sizer) encodeTime(tm time.Time) int { + secs := uint64(tm.Unix()) + if secs>>34 == 0 { + data := uint64(tm.Nanosecond())<<34 | secs + + if data&0xffffffff00000000 == 0 { + return 4 + } + + return 8 + } + + return 12 +} + +func (s *Sizer) encodeExtLen(l int) { + switch l { + case 1, 2, 4, 8, 16: + s.length++ + return + } + if l <= math.MaxUint8 { + s.length += 2 + } else if l <= math.MaxUint16 { + s.length += 3 + } else { + s.length += 5 + } +} + +func (s *Sizer) WriteNillableTime(value *time.Time) { + if value == nil { + s.WriteNil() + } else { + s.WriteTime(*value) + } +} + func (s *Sizer) WriteBool(value bool) { s.length++ } @@ -265,6 +312,8 @@ func (s *Sizer) WriteAny(value any) { s.WriteFloat64(v) case string: s.WriteString(v) + case time.Time: + s.WriteTime(v) case []byte: s.WriteByteArray(v) case []interface{}: @@ -279,6 +328,12 @@ func (s *Sizer) WriteAny(value any) { for _, v := range v { s.WriteString(v) } + case []time.Time: + size := uint32(len(v)) + s.WriteArraySize(size) + for _, v := range v { + s.WriteTime(v) + } case []bool: size := uint32(len(v)) s.WriteArraySize(size)