From 1750293d251c8e9dbae21a75c9e5e70fb21f4f09 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 10 Jun 2019 17:10:43 +1000 Subject: [PATCH 01/56] basic text writer implementation --- api.go | 77 ++++++++ decimal.go | 74 +++++++ decimal_test.go | 49 +++++ go.mod | 3 + text/utils.go | 159 +++++++++++++++ text/utils_test.go | 129 +++++++++++++ text/writer.go | 456 ++++++++++++++++++++++++++++++++++++++++++++ text/writer_test.go | 339 ++++++++++++++++++++++++++++++++ 8 files changed, 1286 insertions(+) create mode 100644 api.go create mode 100644 decimal.go create mode 100644 decimal_test.go create mode 100644 go.mod create mode 100644 text/utils.go create mode 100644 text/utils_test.go create mode 100644 text/writer.go create mode 100644 text/writer_test.go diff --git a/api.go b/api.go new file mode 100644 index 00000000..7c043045 --- /dev/null +++ b/api.go @@ -0,0 +1,77 @@ +package ion + +import ( + "math/big" + "time" +) + +// Type is the type of an Ion Value. +type Type uint8 + +const ( + // NullType is the type of the (unqualified) null value. + NullType Type = iota + // BoolType is the type of a boolean, true or false. + BoolType + // IntType is the type of a signed integer of arbitrary size. + IntType + // FloatType is the type of a 64-bit floating-point value. + FloatType + // DecimalType is the type of an arbitrary-precision decimal value. + DecimalType + // TimestampType is the type of a timestamp. + TimestampType + // StringType is the type of a Unicode string. + StringType + // SymbolType is the type of an interned string. + SymbolType + // BlobType is the type of a binary large object. + BlobType + // ClobType is the type of a character large object. + ClobType + // StructType is the type of a structure. + StructType + // ListType is the type of a list. + ListType + // SexpType is the type of an s-expression. + SexpType +) + +// A Writer writes Ion values to an output stream. +type Writer interface { + InStruct() bool + Err() error + + FieldName(val string) + TypeAnnotation(val string) + TypeAnnotations(vals ...string) + + BeginStruct() + EndStruct() + + BeginList() + EndList() + + BeginSexp() + EndSexp() + + WriteNull() + WriteNullWithType(t Type) + + WriteBool(val bool) + + WriteInt(val int64) + WriteBigInt(val *big.Int) + WriteFloat(val float64) + WriteDecimal(val *Decimal) + + WriteTimestamp(val time.Time) + + WriteSymbol(val string) + WriteString(val string) + + WriteBlob(val []byte) + WriteClob(val []byte) + + Finish() error +} diff --git a/decimal.go b/decimal.go new file mode 100644 index 00000000..d6c6c84f --- /dev/null +++ b/decimal.go @@ -0,0 +1,74 @@ +package ion + +import ( + "fmt" + "math/big" + "strings" +) + +// TODO: Precision. + +// Decimal is an arbitrary-precision decimal value. +type Decimal struct { + n *big.Int + scale int +} + +// NewDecimal creates a new (big-integer) decimal. +func NewDecimal(n *big.Int) *Decimal { + return NewDecimalWithScale(n, 0) +} + +// NewDecimalWithScale creates a new scaled decimal whose value is +// equal to n * 10^-scale. +func NewDecimalWithScale(n *big.Int, scale int) *Decimal { + return &Decimal{ + n: n, + scale: scale, + } +} + +// TODO: Maths. + +func (d *Decimal) String() string { + switch { + case d.scale == 0: + // Value is an unscaled integer. + return d.n.String() + "." + + case d.scale < 0: + // Value is a scaled integer, nnndsss. + return d.n.String() + "d" + fmt.Sprintf("%d", -d.scale) + + default: + // Value is a downscaled integer nn.nnd-ss + str := d.n.String() + idx := len(str) - d.scale + + prefix := 1 + if d.n.Sign() < 0 { + // Account for leading '-' + prefix++ + } + + if idx >= prefix { + // Put the decimal point in the middle. + return str[:idx] + "." + str[idx:] + } + + // Put the decimal point at the beginning and + // add a (negative) exponent. + b := strings.Builder{} + b.WriteString(str[:prefix]) + + if len(str) > prefix { + b.WriteString(".") + b.WriteString(str[prefix:]) + } + + b.WriteString("d") + b.WriteString(fmt.Sprintf("%d", idx-prefix)) + + return b.String() + } +} diff --git a/decimal_test.go b/decimal_test.go new file mode 100644 index 00000000..d2fb5b03 --- /dev/null +++ b/decimal_test.go @@ -0,0 +1,49 @@ +package ion + +import ( + "math/big" + "testing" +) + +func TestDecimalToString(t *testing.T) { + test := func(n int64, scale int, expected string) { + t.Run(expected, func(t *testing.T) { + d := Decimal{ + n: big.NewInt(n), + scale: scale, + } + actual := d.String() + if actual != expected { + t.Errorf("expected '%v', got '%v'", expected, actual) + } + }) + } + + test(0, 0, "0.") + test(0, -1, "0d1") + test(0, 1, "0d-1") + + test(1, 0, "1.") + test(1, -1, "1d1") + test(1, 1, "1d-1") + + test(-1, 0, "-1.") + test(-1, -1, "-1d1") + test(-1, 1, "-1d-1") + + test(123, 0, "123.") + test(-456, 0, "-456.") + + test(123, -5, "123d5") + test(-456, -5, "-456d5") + + test(123, 1, "12.3") + test(123, 2, "1.23") + test(123, 3, "1.23d-1") + test(123, 4, "1.23d-2") + + test(-456, 1, "-45.6") + test(-456, 2, "-4.56") + test(-456, 3, "-4.56d-1") + test(-456, 4, "-4.56d-2") +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..1820885b --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/fernomac/ion-go + +go 1.12 diff --git a/text/utils.go b/text/utils.go new file mode 100644 index 00000000..d53c9fca --- /dev/null +++ b/text/utils.go @@ -0,0 +1,159 @@ +package text + +import ( + "io" +) + +func needsQuoting(sym string) bool { + if sym == "" || sym == "null" || sym == "true" || sym == "false" || sym == "nan" { + return true + } + + if isSymbolRef(sym) { + return true + } + + if !isIdentifierStart(sym[0]) { + return true + } + + for i := 1; i < len(sym); i++ { + if !isIdentifierPart(sym[i]) { + return true + } + } + + return false +} + +func isSymbolRef(sym string) bool { + if len(sym) == 0 || sym[0] != '$' { + return false + } + + if len(sym) == 1 { + return false + } + + for i := 1; i < len(sym); i++ { + if !isDigit(sym[i]) { + return false + } + } + + return true +} + +func isIdentifierStart(c byte) bool { + if c >= 'a' && c <= 'z' { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + if c == '_' || c == '$' { + return true + } + return false +} + +func isIdentifierPart(c byte) bool { + return isIdentifierStart(c) || isDigit(c) +} + +func isDigit(c byte) bool { + return c >= '0' && c <= '9' +} + +func writeSymbol(sym string, out io.Writer) error { + if needsQuoting(sym) { + if err := writeChar('\'', out); err != nil { + return err + } + if err := writeEscapedSymbol(sym, out); err != nil { + return err + } + return writeChar('\'', out) + } else { + return writeString(sym, out) + } +} + +func writeEscapedSymbol(sym string, out io.Writer) error { + for i := 0; i < len(sym); i++ { + c := sym[i] + if c < 32 || c == '\\' || c == '\'' { + if err := writeEscapedChar(c, out); err != nil { + return err + } + } else { + if err := writeChar(c, out); err != nil { + return err + } + } + } + return nil +} + +func writeEscapedString(str string, out io.Writer) error { + for i := 0; i>4)&0xF], hexChars[c&0xF]} + return writeChars(buf, out) + } +} + +func writeString(s string, out io.Writer) error { + _, err := out.Write([]byte(s)) + return err +} + +func writeChars(cs []byte, out io.Writer) error { + _, err := out.Write(cs) + return err +} + +func writeChar(c byte, out io.Writer) error { + _, err := out.Write([]byte{c}) + return err +} diff --git a/text/utils_test.go b/text/utils_test.go new file mode 100644 index 00000000..46062b63 --- /dev/null +++ b/text/utils_test.go @@ -0,0 +1,129 @@ +package text + +import ( + "strings" + "testing" +) + +func TestWriteSymbol(t *testing.T) { + test := func(t *testing.T, sym, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeSymbol(sym, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("expected \"%v\", got \"%v\"", expected, actual) + } + }) + } + + test(t, "", "''") + test(t, "null", "'null'") + test(t, "null.null", "'null.null'") + + test(t, "basic", "basic") + test(t, "_basic_", "_basic_") + test(t, "$basic$", "$basic$") + + test(t, "123", "'123'") + test(t, "$123", "'$123'") + test(t, "abc'def", "'abc\\'def'") + test(t, "abc\"def", "'abc\"def'") +} + +func TestNeedsQuoting(t *testing.T) { + test := func(t *testing.T, sym string, expected bool) { + t.Run(sym, func(t *testing.T) { + actual := needsQuoting(sym) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test(t, "", true) + test(t, "null", true) + test(t, "true", true) + test(t, "false", true) + test(t, "nan", true) + + test(t, "basic", false) + test(t, "_basic_", false) + test(t, "basic$123", false) + test(t, "$", false) + test(t, "$basic", false) + + test(t, "123", true) + test(t, "$123", true) + test(t, "abc.def", true) + test(t, "abc,def", true) + test(t, "abc:def", true) + test(t, "abc{def", true) + test(t, "abc}def", true) + test(t, "abc[def", true) + test(t, "abc]def", true) + test(t, "abc'def", true) + test(t, "abc\"def", true) +} + +func TestIsSymbolRef(t *testing.T) { + testIsSymbolRef(t, "", false) + testIsSymbolRef(t, "1", false) + testIsSymbolRef(t, "a", false) + testIsSymbolRef(t, "$", false) + testIsSymbolRef(t, "$1", true) + testIsSymbolRef(t, "$1234567890", true) + testIsSymbolRef(t, "$a", false) + testIsSymbolRef(t, "$1234a567890", false) +} + +func testIsSymbolRef(t *testing.T, sym string, expected bool) { + t.Run(sym, func(t *testing.T) { + actual := isSymbolRef(sym) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) +} + +func TestWriteEscapedSymbol(t *testing.T) { + testWriteEscapedSymbol(t, "basic", "basic") + testWriteEscapedSymbol(t, "\"basic\"", "\"basic\"") + testWriteEscapedSymbol(t, "o'clock", "o\\'clock") + testWriteEscapedSymbol(t, "c:\\", "c:\\\\") +} + +func testWriteEscapedSymbol(t *testing.T, sym, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeEscapedSymbol(sym, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("bad encoding of \"%v\": \"%v\"", expected, actual) + } + }) +} + +func TestWriteEscapedChar(t *testing.T) { + testWriteEscapedChar(t, 0, "\\0") + testWriteEscapedChar(t, '\n', "\\n") + testWriteEscapedChar(t, 1, "\\x01") + testWriteEscapedChar(t, '\xFF', "\\xFF") +} + +func testWriteEscapedChar(t *testing.T, c byte, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeEscapedChar(c, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("bad encoding of '%v': \"%v\"", expected, actual) + } + }) +} diff --git a/text/writer.go b/text/writer.go new file mode 100644 index 00000000..b63a33f0 --- /dev/null +++ b/text/writer.go @@ -0,0 +1,456 @@ +package text + +import ( + "encoding/base64" + "errors" + "fmt" + "io" + "math/big" + "strconv" + "strings" + "time" + + "github.com/fernomac/ion-go" +) + +type contextType uint8 + +const ( + topLevelCtx contextType = iota + inStructCtx + inListCtx + inSexpCtx +) + +type context struct { + value contextType + parent *context +} + +var topLevel = &context{value: topLevelCtx, parent: nil} + +type writer struct { + out io.Writer + ctx *context + err error + + fieldName string + typeAnnotations []string + needsSeparator bool +} + +// NewWriter returns a new text writer. +func NewWriter(out io.Writer) ion.Writer { + return &writer{ + out: out, + ctx: topLevel, + } +} + +func (w *writer) push(t contextType) { + ctx := &context{ + value: t, + parent: w.ctx, + } + w.ctx = ctx +} + +func (w *writer) pop() { + if w.ctx.parent == nil { + panic("pop called at the top level") + } + w.ctx = w.ctx.parent +} + +func (w *writer) InStruct() bool { + return (w.ctx.value == inStructCtx) +} + +func (w *writer) Err() error { + return w.err +} + +func (w *writer) FieldName(val string) { + if w.err != nil { + return + } + if !w.InStruct() { + w.err = errors.New("field name called while not in a struct") + return + } + w.fieldName = val +} + +func (w *writer) TypeAnnotation(val string) { + if w.err != nil { + return + } + w.typeAnnotations = append(w.typeAnnotations, val) +} + +func (w *writer) TypeAnnotations(val ...string) { + if w.err != nil { + return + } + w.typeAnnotations = append(w.typeAnnotations, val...) +} + +func (w *writer) beginValue() error { + if w.needsSeparator { + var sep byte + switch w.ctx.value { + case inStructCtx, inListCtx: + sep = ',' + case inSexpCtx: + sep = ' ' + default: + sep = '\n' + } + + if err := writeChar(sep, w.out); err != nil { + return err + } + } + + if w.InStruct() { + if w.fieldName == "" { + return errors.New("field name not set") + } + name := w.fieldName + w.fieldName = "" + + if err := writeSymbol(name, w.out); err != nil { + return err + } + if err := writeChar(':', w.out); err != nil { + return err + } + } + + if len(w.typeAnnotations) > 0 { + as := w.typeAnnotations + w.typeAnnotations = nil + + for _, a := range as { + if err := writeSymbol(a, w.out); err != nil { + return err + } + if err := writeString("::", w.out); err != nil { + return err + } + } + } + + return nil +} + +func (w *writer) endValue() { + w.needsSeparator = true +} + +func (w *writer) begin(t contextType, c byte) error { + if err := w.beginValue(); err != nil { + return err + } + + w.push(t) + w.needsSeparator = false + + return writeChar(c, w.out) +} + +func (w *writer) end(t contextType, c byte) error { + if w.ctx.value != t { + return errors.New("not in an appropriate container") + } + + if err := writeChar(c, w.out); err != nil { + return err + } + + w.fieldName = "" + w.typeAnnotations = nil + w.pop() + w.endValue() + + return nil +} + +func (w *writer) BeginStruct() { + if w.err != nil { + return + } + w.err = w.begin(inStructCtx, '{') +} + +func (w *writer) EndStruct() { + if w.err != nil { + return + } + w.err = w.end(inStructCtx, '}') +} + +func (w *writer) BeginList() { + if w.err != nil { + return + } + w.err = w.begin(inListCtx, '[') +} + +func (w *writer) EndList() { + if w.err != nil { + return + } + w.err = w.end(inListCtx, ']') +} + +func (w *writer) BeginSexp() { + if w.err != nil { + return + } + w.err = w.begin(inSexpCtx, '(') +} + +func (w *writer) EndSexp() { + if w.err != nil { + return + } + w.err = w.end(inSexpCtx, ')') +} + +func (w *writer) writeValue(f func() string) error { + if err := w.beginValue(); err != nil { + return err + } + + sym := f() + if err := writeString(sym, w.out); err != nil { + return err + } + + w.endValue() + return nil +} + +func (w *writer) writeValueStreaming(f func() error) error { + if err := w.beginValue(); err != nil { + return err + } + + if err := f(); err != nil { + return err + } + + w.endValue() + return nil +} + +func (w *writer) WriteNull() { + w.WriteNullWithType(ion.NullType) +} + +func (w *writer) WriteNullWithType(t ion.Type) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + switch t { + case ion.NullType: + return "null" + case ion.BoolType: + return "null.bool" + case ion.IntType: + return "null.int" + case ion.FloatType: + return "null.float" + case ion.DecimalType: + return "null.decimal" + case ion.TimestampType: + return "null.timestamp" + case ion.StringType: + return "null.string" + case ion.SymbolType: + return "null.symbol" + case ion.BlobType: + return "null.blob" + case ion.ClobType: + return "null.clob" + case ion.StructType: + return "null.struct" + case ion.ListType: + return "null.list" + case ion.SexpType: + return "null.sexp" + default: + panic("invalid type") + } + }) +} + +func symbolForBool(val bool) string { + if val { + return "true" + } + return "false" +} + +func (w *writer) WriteBool(val bool) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + if val { + return "true" + } + return "false" + }) +} + +func (w *writer) WriteInt(val int64) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + return fmt.Sprintf("%d", val) + }) +} + +func (w *writer) WriteBigInt(val *big.Int) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + return val.String() + }) +} + +func (w *writer) WriteFloat(val float64) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + // Built-in go formatting isn't up to the task. :( + str := strconv.FormatFloat(val, 'e', -1, 64) + + switch str { + case "NaN": return "nan" + case "+Inf": return "+inf" + case "-Inf": return "-inf" + default: break + } + + idx := strings.Index(str, "e") + if idx < 0 { + str += "e0" + } else if idx+2 < len(str) && str[idx+2] == '0' { + str = str[:idx+2] + str[idx+3:] + } + + return str + }) +} + +func (w *writer) WriteDecimal(val *ion.Decimal) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + return val.String() + }) +} + +func (w *writer) WriteTimestamp(val time.Time) { + if w.err != nil { + return + } + w.err = w.writeValue(func() string { + return val.Format(time.RFC3339Nano) + }) +} + +func (w *writer) WriteSymbol(val string) { + if w.err != nil { + return + } + w.err = w.writeValueStreaming(func() error { + return writeSymbol(val, w.out) + }) +} + +func (w *writer) WriteString(val string) { + if w.err != nil { + return + } + w.err = w.writeValueStreaming(func() error { + if err := writeChar('"', w.out); err != nil { + return err + } + if err := writeEscapedString(val, w.out); err != nil { + return err + } + return writeChar('"', w.out) + }) +} + +func (w *writer) WriteBlob(val []byte) { + if w.err != nil { + return + } + w.err = w.writeValueStreaming(func() error { + if err := writeString("{{", w.out); err != nil { + return err + } + + enc := base64.NewEncoder(base64.StdEncoding, w.out) + enc.Write(val) + if err := enc.Close(); err != nil { + return err + } + + return writeString("}}", w.out) + }) +} + +func (w *writer) WriteClob(val []byte) { + if w.err != nil { + return + } + w.err = w.writeValueStreaming(func() error { + if err := writeString("{{\"", w.out); err != nil { + return err + } + + for _, c := range val { + if c < 32 || c == '\\' || c == '"' || c > 0x7F { + if err := writeEscapedChar(c, w.out); err != nil { + return err + } + } else { + if err := writeChar(c, w.out); err != nil { + return err + } + } + } + + return writeString("\"}}", w.out) + }) +} + +func (w *writer) Finish() error { + if w.err != nil { + return w.err + } + if w.ctx.value != topLevelCtx { + w.err = errors.New("not at top level") + return w.err + } + + if w.err = writeChar('\n', w.out); w.err != nil { + return w.err + } + + w.fieldName = "" + w.typeAnnotations = nil + w.needsSeparator = false + return nil +} diff --git a/text/writer_test.go b/text/writer_test.go new file mode 100644 index 00000000..df4d9287 --- /dev/null +++ b/text/writer_test.go @@ -0,0 +1,339 @@ +package text + +import ( + "math" + "math/big" + "strings" + "testing" + "time" + + "github.com/fernomac/ion-go" +) + +func TestTopLevelFieldName(t *testing.T) { + writeText(func(w ion.Writer) { + w.FieldName("foo") + if w.Err() == nil { + t.Error("expected an error") + } + }) +} + +func TestEmptyStruct(t *testing.T) { + testTextWriter(t, "{}", func(w ion.Writer) { + if w.InStruct() { + t.Error("already in struct") + } + + w.BeginStruct() + if w.Err() != nil { + t.Fatal(w.Err()) + } + + if !w.InStruct() { + t.Error("not in struct after begin") + } + + w.EndStruct() + if w.Err() != nil { + t.Fatal(w.Err()) + } + + if w.InStruct() { + t.Error("still in struct after end") + } + + w.EndStruct() + if w.Err() == nil { + t.Fatal("no error from ending struct too many times") + } + }) +} + +func TestAnnotatedStruct(t *testing.T) { + testTextWriter(t, "foo::$bar::'.baz'::{}", func(w ion.Writer) { + w.TypeAnnotation("foo") + w.TypeAnnotation("$bar") + w.TypeAnnotation(".baz") + w.BeginStruct() + w.EndStruct() + + if w.Err() != nil { + t.Fatal(w.Err()) + } + }) +} + +func TestNestedStruct(t *testing.T) { + testTextWriter(t, "{foo:'true'::{},'null':{}}", func(w ion.Writer) { + w.BeginStruct() + + w.FieldName("foo") + w.TypeAnnotation("true") + w.BeginStruct() + w.EndStruct() + + w.FieldName("null") + w.BeginStruct() + w.EndStruct() + + w.EndStruct() + }) +} + +func TestEmptyList(t *testing.T) { + testTextWriter(t, "[]", func(w ion.Writer) { + w.BeginList() + if w.Err() != nil { + t.Fatal(w.Err()) + } + + if w.InStruct() { + t.Error("instruct returns true in a list") + } + + w.EndList() + if w.Err() != nil { + t.Fatal(w.Err()) + } + + w.EndList() + if w.Err() == nil { + t.Error("no error calling endlist at top level") + } + }) +} + +func TestNestedLists(t *testing.T) { + testTextWriter(t, "[{},foo::{},'null'::[]]", func(w ion.Writer) { + w.BeginList() + + w.BeginStruct() + w.EndStruct() + + w.TypeAnnotation("foo") + w.BeginStruct() + w.EndStruct() + + w.TypeAnnotation("null") + w.BeginList() + w.EndList() + + w.EndList() + }) +} + +func TestSexps(t *testing.T) { + testTextWriter(t, "()\n(())\n(() ())", func(w ion.Writer) { + w.BeginSexp() + w.EndSexp() + + w.BeginSexp() + w.BeginSexp() + w.EndSexp() + w.EndSexp() + + w.BeginSexp() + w.BeginSexp() + w.EndSexp() + w.BeginSexp() + w.EndSexp() + w.EndSexp() + }) +} + +func TestNull(t *testing.T) { + expected := "[null,foo::null,null.int,bar::null.sexp]" + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginList() + + w.WriteNull() + w.TypeAnnotation("foo") + w.WriteNullWithType(ion.NullType) + w.WriteNullWithType(ion.IntType) + w.TypeAnnotation("bar") + w.WriteNullWithType(ion.SexpType) + + w.EndList() + }) +} + +func TestBool(t *testing.T) { + expected := "true\n(false '123'::true)\n'false'::false" + testTextWriter(t, expected, func(w ion.Writer) { + w.WriteBool(true) + + w.BeginSexp() + + w.WriteBool(false) + w.TypeAnnotation("123"); w.WriteBool(true) + + w.EndSexp() + + w.TypeAnnotation("false"); w.WriteBool(false) + }) +} + +func TestInt(t *testing.T) { + expected := "(zero::0 1 -1 (9223372036854775807 -9223372036854775808))" + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginSexp() + + w.TypeAnnotation("zero"); w.WriteInt(0) + w.WriteInt(1) + w.WriteInt(-1) + + w.BeginSexp() + w.WriteInt(math.MaxInt64) + w.WriteInt(math.MinInt64) + w.EndSexp() + + w.EndSexp() + }) +} + +func TestBigInt(t *testing.T) { + expected := "[0,big::18446744073709551616]" + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginList() + + w.WriteBigInt(big.NewInt(0)) + + var val, max, one big.Int + max.SetUint64(math.MaxUint64) + one.SetInt64(1) + val.Add(&max, &one) + + w.TypeAnnotation("big"); w.WriteBigInt(&val) + + w.EndList() + }) +} + +func TestFloat(t *testing.T) { + expected := "{z:0e+0,nz:-0e+0,s:1.234e+1,l:1.234e-55,n:nan,i:+inf,ni:-inf}" + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginStruct() + + w.FieldName("z"); w.WriteFloat(0.0) + w.FieldName("nz"); w.WriteFloat(-1.0 / math.Inf(1)) + + w.FieldName("s"); w.WriteFloat(12.34) + w.FieldName("l"); w.WriteFloat(12.34e-56) + + w.FieldName("n"); w.WriteFloat(math.NaN()) + w.FieldName("i"); w.WriteFloat(math.Inf(1)) + w.FieldName("ni"); w.WriteFloat(math.Inf(-1)) + + w.EndStruct() + }) +} + +func TestDecimal(t *testing.T) { + expected := "0.\n-1.23d-98" + testTextWriter(t, expected, func(w ion.Writer) { + w.WriteDecimal(ion.NewDecimal(big.NewInt(0))) + w.WriteDecimal(ion.NewDecimalWithScale(big.NewInt(-123), 100)) + }) +} + +func TestTimestamp(t *testing.T) { + expected := "1970-01-01T00:00:00.001Z\n1970-01-01T01:23:00+01:23" + testTextWriter(t, expected, func (w ion.Writer) { + w.WriteTimestamp(time.Unix(0, 1000000).In(time.UTC)) + w.WriteTimestamp(time.Unix(0, 0).In(time.FixedZone("wtf", 4980))) + }) +} + +func TestSymbol(t *testing.T) { + expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸'}" + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginStruct() + + w.FieldName("foo"); w.WriteSymbol("bar") + w.FieldName("empty"); w.WriteSymbol("") + w.FieldName("null"); w.WriteSymbol("null") + + w.FieldName("f"); w.TypeAnnotation("a") + w.TypeAnnotation("b"); w.TypeAnnotation("u") + w.WriteSymbol("lo🇺🇸") + + w.EndStruct() + }) +} + +func TestString(t *testing.T) { + expected := `("hello" "" ("\\\"\n\"\\" zany::"🤪"))` + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginSexp() + w.WriteString("hello") + w.WriteString("") + + w.BeginSexp() + w.WriteString("\\\"\n\"\\") + w.TypeAnnotation("zany"); w.WriteString("🤪") + w.EndSexp() + + w.EndSexp() + }) +} + +func TestBlob(t *testing.T) { + expected := "{{AAEC/f7/}}\n{{SGVsbG8gV29ybGQ=}}\nempty::{{}}" + testTextWriter(t, expected, func(w ion.Writer) { + w.WriteBlob([]byte{ 0, 1, 2, 0xFD, 0xFE, 0xFF }) + w.WriteBlob([]byte("Hello World")) + w.TypeAnnotation("empty"); w.WriteBlob(nil) + }) +} + +func TestClob(t *testing.T) { + expected := "{hello:{{\"world\"}},bits:{{\"\\0\\x01\\xFE\\xFF\"}}}" + testTextWriter(t, expected, func(w ion.Writer) { + w.BeginStruct() + w.FieldName("hello"); w.WriteClob([]byte("world")) + w.FieldName("bits"); w.WriteClob([]byte{0,1,0xFE,0xFF}) + w.EndStruct() + }) +} + +func TestFinish(t *testing.T) { + expected := "1\nfoo\n\"bar\"\n{}\n" + testTextWriter(t, expected, func(w ion.Writer) { + w.WriteInt(1) + w.WriteSymbol("foo") + w.WriteString("bar") + w.BeginStruct(); w.EndStruct() + if err := w.Finish(); err != nil { + t.Fatal(err) + } + }) +} + +func TestBadFinish(t *testing.T) { + buf := strings.Builder{} + w := NewWriter(&buf) + + w.BeginStruct() + err := w.Finish() + + if err == nil { + t.Error("should not be able to finish in the middle of a struct") + } +} + +func testTextWriter(t *testing.T, expected string, f func(ion.Writer)) { + actual := writeText(f) + if actual != expected { + t.Errorf("expected: %v, actual: %v", expected, actual) + } +} + +func writeText(f func(ion.Writer)) string { + buf := strings.Builder{} + w := NewWriter(&buf) + + f(w) + + return buf.String() +} From 7ac7e33652440fe67948754118295a8b9bb5ffc8 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 10 Jun 2019 17:54:35 +1000 Subject: [PATCH 02/56] go fmt --- decimal.go | 4 ++-- decimal_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/decimal.go b/decimal.go index d6c6c84f..77bb45f1 100644 --- a/decimal.go +++ b/decimal.go @@ -10,7 +10,7 @@ import ( // Decimal is an arbitrary-precision decimal value. type Decimal struct { - n *big.Int + n *big.Int scale int } @@ -23,7 +23,7 @@ func NewDecimal(n *big.Int) *Decimal { // equal to n * 10^-scale. func NewDecimalWithScale(n *big.Int, scale int) *Decimal { return &Decimal{ - n: n, + n: n, scale: scale, } } diff --git a/decimal_test.go b/decimal_test.go index d2fb5b03..8e7f9fc6 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -9,7 +9,7 @@ func TestDecimalToString(t *testing.T) { test := func(n int64, scale int, expected string) { t.Run(expected, func(t *testing.T) { d := Decimal{ - n: big.NewInt(n), + n: big.NewInt(n), scale: scale, } actual := d.String() From f8208a51854c89e5e3e37d2bbe5a3a7f9c1f0fb7 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 10 Jun 2019 21:17:53 +1000 Subject: [PATCH 03/56] parsing decimals and some math --- decimal.go | 99 +++++++++++++++++++++++++++++++++++++++- decimal_test.go | 117 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 1 deletion(-) diff --git a/decimal.go b/decimal.go index 77bb45f1..671ff9eb 100644 --- a/decimal.go +++ b/decimal.go @@ -1,11 +1,15 @@ package ion import ( + "errors" "fmt" "math/big" + "strconv" "strings" ) +var ten = big.NewInt(10) + // TODO: Precision. // Decimal is an arbitrary-precision decimal value. @@ -14,7 +18,8 @@ type Decimal struct { scale int } -// NewDecimal creates a new (big-integer) decimal. +// NewDecimal creates a new decimal whose value is equal to the given +// (big) integer. func NewDecimal(n *big.Int) *Decimal { return NewDecimalWithScale(n, 0) } @@ -28,8 +33,100 @@ func NewDecimalWithScale(n *big.Int, scale int) *Decimal { } } +func ParseDecimal(in string) (*Decimal, error) { + if len(in) == 0 { + return nil, errors.New("empty string") + } + + shift := 0 + + d := strings.IndexAny(in, "Dd") + if d != -1 { + // There's an explicit exponent. + exp := in[d+1:] + if len(exp) == 0 { + return nil, errors.New("unexpected end of input after d") + } + + tmp, err := strconv.ParseInt(exp, 10, 32) + if err != nil { + return nil, err + } + + shift = int(tmp) + in = in[:d] + } + + d = strings.Index(in, ".") + if d != -1 { + // There's zero or more decimal places. + ipart := in[:d] + fpart := in[d+1:] + + shift -= len(fpart) + in = ipart + fpart + } + + n, ok := new(big.Int).SetString(in, 10) + if !ok { + // Unfortunately this is all we get? + return nil, errors.New("not a valid number") + } + + return NewDecimalWithScale(n, -shift), nil +} + +func (d *Decimal) Abs() *Decimal { + return &Decimal{ + n: new(big.Int).Abs(d.n), + scale: d.scale, + } +} + +func (d *Decimal) Add(o *Decimal) *Decimal { + a, b := rescale(d, o) + return &Decimal{ + n: new(big.Int).Add(a.n, b.n), + scale: a.scale, + } +} + // TODO: Maths. +func (d *Decimal) Cmp(o *Decimal) int { + a, b := rescale(d, o) + return a.n.Cmp(b.n) +} + +func (d *Decimal) Equal(o *Decimal) bool { + return d.Cmp(o) == 0 +} + +func rescale(a, b *Decimal) (*Decimal, *Decimal) { + if a.scale < b.scale { + return a.upscale(b.scale), b + } else if a.scale > b.scale { + return a, b.upscale(a.scale) + } else { + return a, b + } +} + +func (d *Decimal) upscale(scale int) *Decimal { + diff := int64(scale) - int64(d.scale) + if diff < 0 { + panic("can't upscale to a smaller scale") + } + + pow := new(big.Int).Exp(ten, big.NewInt(diff), nil) + n := new(big.Int).Mul(d.n, pow) + + return &Decimal{ + n: n, + scale: scale, + } +} + func (d *Decimal) String() string { switch { case d.scale == 0: diff --git a/decimal_test.go b/decimal_test.go index 8e7f9fc6..d67ccd4f 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -47,3 +47,120 @@ func TestDecimalToString(t *testing.T) { test(-456, 3, "-4.56d-1") test(-456, 4, "-4.56d-2") } + +func TestParseDecimal(t *testing.T) { + test := func(in string, n *big.Int, scale int) { + t.Run(in, func(t *testing.T) { + d, err := ParseDecimal(in) + if err != nil { + t.Fatal(err) + } + + if n.Cmp(d.n) != 0 { + t.Errorf("wrong n; expected %v, got %v", n, d.n) + } + if scale != d.scale { + t.Errorf("wrong scale; expected %v, got %v", scale, d.scale) + } + }) + } + + test("0", big.NewInt(0), 0) + test("-0", big.NewInt(0), 0) + test("0D0", big.NewInt(0), 0) + test("-0d-1", big.NewInt(0), 1) + + test("1.", big.NewInt(1), 0) + test("1.0", big.NewInt(10), 1) + test("0.123", big.NewInt(123), 3) + + test("1d0", big.NewInt(1), 0) + test("1d1", big.NewInt(1), -1) + test("1d+1", big.NewInt(1), -1) + test("1d-1", big.NewInt(1), 1) + + test("-0.12d4", big.NewInt(-12), -2) +} + +func TestAbs(t *testing.T) { + t.Run("0", func(t *testing.T) { + d := NewDecimal(big.NewInt(0)) + actual := d.Abs().String() + if actual != "0." { + t.Errorf("expected 0., got %v", actual) + } + }) + + t.Run("-1d100", func(t *testing.T) { + d, _ := ParseDecimal("-1d100") + actual := d.Abs().String() + if actual != "1d100" { + t.Errorf("expected 1d100, got %v", actual) + } + }) + + t.Run("-1.2d-3", func(t *testing.T) { + d, _ := ParseDecimal("-1.2d-3") + actual := d.Abs().String() + if actual != "1.2d-3" { + t.Errorf("expected 1.2d-3, got %v", actual) + } + }) +} + +func TestAdd(t *testing.T) { + test := func(a, b, expected string) { + t.Run("("+a+"+"+b+")", func(t *testing.T) { + aa, _ := ParseDecimal(a) + bb, _ := ParseDecimal(b) + ee, _ := ParseDecimal(expected) + + actual := aa.Add(bb) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) + } + }) + } + + test("1", "1", "2") + test("1", "0.1", "1.1") + test("0.3", "0.06", "0.36") + test("1", "100", "101") + test("1d100", "1d98", "101d98") + test("1d-100", "1d-98", "1.01d-98") +} + +func TestCmp(t *testing.T) { + test := func(a, b string, expected int) { + t.Run("("+a+","+b+")", func(t *testing.T) { + ad, _ := ParseDecimal(a) + bd, _ := ParseDecimal(b) + actual := ad.Cmp(bd) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("0", "0", 0) + test("0", "1", -1) + test("0", "-1", 1) + + test("1d2", "100", 0) + test("100", "1d2", 0) + test("1d2", "10", 1) + test("10", "1d2", -1) + + test("0.01", "1d-2", 0) + test("1d-2", "0.01", 0) + test("0.01", "1d-3", 1) + test("1d-3", "0.01", -1) +} + +func TestUpscale(t *testing.T) { + d, _ := ParseDecimal("1d1") + actual := d.upscale(4).String() + if actual != "10.0000" { + t.Errorf("expected 10.0000, got %v", actual) + } +} From 797d9a988b16d2d7a0f37c202ab51b602b2d4139 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 11 Jun 2019 19:28:48 +1000 Subject: [PATCH 04/56] more decimal bikeshedding --- decimal.go | 153 +++++++++++++++++++++++++++++++++++++---- decimal_test.go | 178 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 289 insertions(+), 42 deletions(-) diff --git a/decimal.go b/decimal.go index 671ff9eb..ed7ceded 100644 --- a/decimal.go +++ b/decimal.go @@ -3,14 +3,13 @@ package ion import ( "errors" "fmt" + "math" "math/big" "strconv" "strings" ) -var ten = big.NewInt(10) - -// TODO: Precision. +// TODO: Explicitly track precision? // Decimal is an arbitrary-precision decimal value. type Decimal struct { @@ -33,6 +32,18 @@ func NewDecimalWithScale(n *big.Int, scale int) *Decimal { } } +// MustParseDecimal parses the given string into a decimal object, +// panicing on error. +func MustParseDecimal(in string) *Decimal { + d, err := ParseDecimal(in) + if err != nil { + panic(err) + } + return d +} + +// ParseDecimal parses the given string into a decimal object, +// returning an error on failure. func ParseDecimal(in string) (*Decimal, error) { if len(in) == 0 { return nil, errors.New("empty string") @@ -84,20 +95,84 @@ func (d *Decimal) Abs() *Decimal { } func (d *Decimal) Add(o *Decimal) *Decimal { - a, b := rescale(d, o) + // a*10^x + b*10^y = (a*10^(x-y) + b) * 10^y + dd, oo := rescale(d, o) + return &Decimal{ + n: new(big.Int).Add(dd.n, oo.n), + scale: dd.scale, + } +} + +func (d *Decimal) Sub(o *Decimal) *Decimal { + dd, oo := rescale(d, o) + return &Decimal{ + n: new(big.Int).Sub(dd.n, oo.n), + scale: dd.scale, + } +} + +func (d *Decimal) Neg() *Decimal { return &Decimal{ - n: new(big.Int).Add(a.n, b.n), - scale: a.scale, + n: new(big.Int).Neg(d.n), + scale: d.scale, } } -// TODO: Maths. +// Mul multiplies two decimals and returns the result. +func (d *Decimal) Mul(o *Decimal) *Decimal { + // a*10^x * b*10^y = (a*b) * 10^(x+y) + scale := int64(d.scale) + int64(o.scale) + if scale > math.MaxInt32 || scale < math.MinInt32 { + panic("exponent out of bounds") + } + return &Decimal{ + n: new(big.Int).Mul(d.n, o.n), + scale: int(scale), + } +} + +// ShiftL returns a new decimal shifted the given number of decimal +// places to the left. It's a computationally-cheap way to compute +// d * 10^shift. +func (d *Decimal) ShiftL(shift int) *Decimal { + scale := int64(d.scale) - int64(shift) + if scale > math.MaxInt32 || scale < math.MinInt32 { + panic("exponent out of bounds") + } + + return &Decimal{ + n: d.n, + scale: int(scale), + } +} + +// ShiftR returns a new decimal shifted the given number of decimal +// places to the right. It's a computationally-cheap way to compute +// d / 10^shift. +func (d *Decimal) ShiftR(shift int) *Decimal { + scale := int64(d.scale) + int64(shift) + if scale > math.MaxInt32 || scale < math.MinInt32 { + panic("exponent out of bounds") + } + + return &Decimal{ + n: d.n, + scale: int(scale), + } +} + +// TODO: Div, Exp, etc? + +// Cmp compares two decimals, returning -1 if d is smaller, +1 if d is +// larger, and 0 if they are equal (ignoring precision). func (d *Decimal) Cmp(o *Decimal) int { - a, b := rescale(d, o) - return a.n.Cmp(b.n) + dd, oo := rescale(d, o) + return dd.n.Cmp(oo.n) } +// Equal determines if two decimals are equal (discounting precision, +// at least for now). func (d *Decimal) Equal(o *Decimal) bool { return d.Cmp(o) == 0 } @@ -112,6 +187,12 @@ func rescale(a, b *Decimal) (*Decimal, *Decimal) { } } +var ten = big.NewInt(10) + +// Make 'n' bigger by making 'scale' smaller, since we know we can +// do that. (1d100 -> 10d99). Makes comparisons and math easier, at the +// expense of more storage space. Technically speaking implies adding +// more precision, but we're not tracking that too closely. func (d *Decimal) upscale(scale int) *Decimal { diff := int64(scale) - int64(d.scale) if diff < 0 { @@ -127,29 +208,73 @@ func (d *Decimal) upscale(scale int) *Decimal { } } +// Truncate returns a new decimal, truncated to the given number of +// decimal digits of precision. It does not round, so 19.Truncate(1) +// = 1d1. +func (d *Decimal) Truncate(precision int) *Decimal { + if precision <= 0 { + panic("precision must be positive") + } + + // Is there a better way to calculate precision? It really + // seems like there should be... + + str := d.n.String() + if str[0] == '-' { + // Cheating a bit. + precision++ + } + + diff := len(str) - precision + if diff <= 0 { + // Already small enough, nothing to truncate. + return d + } + + // Lazy man's division by a power of 10. + n, ok := new(big.Int).SetString(str[:precision], 10) + if !ok { + // Should never happen, since we started with a valid int. + panic("failed to parse integer") + } + + scale := int64(d.scale) - int64(diff) + if scale < math.MinInt32 { + panic("exponent out of range") + } + + return &Decimal{ + n: n, + scale: int(scale), + } +} + +// String formats the decimal as a string in Ion text format. func (d *Decimal) String() string { switch { case d.scale == 0: - // Value is an unscaled integer. + // Value is an unscaled integer. Just mark it as a decimal. + // TODO: If there are enough trailing zeros should we knock them + // off and do nnn'd'sss here? That'd technically erase precision. return d.n.String() + "." case d.scale < 0: - // Value is a scaled integer, nnndsss. + // Value is a upscaled integer, nn'd'ss return d.n.String() + "d" + fmt.Sprintf("%d", -d.scale) default: - // Value is a downscaled integer nn.nnd-ss + // Value is a downscaled integer nn.nn('d'-ss)? str := d.n.String() idx := len(str) - d.scale prefix := 1 if d.n.Sign() < 0 { - // Account for leading '-' + // Account for leading '-'. prefix++ } if idx >= prefix { - // Put the decimal point in the middle. + // Put the decimal point in the middle, no exponent. return str[:idx] + "." + str[idx:] } diff --git a/decimal_test.go b/decimal_test.go index d67ccd4f..f690cbdc 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -1,6 +1,7 @@ package ion import ( + "fmt" "math/big" "testing" ) @@ -82,46 +83,118 @@ func TestParseDecimal(t *testing.T) { test("-0.12d4", big.NewInt(-12), -2) } -func TestAbs(t *testing.T) { - t.Run("0", func(t *testing.T) { - d := NewDecimal(big.NewInt(0)) - actual := d.Abs().String() - if actual != "0." { - t.Errorf("expected 0., got %v", actual) +func absF(d *Decimal) *Decimal { return d.Abs() } +func negF(d *Decimal) *Decimal { return d.Neg() } + +type unaryop struct { + sym string + fun func(d *Decimal) *Decimal +} + +var abs = &unaryop{"abs", absF} +var neg = &unaryop{"neg", negF} + +func testUnaryOp(t *testing.T, a, e string, op *unaryop) { + t.Run(op.sym+"("+a+")="+e, func(t *testing.T) { + aa, _ := ParseDecimal(a) + ee, _ := ParseDecimal(e) + actual := op.fun(aa) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) } }) +} + +func TestAbs(t *testing.T) { + test := func(a, e string) { + testUnaryOp(t, a, e, abs) + } + + test("0", "0") + test("1d100", "1d100") + test("-1d100", "1d100") + test("1.2d-3", "1.2d-3") + test("-1.2d-3", "1.2d-3") +} + +func TestNeg(t *testing.T) { + test := func(a, e string) { + testUnaryOp(t, a, e, neg) + } + + test("0", "0") + test("1d100", "-1d100") + test("-1d100", "1d100") + test("1.2d-3", "-1.2d-3") + test("-1.2d-3", "1.2d-3") +} + +func addF(a, b *Decimal) *Decimal { return a.Add(b) } +func subF(a, b *Decimal) *Decimal { return a.Sub(b) } +func mulF(a, b *Decimal) *Decimal { return a.Mul(b) } + +type binop struct { + sym string + fun func(a, b *Decimal) *Decimal +} - t.Run("-1d100", func(t *testing.T) { - d, _ := ParseDecimal("-1d100") - actual := d.Abs().String() - if actual != "1d100" { - t.Errorf("expected 1d100, got %v", actual) +func TestShiftL(t *testing.T) { + test := func(a string, b int, e string) { + aa, _ := ParseDecimal(a) + ee, _ := ParseDecimal(e) + actual := aa.ShiftL(b) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) } - }) + } + + test("0", 10, "0") + test("1", 0, "1") + test("123", 1, "1230") + test("123", 100, "123d100") + test("1.23d-100", 102, "123") +} - t.Run("-1.2d-3", func(t *testing.T) { - d, _ := ParseDecimal("-1.2d-3") - actual := d.Abs().String() - if actual != "1.2d-3" { - t.Errorf("expected 1.2d-3, got %v", actual) +func TestShiftR(t *testing.T) { + test := func(a string, b int, e string) { + aa, _ := ParseDecimal(a) + ee, _ := ParseDecimal(e) + actual := aa.ShiftR(b) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) + } + } + + test("0", 10, "0") + test("1", 0, "1") + test("123", 1, "12.3") + test("123", 100, "1.23d-98") + test("1.23d100", 98, "123") +} + +var add = &binop{"+", addF} +var sub = &binop{"-", subF} +var mul = &binop{"*", mulF} + +func testBinaryOp(t *testing.T, a, b, e string, op *binop) { + t.Run(a+op.sym+b+"="+e, func(t *testing.T) { + aa, _ := ParseDecimal(a) + bb, _ := ParseDecimal(b) + ee, _ := ParseDecimal(e) + + actual := op.fun(aa, bb) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) } }) } func TestAdd(t *testing.T) { - test := func(a, b, expected string) { - t.Run("("+a+"+"+b+")", func(t *testing.T) { - aa, _ := ParseDecimal(a) - bb, _ := ParseDecimal(b) - ee, _ := ParseDecimal(expected) - - actual := aa.Add(bb) - if !actual.Equal(ee) { - t.Errorf("expected %v, got %v", ee, actual) - } - }) + test := func(a, b, e string) { + testBinaryOp(t, a, b, e, add) } + test("1", "0", "1") test("1", "1", "2") test("1", "0.1", "1.1") test("0.3", "0.06", "0.36") @@ -130,6 +203,55 @@ func TestAdd(t *testing.T) { test("1d-100", "1d-98", "1.01d-98") } +func TestSub(t *testing.T) { + test := func(a, b, e string) { + testBinaryOp(t, a, b, e, sub) + } + + test("1", "0", "1") + test("1", "1", "0") + test("1", "0.1", "0.9") + test("0.3", "0.06", "0.24") + test("1", "100", "-99") + test("1d100", "1d98", "99d98") + test("1d-100", "1d-98", "-99d-100") +} + +func TestMul(t *testing.T) { + test := func(a, b, e string) { + testBinaryOp(t, a, b, e, mul) + } + + test("1", "0", "0") + test("1", "1", "1") + test("2", "-1", "-2") + test("7", "6", "42") + test("10", "0.3", "3") + test("3d100", "2d50", "6d150") + test("3d-100", "2d-50", "6d-150") + test("2d100", "4d-98", "8d2") +} + +func TestTruncate(t *testing.T) { + test := func(a string, p int, expected string) { + t.Run(fmt.Sprintf("trunc(%v,%v)", a, p), func(t *testing.T) { + aa := MustParseDecimal(a) + actual := aa.Truncate(p).String() + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("1", 1, "1.") + test("1", 10, "1.") + test("10", 1, "1d1") + test("1999", 1, "1d3") + test("1.2345", 3, "1.23") + test("100d100", 2, "10d101") + test("1.2345d-100", 2, "1.2d-100") +} + func TestCmp(t *testing.T) { test := func(a, b string, expected int) { t.Run("("+a+","+b+")", func(t *testing.T) { From 46c843fcd17b56be364666e12a43d23822c87c21 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 11 Jun 2019 21:26:33 +1000 Subject: [PATCH 05/56] refactorizing to avoid circular references --- text/utils_test.go | 129 ------------ text/utils.go => textutils.go | 58 +++--- textutils_test.go | 131 ++++++++++++ text/writer.go => textwriter.go | 241 +++++++++------------- text/writer_test.go => textwriter_test.go | 60 +++--- writer.go | 97 +++++++++ 6 files changed, 389 insertions(+), 327 deletions(-) delete mode 100644 text/utils_test.go rename text/utils.go => textutils.go (57%) create mode 100644 textutils_test.go rename text/writer.go => textwriter.go (51%) rename text/writer_test.go => textwriter_test.go (79%) create mode 100644 writer.go diff --git a/text/utils_test.go b/text/utils_test.go deleted file mode 100644 index 46062b63..00000000 --- a/text/utils_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package text - -import ( - "strings" - "testing" -) - -func TestWriteSymbol(t *testing.T) { - test := func(t *testing.T, sym, expected string) { - t.Run(expected, func(t *testing.T) { - buf := strings.Builder{} - if err := writeSymbol(sym, &buf); err != nil { - t.Fatal(err) - } - actual := buf.String() - if actual != expected { - t.Errorf("expected \"%v\", got \"%v\"", expected, actual) - } - }) - } - - test(t, "", "''") - test(t, "null", "'null'") - test(t, "null.null", "'null.null'") - - test(t, "basic", "basic") - test(t, "_basic_", "_basic_") - test(t, "$basic$", "$basic$") - - test(t, "123", "'123'") - test(t, "$123", "'$123'") - test(t, "abc'def", "'abc\\'def'") - test(t, "abc\"def", "'abc\"def'") -} - -func TestNeedsQuoting(t *testing.T) { - test := func(t *testing.T, sym string, expected bool) { - t.Run(sym, func(t *testing.T) { - actual := needsQuoting(sym) - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) - } - - test(t, "", true) - test(t, "null", true) - test(t, "true", true) - test(t, "false", true) - test(t, "nan", true) - - test(t, "basic", false) - test(t, "_basic_", false) - test(t, "basic$123", false) - test(t, "$", false) - test(t, "$basic", false) - - test(t, "123", true) - test(t, "$123", true) - test(t, "abc.def", true) - test(t, "abc,def", true) - test(t, "abc:def", true) - test(t, "abc{def", true) - test(t, "abc}def", true) - test(t, "abc[def", true) - test(t, "abc]def", true) - test(t, "abc'def", true) - test(t, "abc\"def", true) -} - -func TestIsSymbolRef(t *testing.T) { - testIsSymbolRef(t, "", false) - testIsSymbolRef(t, "1", false) - testIsSymbolRef(t, "a", false) - testIsSymbolRef(t, "$", false) - testIsSymbolRef(t, "$1", true) - testIsSymbolRef(t, "$1234567890", true) - testIsSymbolRef(t, "$a", false) - testIsSymbolRef(t, "$1234a567890", false) -} - -func testIsSymbolRef(t *testing.T, sym string, expected bool) { - t.Run(sym, func(t *testing.T) { - actual := isSymbolRef(sym) - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) -} - -func TestWriteEscapedSymbol(t *testing.T) { - testWriteEscapedSymbol(t, "basic", "basic") - testWriteEscapedSymbol(t, "\"basic\"", "\"basic\"") - testWriteEscapedSymbol(t, "o'clock", "o\\'clock") - testWriteEscapedSymbol(t, "c:\\", "c:\\\\") -} - -func testWriteEscapedSymbol(t *testing.T, sym, expected string) { - t.Run(expected, func(t *testing.T) { - buf := strings.Builder{} - if err := writeEscapedSymbol(sym, &buf); err != nil { - t.Fatal(err) - } - actual := buf.String() - if actual != expected { - t.Errorf("bad encoding of \"%v\": \"%v\"", expected, actual) - } - }) -} - -func TestWriteEscapedChar(t *testing.T) { - testWriteEscapedChar(t, 0, "\\0") - testWriteEscapedChar(t, '\n', "\\n") - testWriteEscapedChar(t, 1, "\\x01") - testWriteEscapedChar(t, '\xFF', "\\xFF") -} - -func testWriteEscapedChar(t *testing.T, c byte, expected string) { - t.Run(expected, func(t *testing.T) { - buf := strings.Builder{} - if err := writeEscapedChar(c, &buf); err != nil { - t.Fatal(err) - } - actual := buf.String() - if actual != expected { - t.Errorf("bad encoding of '%v': \"%v\"", expected, actual) - } - }) -} diff --git a/text/utils.go b/textutils.go similarity index 57% rename from text/utils.go rename to textutils.go index d53c9fca..0d0973e9 100644 --- a/text/utils.go +++ b/textutils.go @@ -1,10 +1,11 @@ -package text +package ion import ( "io" ) -func needsQuoting(sym string) bool { +// Does this symbol need to be quoted in text form? +func symbolNeedsQuoting(sym string) bool { if sym == "" || sym == "null" || sym == "true" || sym == "false" || sym == "nan" { return true } @@ -26,6 +27,7 @@ func needsQuoting(sym string) bool { return false } +// Is this the text form of a symbol reference ($)? func isSymbolRef(sym string) bool { if len(sym) == 0 || sym[0] != '$' { return false @@ -44,6 +46,7 @@ func isSymbolRef(sym string) bool { return true } +// Is this a valid first character for an identifier? func isIdentifierStart(c byte) bool { if c >= 'a' && c <= 'z' { return true @@ -57,28 +60,32 @@ func isIdentifierStart(c byte) bool { return false } +// Is this a valid character for later in an identifier? func isIdentifierPart(c byte) bool { return isIdentifierStart(c) || isDigit(c) } +// Is this a digit? func isDigit(c byte) bool { return c >= '0' && c <= '9' } +// Write the given symbol out, quoting and encoding if necessary. func writeSymbol(sym string, out io.Writer) error { - if needsQuoting(sym) { - if err := writeChar('\'', out); err != nil { + if symbolNeedsQuoting(sym) { + if err := writeRawChar('\'', out); err != nil { return err } if err := writeEscapedSymbol(sym, out); err != nil { return err } - return writeChar('\'', out) + return writeRawChar('\'', out) } else { - return writeString(sym, out) + return writeRawString(sym, out) } } +// Write the given symbol out, escaping any characters that need escaping. func writeEscapedSymbol(sym string, out io.Writer) error { for i := 0; i < len(sym); i++ { c := sym[i] @@ -87,7 +94,7 @@ func writeEscapedSymbol(sym string, out io.Writer) error { return err } } else { - if err := writeChar(c, out); err != nil { + if err := writeRawChar(c, out); err != nil { return err } } @@ -95,6 +102,7 @@ func writeEscapedSymbol(sym string, out io.Writer) error { return nil } +// Write the given string out, escaping any characters that need escaping. func writeEscapedString(str string, out io.Writer) error { for i := 0; i>4)&0xF], hexChars[c&0xF]} - return writeChars(buf, out) + return writeRawChars(buf, out) } } -func writeString(s string, out io.Writer) error { +// Write out the given raw string. +func writeRawString(s string, out io.Writer) error { _, err := out.Write([]byte(s)) return err } -func writeChars(cs []byte, out io.Writer) error { +// Write out the given raw character sequence. +func writeRawChars(cs []byte, out io.Writer) error { _, err := out.Write(cs) return err } -func writeChar(c byte, out io.Writer) error { +// Write out the given raw character. +func writeRawChar(c byte, out io.Writer) error { _, err := out.Write([]byte{c}) return err } diff --git a/textutils_test.go b/textutils_test.go new file mode 100644 index 00000000..684904ec --- /dev/null +++ b/textutils_test.go @@ -0,0 +1,131 @@ +package ion + +import ( + "strings" + "testing" +) + +func TestWriteSymbol(t *testing.T) { + test := func(sym, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeSymbol(sym, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("expected \"%v\", got \"%v\"", expected, actual) + } + }) + } + + test("", "''") + test("null", "'null'") + test("null.null", "'null.null'") + + test("basic", "basic") + test("_basic_", "_basic_") + test("$basic$", "$basic$") + + test("123", "'123'") + test("$123", "'$123'") + test("abc'def", "'abc\\'def'") + test("abc\"def", "'abc\"def'") +} + +func TestSymbolNeedsQuoting(t *testing.T) { + test := func(sym string, expected bool) { + t.Run(sym, func(t *testing.T) { + actual := symbolNeedsQuoting(sym) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("", true) + test("null", true) + test("true", true) + test("false", true) + test("nan", true) + + test("basic", false) + test("_basic_", false) + test("basic$123", false) + test("$", false) + test("$basic", false) + + test("123", true) + test("$123", true) + test("abc.def", true) + test("abc,def", true) + test("abc:def", true) + test("abc{def", true) + test("abc}def", true) + test("abc[def", true) + test("abc]def", true) + test("abc'def", true) + test("abc\"def", true) +} + +func TestIsSymbolRef(t *testing.T) { + test := func(sym string, expected bool) { + t.Run(sym, func(t *testing.T) { + actual := isSymbolRef(sym) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("", false) + test("1", false) + test("a", false) + test("$", false) + test("$1", true) + test("$1234567890", true) + test("$a", false) + test("$1234a567890", false) +} + +func TestWriteEscapedSymbol(t *testing.T) { + test := func(sym, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeEscapedSymbol(sym, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("bad encoding of \"%v\": \"%v\"", + expected, actual) + } + }) + } + + test("basic", "basic") + test("\"basic\"", "\"basic\"") + test("o'clock", "o\\'clock") + test("c:\\", "c:\\\\") +} + +func TestWriteEscapedChar(t *testing.T) { + test := func(c byte, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeEscapedChar(c, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("bad encoding of '%v': \"%v\"", + expected, actual) + } + }) + } + + test(0, "\\0") + test('\n', "\\n") + test(1, "\\x01") + test('\xFF', "\\xFF") +} diff --git a/text/writer.go b/textwriter.go similarity index 51% rename from text/writer.go rename to textwriter.go index b63a33f0..0a6c4441 100644 --- a/text/writer.go +++ b/textwriter.go @@ -1,4 +1,4 @@ -package text +package ion import ( "encoding/base64" @@ -9,96 +9,30 @@ import ( "strconv" "strings" "time" - - "github.com/fernomac/ion-go" ) -type contextType uint8 - -const ( - topLevelCtx contextType = iota - inStructCtx - inListCtx - inSexpCtx -) - -type context struct { - value contextType - parent *context -} - -var topLevel = &context{value: topLevelCtx, parent: nil} - -type writer struct { - out io.Writer - ctx *context - err error - - fieldName string - typeAnnotations []string +// textWriter is a writer that writes human-readable text +type textWriter struct { + writer needsSeparator bool } -// NewWriter returns a new text writer. -func NewWriter(out io.Writer) ion.Writer { - return &writer{ - out: out, - ctx: topLevel, - } -} - -func (w *writer) push(t contextType) { - ctx := &context{ - value: t, - parent: w.ctx, - } - w.ctx = ctx -} - -func (w *writer) pop() { - if w.ctx.parent == nil { - panic("pop called at the top level") - } - w.ctx = w.ctx.parent -} - -func (w *writer) InStruct() bool { - return (w.ctx.value == inStructCtx) -} - -func (w *writer) Err() error { - return w.err -} - -func (w *writer) FieldName(val string) { - if w.err != nil { - return - } - if !w.InStruct() { - w.err = errors.New("field name called while not in a struct") - return +// NewTextWriter returns a new text writer. +func NewTextWriter(out io.Writer) Writer { + return &textWriter{ + writer: writer{ + out: out, + }, } - w.fieldName = val } -func (w *writer) TypeAnnotation(val string) { - if w.err != nil { - return - } - w.typeAnnotations = append(w.typeAnnotations, val) -} - -func (w *writer) TypeAnnotations(val ...string) { - if w.err != nil { - return - } - w.typeAnnotations = append(w.typeAnnotations, val...) -} - -func (w *writer) beginValue() error { +// beginValue begins the process of writing a value, by writing out +// a separator (if needed), field name (if in a struct), and type +// annotations (if any). +func (w *textWriter) beginValue() error { if w.needsSeparator { var sep byte - switch w.ctx.value { + switch w.ctx() { case inStructCtx, inListCtx: sep = ',' case inSexpCtx: @@ -107,7 +41,7 @@ func (w *writer) beginValue() error { sep = '\n' } - if err := writeChar(sep, w.out); err != nil { + if err := writeRawChar(sep, w.out); err != nil { return err } } @@ -122,7 +56,7 @@ func (w *writer) beginValue() error { if err := writeSymbol(name, w.out); err != nil { return err } - if err := writeChar(':', w.out); err != nil { + if err := writeRawChar(':', w.out); err != nil { return err } } @@ -135,7 +69,7 @@ func (w *writer) beginValue() error { if err := writeSymbol(a, w.out); err != nil { return err } - if err := writeString("::", w.out); err != nil { + if err := writeRawString("::", w.out); err != nil { return err } } @@ -144,11 +78,13 @@ func (w *writer) beginValue() error { return nil } -func (w *writer) endValue() { +// endValue finishes the process of writing a value. +func (w *textWriter) endValue() { w.needsSeparator = true } -func (w *writer) begin(t contextType, c byte) error { +// begin starts writing a container of the given type. +func (w *textWriter) begin(t ctxType, c byte) error { if err := w.beginValue(); err != nil { return err } @@ -156,15 +92,16 @@ func (w *writer) begin(t contextType, c byte) error { w.push(t) w.needsSeparator = false - return writeChar(c, w.out) + return writeRawChar(c, w.out) } -func (w *writer) end(t contextType, c byte) error { - if w.ctx.value != t { - return errors.New("not in an appropriate container") +// end finishes writing a container of the given type +func (w *textWriter) end(t ctxType, c byte) error { + if w.ctx() != t { + return errors.New("not in that kind of container") } - if err := writeChar(c, w.out); err != nil { + if err := writeRawChar(c, w.out); err != nil { return err } @@ -176,55 +113,63 @@ func (w *writer) end(t contextType, c byte) error { return nil } -func (w *writer) BeginStruct() { +// BeginStruct begins writing a struct. +func (w *textWriter) BeginStruct() { if w.err != nil { return } w.err = w.begin(inStructCtx, '{') } -func (w *writer) EndStruct() { +// EndStruct finishes writing a struct. +func (w *textWriter) EndStruct() { if w.err != nil { return } w.err = w.end(inStructCtx, '}') } -func (w *writer) BeginList() { +// BeginList begins writing a list. +func (w *textWriter) BeginList() { if w.err != nil { return } w.err = w.begin(inListCtx, '[') } -func (w *writer) EndList() { +// EndList finishes writing a list. +func (w *textWriter) EndList() { if w.err != nil { return } w.err = w.end(inListCtx, ']') } -func (w *writer) BeginSexp() { +// BeginSexp begins writing an s-expression. +func (w *textWriter) BeginSexp() { if w.err != nil { return } w.err = w.begin(inSexpCtx, '(') } -func (w *writer) EndSexp() { +// EndSexp finishes writing an s-expression. +func (w *textWriter) EndSexp() { if w.err != nil { return } w.err = w.end(inSexpCtx, ')') } -func (w *writer) writeValue(f func() string) error { +// writeValue writes a value whose raw encoding is produced by the +// given function. +func (w *textWriter) writeValue(f func() string) error { if err := w.beginValue(); err != nil { return err } sym := f() - if err := writeString(sym, w.out); err != nil { + if err := writeRawString(sym, w.out); err != nil { return err } @@ -232,7 +177,9 @@ func (w *writer) writeValue(f func() string) error { return nil } -func (w *writer) writeValueStreaming(f func() error) error { +// writeValue writes a value by calling the given function, which is +// expected to write the raw value to w.out. +func (w *textWriter) writeValueStreaming(f func() error) error { if err := w.beginValue(); err != nil { return err } @@ -245,41 +192,43 @@ func (w *writer) writeValueStreaming(f func() error) error { return nil } -func (w *writer) WriteNull() { - w.WriteNullWithType(ion.NullType) +// WriteNull writes an untyped null. +func (w *textWriter) WriteNull() { + w.WriteNullWithType(NullType) } -func (w *writer) WriteNullWithType(t ion.Type) { +// WriteNullWithType writes a typed null. +func (w *textWriter) WriteNullWithType(t Type) { if w.err != nil { return } w.err = w.writeValue(func() string { switch t { - case ion.NullType: + case NullType: return "null" - case ion.BoolType: + case BoolType: return "null.bool" - case ion.IntType: + case IntType: return "null.int" - case ion.FloatType: + case FloatType: return "null.float" - case ion.DecimalType: + case DecimalType: return "null.decimal" - case ion.TimestampType: + case TimestampType: return "null.timestamp" - case ion.StringType: + case StringType: return "null.string" - case ion.SymbolType: + case SymbolType: return "null.symbol" - case ion.BlobType: + case BlobType: return "null.blob" - case ion.ClobType: + case ClobType: return "null.clob" - case ion.StructType: + case StructType: return "null.struct" - case ion.ListType: + case ListType: return "null.list" - case ion.SexpType: + case SexpType: return "null.sexp" default: panic("invalid type") @@ -287,14 +236,8 @@ func (w *writer) WriteNullWithType(t ion.Type) { }) } -func symbolForBool(val bool) string { - if val { - return "true" - } - return "false" -} - -func (w *writer) WriteBool(val bool) { +// WriteBool writes a boolean value. +func (w *textWriter) WriteBool(val bool) { if w.err != nil { return } @@ -306,7 +249,8 @@ func (w *writer) WriteBool(val bool) { }) } -func (w *writer) WriteInt(val int64) { +// WriteInt writes an integer value. +func (w *textWriter) WriteInt(val int64) { if w.err != nil { return } @@ -315,7 +259,8 @@ func (w *writer) WriteInt(val int64) { }) } -func (w *writer) WriteBigInt(val *big.Int) { +// WriteBigInt writes a (big) integer value. +func (w *textWriter) WriteBigInt(val *big.Int) { if w.err != nil { return } @@ -324,12 +269,13 @@ func (w *writer) WriteBigInt(val *big.Int) { }) } -func (w *writer) WriteFloat(val float64) { +// WriteFloat writes a floating-point value. +func (w *textWriter) WriteFloat(val float64) { if w.err != nil { return } w.err = w.writeValue(func() string { - // Built-in go formatting isn't up to the task. :( + // Built-in go formatting isn't quite up to the task. :( str := strconv.FormatFloat(val, 'e', -1, 64) switch str { @@ -350,7 +296,8 @@ func (w *writer) WriteFloat(val float64) { }) } -func (w *writer) WriteDecimal(val *ion.Decimal) { +// WriteDecimal writes an arbitrary-precision decimal value. +func (w *textWriter) WriteDecimal(val *Decimal) { if w.err != nil { return } @@ -359,7 +306,8 @@ func (w *writer) WriteDecimal(val *ion.Decimal) { }) } -func (w *writer) WriteTimestamp(val time.Time) { +// WriteTimestamp writes a timestamp. +func (w *textWriter) WriteTimestamp(val time.Time) { if w.err != nil { return } @@ -368,7 +316,8 @@ func (w *writer) WriteTimestamp(val time.Time) { }) } -func (w *writer) WriteSymbol(val string) { +// WriteSymbol writes a symbol. +func (w *textWriter) WriteSymbol(val string) { if w.err != nil { return } @@ -377,27 +326,29 @@ func (w *writer) WriteSymbol(val string) { }) } -func (w *writer) WriteString(val string) { +// WriteString writes a string. +func (w *textWriter) WriteString(val string) { if w.err != nil { return } w.err = w.writeValueStreaming(func() error { - if err := writeChar('"', w.out); err != nil { + if err := writeRawChar('"', w.out); err != nil { return err } if err := writeEscapedString(val, w.out); err != nil { return err } - return writeChar('"', w.out) + return writeRawChar('"', w.out) }) } -func (w *writer) WriteBlob(val []byte) { +// WriteBlob writes a blob. +func (w *textWriter) WriteBlob(val []byte) { if w.err != nil { return } w.err = w.writeValueStreaming(func() error { - if err := writeString("{{", w.out); err != nil { + if err := writeRawString("{{", w.out); err != nil { return err } @@ -407,16 +358,17 @@ func (w *writer) WriteBlob(val []byte) { return err } - return writeString("}}", w.out) + return writeRawString("}}", w.out) }) } -func (w *writer) WriteClob(val []byte) { +// WriteClob writes a clob. +func (w *textWriter) WriteClob(val []byte) { if w.err != nil { return } w.err = w.writeValueStreaming(func() error { - if err := writeString("{{\"", w.out); err != nil { + if err := writeRawString("{{\"", w.out); err != nil { return err } @@ -426,26 +378,27 @@ func (w *writer) WriteClob(val []byte) { return err } } else { - if err := writeChar(c, w.out); err != nil { + if err := writeRawChar(c, w.out); err != nil { return err } } } - return writeString("\"}}", w.out) + return writeRawString("\"}}", w.out) }) } -func (w *writer) Finish() error { +// Finish finishes the current datagram. +func (w *textWriter) Finish() error { if w.err != nil { return w.err } - if w.ctx.value != topLevelCtx { + if w.ctx() != atTopLevelCtx { w.err = errors.New("not at top level") return w.err } - if w.err = writeChar('\n', w.out); w.err != nil { + if w.err = writeRawChar('\n', w.out); w.err != nil { return w.err } diff --git a/text/writer_test.go b/textwriter_test.go similarity index 79% rename from text/writer_test.go rename to textwriter_test.go index df4d9287..fc8bf244 100644 --- a/text/writer_test.go +++ b/textwriter_test.go @@ -1,4 +1,4 @@ -package text +package ion import ( "math" @@ -6,12 +6,10 @@ import ( "strings" "testing" "time" - - "github.com/fernomac/ion-go" ) func TestTopLevelFieldName(t *testing.T) { - writeText(func(w ion.Writer) { + writeText(func(w Writer) { w.FieldName("foo") if w.Err() == nil { t.Error("expected an error") @@ -20,7 +18,7 @@ func TestTopLevelFieldName(t *testing.T) { } func TestEmptyStruct(t *testing.T) { - testTextWriter(t, "{}", func(w ion.Writer) { + testTextWriter(t, "{}", func(w Writer) { if w.InStruct() { t.Error("already in struct") } @@ -51,7 +49,7 @@ func TestEmptyStruct(t *testing.T) { } func TestAnnotatedStruct(t *testing.T) { - testTextWriter(t, "foo::$bar::'.baz'::{}", func(w ion.Writer) { + testTextWriter(t, "foo::$bar::'.baz'::{}", func(w Writer) { w.TypeAnnotation("foo") w.TypeAnnotation("$bar") w.TypeAnnotation(".baz") @@ -65,7 +63,7 @@ func TestAnnotatedStruct(t *testing.T) { } func TestNestedStruct(t *testing.T) { - testTextWriter(t, "{foo:'true'::{},'null':{}}", func(w ion.Writer) { + testTextWriter(t, "{foo:'true'::{},'null':{}}", func(w Writer) { w.BeginStruct() w.FieldName("foo") @@ -82,7 +80,7 @@ func TestNestedStruct(t *testing.T) { } func TestEmptyList(t *testing.T) { - testTextWriter(t, "[]", func(w ion.Writer) { + testTextWriter(t, "[]", func(w Writer) { w.BeginList() if w.Err() != nil { t.Fatal(w.Err()) @@ -105,7 +103,7 @@ func TestEmptyList(t *testing.T) { } func TestNestedLists(t *testing.T) { - testTextWriter(t, "[{},foo::{},'null'::[]]", func(w ion.Writer) { + testTextWriter(t, "[{},foo::{},'null'::[]]", func(w Writer) { w.BeginList() w.BeginStruct() @@ -124,7 +122,7 @@ func TestNestedLists(t *testing.T) { } func TestSexps(t *testing.T) { - testTextWriter(t, "()\n(())\n(() ())", func(w ion.Writer) { + testTextWriter(t, "()\n(())\n(() ())", func(w Writer) { w.BeginSexp() w.EndSexp() @@ -144,15 +142,15 @@ func TestSexps(t *testing.T) { func TestNull(t *testing.T) { expected := "[null,foo::null,null.int,bar::null.sexp]" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginList() w.WriteNull() w.TypeAnnotation("foo") - w.WriteNullWithType(ion.NullType) - w.WriteNullWithType(ion.IntType) + w.WriteNullWithType(NullType) + w.WriteNullWithType(IntType) w.TypeAnnotation("bar") - w.WriteNullWithType(ion.SexpType) + w.WriteNullWithType(SexpType) w.EndList() }) @@ -160,7 +158,7 @@ func TestNull(t *testing.T) { func TestBool(t *testing.T) { expected := "true\n(false '123'::true)\n'false'::false" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.WriteBool(true) w.BeginSexp() @@ -176,7 +174,7 @@ func TestBool(t *testing.T) { func TestInt(t *testing.T) { expected := "(zero::0 1 -1 (9223372036854775807 -9223372036854775808))" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginSexp() w.TypeAnnotation("zero"); w.WriteInt(0) @@ -194,7 +192,7 @@ func TestInt(t *testing.T) { func TestBigInt(t *testing.T) { expected := "[0,big::18446744073709551616]" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginList() w.WriteBigInt(big.NewInt(0)) @@ -212,7 +210,7 @@ func TestBigInt(t *testing.T) { func TestFloat(t *testing.T) { expected := "{z:0e+0,nz:-0e+0,s:1.234e+1,l:1.234e-55,n:nan,i:+inf,ni:-inf}" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginStruct() w.FieldName("z"); w.WriteFloat(0.0) @@ -231,15 +229,15 @@ func TestFloat(t *testing.T) { func TestDecimal(t *testing.T) { expected := "0.\n-1.23d-98" - testTextWriter(t, expected, func(w ion.Writer) { - w.WriteDecimal(ion.NewDecimal(big.NewInt(0))) - w.WriteDecimal(ion.NewDecimalWithScale(big.NewInt(-123), 100)) + testTextWriter(t, expected, func(w Writer) { + w.WriteDecimal(MustParseDecimal("0")) + w.WriteDecimal(MustParseDecimal("-123d-100")) }) } func TestTimestamp(t *testing.T) { expected := "1970-01-01T00:00:00.001Z\n1970-01-01T01:23:00+01:23" - testTextWriter(t, expected, func (w ion.Writer) { + testTextWriter(t, expected, func (w Writer) { w.WriteTimestamp(time.Unix(0, 1000000).In(time.UTC)) w.WriteTimestamp(time.Unix(0, 0).In(time.FixedZone("wtf", 4980))) }) @@ -247,7 +245,7 @@ func TestTimestamp(t *testing.T) { func TestSymbol(t *testing.T) { expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸'}" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginStruct() w.FieldName("foo"); w.WriteSymbol("bar") @@ -264,7 +262,7 @@ func TestSymbol(t *testing.T) { func TestString(t *testing.T) { expected := `("hello" "" ("\\\"\n\"\\" zany::"🤪"))` - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginSexp() w.WriteString("hello") w.WriteString("") @@ -280,7 +278,7 @@ func TestString(t *testing.T) { func TestBlob(t *testing.T) { expected := "{{AAEC/f7/}}\n{{SGVsbG8gV29ybGQ=}}\nempty::{{}}" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.WriteBlob([]byte{ 0, 1, 2, 0xFD, 0xFE, 0xFF }) w.WriteBlob([]byte("Hello World")) w.TypeAnnotation("empty"); w.WriteBlob(nil) @@ -289,7 +287,7 @@ func TestBlob(t *testing.T) { func TestClob(t *testing.T) { expected := "{hello:{{\"world\"}},bits:{{\"\\0\\x01\\xFE\\xFF\"}}}" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.BeginStruct() w.FieldName("hello"); w.WriteClob([]byte("world")) w.FieldName("bits"); w.WriteClob([]byte{0,1,0xFE,0xFF}) @@ -299,7 +297,7 @@ func TestClob(t *testing.T) { func TestFinish(t *testing.T) { expected := "1\nfoo\n\"bar\"\n{}\n" - testTextWriter(t, expected, func(w ion.Writer) { + testTextWriter(t, expected, func(w Writer) { w.WriteInt(1) w.WriteSymbol("foo") w.WriteString("bar") @@ -312,7 +310,7 @@ func TestFinish(t *testing.T) { func TestBadFinish(t *testing.T) { buf := strings.Builder{} - w := NewWriter(&buf) + w := NewTextWriter(&buf) w.BeginStruct() err := w.Finish() @@ -322,16 +320,16 @@ func TestBadFinish(t *testing.T) { } } -func testTextWriter(t *testing.T, expected string, f func(ion.Writer)) { +func testTextWriter(t *testing.T, expected string, f func(Writer)) { actual := writeText(f) if actual != expected { t.Errorf("expected: %v, actual: %v", expected, actual) } } -func writeText(f func(ion.Writer)) string { +func writeText(f func(Writer)) string { buf := strings.Builder{} - w := NewWriter(&buf) + w := NewTextWriter(&buf) f(w) diff --git a/writer.go b/writer.go new file mode 100644 index 00000000..5585db1e --- /dev/null +++ b/writer.go @@ -0,0 +1,97 @@ +package ion + +import ( + "errors" + "io" +) + +type ctxType byte + +const ( + atTopLevelCtx ctxType = iota + inStructCtx + inListCtx + inSexpCtx +) + +// writer holds shared stuff for all writers. +type writer struct { + out io.Writer + ctxArr []ctxType + err error + + fieldName string + typeAnnotations []string +} + +// InStruct returns true if we're currently writing a struct. +func (w *writer) InStruct() bool { + return w.ctx() == inStructCtx +} + +// InList returns true if we're currently writing a list. +func (w *writer) InList() bool { + return w.ctx() == inListCtx +} + +// InSexp returns true if we're currently writing an s-expression. +func (w *writer) InSexp() bool { + return w.ctx() == inSexpCtx +} + +// Err returns the current error, or nil if there are none yet. +func (w *writer) Err() error { + return w.err +} + + +// FieldName sets the field name for the next value written. +// It may only be called while writing a struct. +func (w *writer) FieldName(val string) { + if w.err != nil { + return + } + if !w.InStruct() { + w.err = errors.New("FieldName() called but not writing a struct") + return + } + w.fieldName = val +} + +// TypeAnnotation adds a type annotation to the next value written. +func (w *writer) TypeAnnotation(val string) { + if w.err != nil { + return + } + w.typeAnnotations = append(w.typeAnnotations, val) +} + +// TypeAnnotations adds one or more type annotations to the next value +// written. +func (w *writer) TypeAnnotations(val ...string) { + if w.err != nil { + return + } + w.typeAnnotations = append(w.typeAnnotations, val...) +} + +// ctx returns the current writing context +func (w *writer) ctx() ctxType { + if len(w.ctxArr) == 0 { + return atTopLevelCtx + } + return w.ctxArr[len(w.ctxArr)-1] +} + +// push pushes a new writing context when a new container is begun. +func (w *writer) push(ctx ctxType) { + w.ctxArr = append(w.ctxArr, ctx) +} + +// pop pops the writing context when a container is ended. +func (w *writer) pop() { + if len(w.ctxArr) == 0 { + panic("pop called at top level") + } + w.ctxArr = w.ctxArr[:len(w.ctxArr)-1] +} From b5b2778d713dc60b0ae32c1d86c833e441865e47 Mon Sep 17 00:00:00 2001 From: David Murray Date: Wed, 12 Jun 2019 20:38:11 +1000 Subject: [PATCH 06/56] fine, go fmt --- textutils.go | 2 +- textwriter.go | 14 +++++---- textwriter_test.go | 71 ++++++++++++++++++++++++++++++---------------- writer.go | 7 ++--- 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/textutils.go b/textutils.go index 0d0973e9..9606eeaf 100644 --- a/textutils.go +++ b/textutils.go @@ -104,7 +104,7 @@ func writeEscapedSymbol(sym string, out io.Writer) error { // Write the given string out, escaping any characters that need escaping. func writeEscapedString(str string, out io.Writer) error { - for i := 0; i Date: Sat, 15 Jun 2019 16:43:36 +1000 Subject: [PATCH 07/56] symboltable, reader interface --- api.go | 33 ++++- symboltable.go | 326 ++++++++++++++++++++++++++++++++++++++++++++ symboltable_test.go | 170 +++++++++++++++++++++++ 3 files changed, 528 insertions(+), 1 deletion(-) create mode 100644 symboltable.go create mode 100644 symboltable_test.go diff --git a/api.go b/api.go index 7c043045..22676803 100644 --- a/api.go +++ b/api.go @@ -9,8 +9,10 @@ import ( type Type uint8 const ( + // NoType is returned by a Reader that's not currently pointing at a value. + NoType Type = iota // NullType is the type of the (unqualified) null value. - NullType Type = iota + NullType // BoolType is the type of a boolean, true or false. BoolType // IntType is the type of a signed integer of arbitrary size. @@ -37,9 +39,38 @@ const ( SexpType ) +// A Reader reads Ion values from an input stream. +type Reader interface { + SymbolTable() SymbolTable + + Next() (Type, error) + Type() Type + IsNull() bool + + StepIn() error + StepOut() error + + FieldName() (string, error) + TypeAnnotations() ([]string, error) + + BoolValue() (bool, error) + IntValue() (int, error) + Int64Value() (int64, error) + BigIntValue() (*big.Int, error) + FloatValue() (float64, error) + DecimalValue() (*Decimal, error) + + TimeValue() (time.Time, error) + StringValue() (string, error) + + ByteValue() ([]byte, error) +} + // A Writer writes Ion values to an output stream. type Writer interface { InStruct() bool + InList() bool + InSexp() bool Err() error FieldName(val string) diff --git a/symboltable.go b/symboltable.go new file mode 100644 index 00000000..2f6d5cab --- /dev/null +++ b/symboltable.go @@ -0,0 +1,326 @@ +package ion + +import ( + "strings" +) + +// A SymbolTable maps binary-representation symbol IDs to +// text-representation strings and vice versa. +type SymbolTable interface { + // MaxID returns the maximum ID this symbol table defines. + MaxID() int + // FindByName finds the ID of a symbol by its name. + FindByName(symbol string) (int, bool) + // FindByID finds the name of a symbol given its ID. + FindByID(id int) (string, bool) + // WriteTo serializes the symbol table to an ion.Writer. + WriteTo(w Writer) error + // String returns an ion text representation of the symbol table. + String() string +} + +// A SharedSymbolTable is distributed out-of-band and referenced from +// a LocalSymbolTable to save space. +type SharedSymbolTable struct { + name string + version int + symbols []string + index map[string]int +} + +var _ SymbolTable = &SharedSymbolTable{} + +// NewSharedSymbolTable creates a new shared symbol table. +func NewSharedSymbolTable(name string, version int, symbols []string) *SharedSymbolTable { + if name == "" { + panic("name must be non-empty") + } + if version < 1 { + panic("version must be at least one") + } + + index, copy := buildIndex(symbols, 0) + + return &SharedSymbolTable{ + name: name, + version: version, + symbols: copy, + index: index, + } +} + +func buildIndex(symbols []string, offset int) (map[string]int, []string) { + index := map[string]int{} + copy := []string{} + + for _, sym := range symbols { + if _, ok := index[sym]; !ok { + copy = append(copy, sym) + index[sym] = offset + len(copy) + } + } + + return index, copy +} + +func (s *SharedSymbolTable) Name() string { + return s.name +} + +func (s *SharedSymbolTable) Version() int { + return s.version +} + +func (s *SharedSymbolTable) MaxID() int { + return len(s.symbols) +} + +func (s *SharedSymbolTable) FindByName(sym string) (int, bool) { + id, ok := s.index[sym] + return id, ok +} + +func (s *SharedSymbolTable) FindByID(id int) (string, bool) { + if id <= 0 || id > len(s.symbols) { + return "", false + } + return s.symbols[id-1], true +} + +func (s *SharedSymbolTable) WriteTo(w Writer) error { + w.TypeAnnotation("$ion_shared_symbol_table") + w.BeginStruct() + + w.FieldName("name") + w.WriteString(s.name) + + w.FieldName("version") + w.WriteInt(int64(s.version)) + + w.FieldName("symbols") + w.BeginList() + + for _, sym := range s.symbols { + w.WriteString(sym) + } + + w.EndList() // symbols + + w.EndStruct() + return w.Err() +} + +func (s *SharedSymbolTable) String() string { + buf := strings.Builder{} + + w := NewTextWriter(&buf) + s.WriteTo(w) + + return buf.String() +} + +// The (implied) system symbol table for Ion v1.0. +var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ + "$ion", + "$ion_1_0", + "$ion_symbol_table", + "name", + "version", + "imports", + "symbols", + "max_id", + "$ion_shared_symbol_table", +}) + +// A LocalSymbolTable is transmitted in-band along with the binary data +// it describes. It may include SharedSymbolTables by reference. +type LocalSymbolTable struct { + imports []*SharedSymbolTable + offsets []int + maxImportID int + + symbols []string + index map[string]int +} + +var _ SymbolTable = &LocalSymbolTable{} + +// NewLocalSymbolTable creates a new local symbol table. +func NewLocalSymbolTable(imports []*SharedSymbolTable, symbols []string) *LocalSymbolTable { + imps, offsets, maxID := processImports(imports) + index, copy := buildIndex(symbols, maxID) + + return &LocalSymbolTable{ + imports: imps, + offsets: offsets, + maxImportID: maxID, + symbols: copy, + index: index, + } +} + +func processImports(imports []*SharedSymbolTable) ([]*SharedSymbolTable, []int, int) { + imps := append([]*SharedSymbolTable{}, imports...) + + // TODO: Automatically add V1SystemSymbolTable? + + maxID := 0 + offsets := make([]int, len(imps)) + for i, imp := range imps { + offsets[i] = maxID + maxID += imp.MaxID() + } + + return imps, offsets, maxID +} + +func (t *LocalSymbolTable) MaxID() int { + return t.maxImportID + len(t.symbols) +} + +func (t *LocalSymbolTable) FindByName(s string) (int, bool) { + for i, imp := range t.imports { + if id, ok := imp.FindByName(s); ok { + return t.offsets[i] + id, true + } + } + + if id, ok := t.index[s]; ok { + return id, true + } + + return 0, false +} + +func (t *LocalSymbolTable) FindByID(id int) (string, bool) { + if id <= 0 { + return "", false + } + if id <= t.maxImportID { + return t.findByIDInImports(id) + } + + // Local to this symbol table. + idx := id - t.maxImportID - 1 + if idx < len(t.symbols) { + return t.symbols[idx], true + } + + return "", false +} + +func (t *LocalSymbolTable) findByIDInImports(id int) (string, bool) { + i := 1 + off := 0 + + for ; i < len(t.imports); i++ { + if id <= t.offsets[i] { + break + } + off = t.offsets[i] + } + + return t.imports[i-1].FindByID(id - off) +} + +func (t *LocalSymbolTable) WriteTo(w Writer) error { + w.TypeAnnotation("$ion_symbol_table") + w.BeginStruct() + + if len(t.imports) > 0 { + w.FieldName("imports") + w.BeginList() + for _, imp := range t.imports { + w.BeginStruct() + + w.FieldName("name") + w.WriteString(imp.Name()) + + w.FieldName("version") + w.WriteInt(int64(imp.Version())) + + w.FieldName("max_id") + w.WriteInt(int64(imp.MaxID())) + + w.EndStruct() + } + w.EndList() + } + + if len(t.symbols) > 0 { + w.FieldName("symbols") + + w.BeginList() + for _, sym := range t.symbols { + w.WriteString(sym) + } + + w.EndList() + } + + w.EndStruct() + return w.Err() +} + +func (t *LocalSymbolTable) String() string { + buf := strings.Builder{} + + w := NewTextWriter(&buf) + t.WriteTo(w) + + return buf.String() +} + +// A SymbolTableBuilder helps you iteratively build a local symbol table. +type SymbolTableBuilder interface { + SymbolTable + + // Add adds a symbol to this symbol table. + Add(symbol string) (int, bool) + // Build creates an immutable local symbol table. + Build() *LocalSymbolTable +} + +type symbolTableBuilder struct { + LocalSymbolTable +} + +func NewSymbolTableBuilder(imports ...*SharedSymbolTable) SymbolTableBuilder { + imps, offsets, maxID := processImports(imports) + return &symbolTableBuilder{ + LocalSymbolTable{ + imports: imps, + offsets: offsets, + maxImportID: maxID, + index: make(map[string]int), + }, + } +} + +func (b *symbolTableBuilder) Add(symbol string) (int, bool) { + if id, ok := b.FindByName(symbol); ok { + return id, false + } + + b.symbols = append(b.symbols, symbol) + id := b.maxImportID + len(b.symbols) + b.index[symbol] = id + + return id, true +} + +func (b *symbolTableBuilder) Build() *LocalSymbolTable { + symbols := append([]string{}, b.symbols...) + index := make(map[string]int) + for s, i := range b.index { + index[s] = i + } + + return &LocalSymbolTable{ + imports: b.imports, + offsets: b.offsets, + maxImportID: b.maxImportID, + symbols: symbols, + index: index, + } +} diff --git a/symboltable_test.go b/symboltable_test.go new file mode 100644 index 00000000..20fd2320 --- /dev/null +++ b/symboltable_test.go @@ -0,0 +1,170 @@ +package ion + +import ( + "fmt" + "testing" +) + +func TestSharedSymbolTable(t *testing.T) { + st := NewSharedSymbolTable("test", 2, []string{ + "abc", + "def", + "foo'bar", + "null", + "def", + "ghi", + }) + + if st.Name() != "test" { + t.Errorf("wrong name: %v", st.Name()) + } + if st.Version() != 2 { + t.Errorf("wrong version: %v", st.Version()) + } + if st.MaxID() != 5 { + t.Errorf("wrong maxid: %v", st.MaxID()) + } + + testFindByName(t, st, "def", 2) + testFindByName(t, st, "null", 4) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 0, "") + testFindByID(t, st, 2, "def") + testFindByID(t, st, 4, "null") + testFindByID(t, st, 6, "") + + testString(t, st, `$ion_shared_symbol_table::{name:"test",version:2,symbols:["abc","def","foo'bar","null","ghi"]}`) +} + +func TestLocalSymbolTable(t *testing.T) { + st := NewLocalSymbolTable(nil, []string{"foo", "bar"}) + + if st.MaxID() != 2 { + t.Errorf("wrong maxid: %v", st.MaxID()) + } + + testFindByName(t, st, "$ion", 0) + testFindByName(t, st, "foo", 1) + testFindByName(t, st, "bar", 2) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 0, "") + testFindByID(t, st, 1, "foo") + testFindByID(t, st, 2, "bar") + testFindByID(t, st, 3, "") + + testString(t, st, `$ion_symbol_table::{symbols:["foo","bar"]}`) +} + +func TestLocalSymbolTableWithImports(t *testing.T) { + imports := []*SharedSymbolTable{V1SystemSymbolTable} + st := NewLocalSymbolTable(imports, []string{ + "foo", + "bar", + }) + + if st.MaxID() != 11 { // 9 from $ion.1, 2 local. + t.Errorf("wrong maxid: %v", st.MaxID()) + } + + testFindByName(t, st, "$ion", 1) + testFindByName(t, st, "$ion_shared_symbol_table", 9) + testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 0, "") + testFindByID(t, st, 1, "$ion") + testFindByID(t, st, 9, "$ion_shared_symbol_table") + testFindByID(t, st, 10, "foo") + testFindByID(t, st, 11, "bar") + testFindByID(t, st, 12, "") + + testString(t, st, `$ion_symbol_table::{imports:[{name:"$ion",version:1,max_id:9}],symbols:["foo","bar"]}`) +} + +func TestSymbolTableBuilder(t *testing.T) { + b := NewSymbolTableBuilder(V1SystemSymbolTable) + + id, ok := b.Add("name") + if ok { + t.Error("Add(name) returned true") + } + if id != 4 { + t.Errorf("Add(name) returned %v", id) + } + + id, ok = b.Add("foo") + if !ok { + t.Error("Add(foo) returned false") + } + if id != 10 { + t.Errorf("Add(foo) returned %v", id) + } + + id, ok = b.Add("foo") + if ok { + t.Error("Second Add(foo) returned true") + } + if id != 10 { + t.Errorf("Second Add(foo) returned %v", id) + } + + st := b.Build() + if st.MaxID() != 10 { + t.Errorf("maxid returned %v", st.MaxID()) + } + + testFindByName(t, st, "$ion", 1) + testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 1, "$ion") + testFindByID(t, st, 10, "foo") + testFindByID(t, st, 11, "") +} + +func testFindByName(t *testing.T, st SymbolTable, sym string, expected int) { + t.Run("FindByName("+sym+")", func(t *testing.T) { + actual, ok := st.FindByName(sym) + if expected == 0 { + if ok { + t.Fatalf("unexpectedly found: %v", actual) + } + } else { + if !ok { + t.Fatal("unexpectedly not found") + } + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + } + }) +} + +func testFindByID(t *testing.T, st SymbolTable, id int, expected string) { + t.Run(fmt.Sprintf("FindByID(%v)", id), func(t *testing.T) { + actual, ok := st.FindByID(id) + if expected == "" { + if ok { + t.Fatalf("unexpectedly found: %v", actual) + } + } else { + if !ok { + t.Fatal("unexpectedly not found") + } + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + } + }) +} + +func testString(t *testing.T, st SymbolTable, expected string) { + t.Run("String()", func(t *testing.T) { + actual := st.String() + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) +} From 7b32f6c254f65605b235b9457da917ae797f0c6c Mon Sep 17 00:00:00 2001 From: David Murray Date: Fri, 21 Jun 2019 08:48:57 +0300 Subject: [PATCH 08/56] start of tokenizer --- skipper.go | 769 ++++++++++++++++++++++++++++++++++++++++++++++ skipper_test.go | 178 +++++++++++ textutils.go | 62 +++- tokenizer.go | 547 +++++++++++++++++++++++++++++++++ tokenizer_test.go | 412 +++++++++++++++++++++++++ 5 files changed, 1962 insertions(+), 6 deletions(-) create mode 100644 skipper.go create mode 100644 skipper_test.go create mode 100644 tokenizer.go create mode 100644 tokenizer_test.go diff --git a/skipper.go b/skipper.go new file mode 100644 index 00000000..f9e8722c --- /dev/null +++ b/skipper.go @@ -0,0 +1,769 @@ +package ion + +import ( + "fmt" + "io" +) + +// SkipValue skips to the end of the current value, if the caller +// didn't bother to consume it before calling Next again. +func (t *tokenizer) skipValue() (int, error) { + var c int + var err error + + switch t.token { + case tokenNumeric, tokenInt, tokenDecimal, tokenFloat: + c, err = t.skipNumber() + case tokenBinary: + c, err = t.skipBinary() + case tokenHex: + c, err = t.skipHex() + case tokenTimestamp: + c, err = t.skipTimestamp() + case tokenSymbol: + c, err = t.skipSymbol() + case tokenSymbolQuoted: + c, err = t.skipSymbolQuoted() + case tokenSymbolOperator: + c, err = t.skipSymbolOperator() + case tokenString: + c, err = t.skipString() + case tokenLongString: + c, err = t.skipLongString() + case tokenOpenDoubleBrace: + c, err = t.skipBlob() + case tokenOpenBrace: + c, err = t.skipStruct() + case tokenOpenParen: + c, err = t.skipSexp() + case tokenOpenBracket: + c, err = t.skipList() + default: + err = fmt.Errorf("skipValue called with token=%v", t.token) + } + + if err != nil { + return 0, err + } + + if isWhitespace(c) { + c, _, err = t.skipWhitespace() + if err != nil { + return 0, err + } + } + + t.unfinished = false + return c, nil +} + +// SkipNumber skips a (non-binary, non-hex) number. +func (t *tokenizer) skipNumber() (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + if c == '-' { + c, err = t.read() + if err != nil { + return 0, err + } + } + + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + + if c == '.' { + c, err = t.read() + if err != nil { + return 0, err + } + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + } + + if c == 'd' || c == 'D' || c == 'e' || c == 'E' { + c, err = t.read() + if err != nil { + return 0, err + } + if c == '+' || c == '-' { + c, err = t.read() + if err != nil { + return 0, err + } + } + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + } + + ok, err := t.isStopChar(c) + if err != nil { + return 0, err + } + if !ok { + return 0, invalidChar(c) + } + return c, nil +} + +// SkipBinary skips a binary literal value. +func (t *tokenizer) skipBinary() (int, error) { + isB := func(c int) bool { + return c == 'b' || c == 'B' + } + isBinaryDigit := func(c int) bool { + return c == '0' || c == '1' + } + return t.skipRadix(isB, isBinaryDigit) +} + +// SkipHex skips a hex value. +func (t *tokenizer) skipHex() (int, error) { + isX := func(c int) bool { + return c == 'x' || c == 'X' + } + return t.skipRadix(isX, isHexDigit) +} + +func (t *tokenizer) skipRadix(pok, dok matcher) (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + if c == '-' { + c, err = t.read() + if err != nil { + return 0, err + } + } + + if c != '0' { + return 0, invalidChar(c) + } + if err = t.expect(pok); err != nil { + return 0, err + } + + for { + c, err = t.read() + if err != nil { + return 0, err + } + if !dok(c) { + break + } + } + + ok, err := t.isStopChar(c) + if err != nil { + return 0, err + } + if !ok { + return 0, invalidChar(c) + } + + return c, nil +} + +// SkipTimestamp skips a timestamp value, returning the next character. +func (t *tokenizer) skipTimestamp() (int, error) { + // Read the first four digits, yyyy. + c, err := t.skipTimestampDigits(4) + if err != nil { + return 0, err + } + if c == 'T' { + // yyyyT + return t.read() + } + if c != '-' { + return 0, invalidChar(c) + } + + // Read the next two, yyyy-mm. + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c == 'T' { + // yyyy-mmT + return t.read() + } + if c != '-' { + return 0, invalidChar(c) + } + + // Read the day. + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != 'T' { + // yyyy-mm-dd. + return t.skipTimestampFinish(c) + } + + c, err = t.read() + if err != nil { + return 0, err + } + if !isDigit(c) { + // yyyy-mm-ddT(+hh:mm)? + c, err = t.skipTimestampOffset(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) + } + + // Already read the first hour digit above. + c, err = t.skipTimestampDigits(1) + if err != nil { + return 0, err + } + if c != ':' { + return 0, invalidChar(c) + } + + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != ':' { + // yyyy-mm-ddThh:mmZ + c, err = t.skipTimestampOffsetOrZ(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) + } + + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != '.' { + // yyyy-mm-ddThh:mm:ssZ + c, err = t.skipTimestampOffsetOrZ(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) + } + + // yyyy-mm-ddThh:mm:ss.ssssZ + c, err = t.read() + if err != nil { + return 0, err + } + if isDigit(c) { + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + } + + c, err = t.skipTimestampOffsetOrZ(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) +} + +// SkipTimestampOffsetOrZ skips a (required) timestamp offset value or +// letter 'Z' (indicating UTC). +func (t *tokenizer) skipTimestampOffsetOrZ(c int) (int, error) { + if c == '-' || c == '+' { + return t.skipTimestampOffset(c) + } + if c == 'z' || c == 'Z' { + return t.read() + } + return 0, invalidChar(c) +} + +// SkipTimestampOffset skips an (optional) +-hh:mm timestamp zone offset +// value. +func (t *tokenizer) skipTimestampOffset(c int) (int, error) { + if c != '-' && c != '+' { + return c, nil + } + + c, err := t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != ':' { + return 0, invalidChar(c) + } + return t.skipTimestampDigits(2) +} + +// SkipTimestampDigits skips a bounded sequence of digits inside a +// timestamp. +func (t *tokenizer) skipTimestampDigits(n int) (int, error) { + for n > 0 { + if err := t.expect(func(c int) bool { + return isDigit(c) + }); err != nil { + return 0, err + } + n-- + } + + return t.read() +} + +// SkipTimestampFinish makes sure the character after a timestamp +// value is a valid ending point. If so, it returns it. +func (t *tokenizer) skipTimestampFinish(c int) (int, error) { + ok, err := t.isStopChar(c) + if err != nil { + return 0, err + } + if !ok { + return 0, invalidChar(c) + } + return c, nil +} + +// SkipSymbol skips a normal symbol and returns the next character. +func (t *tokenizer) skipSymbol() (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + for isIdentifierPart(c) { + c, err = t.read() + if err != nil { + return 0, err + } + } + + return c, nil +} + +// SkipSymbolQuoted skips a quoted symbol and returns the next char. +func (t *tokenizer) skipSymbolQuoted() (int, error) { + if err := t.skipSymbolQuotedHelper(); err != nil { + return 0, err + } + return t.read() +} + +// SkipSymbolQuotedHelper skips a quoted symbol. +func (t *tokenizer) skipSymbolQuotedHelper() error { + for { + c, err := t.read() + if err != nil { + return err + } + + switch c { + case -1, '\n': + return invalidChar(c) + + case '\'': + return nil + + case '\\': + if _, err := t.read(); err != nil { + return err + } + } + } +} + +// SkipSymbolOperator skips an operator-style symbol inside an sexp. +func (t *tokenizer) skipSymbolOperator() (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + for isOperatorChar(c) { + c, err = t.read() + if err != nil { + return 0, err + } + } + + return c, nil +} + +// SkipString skips over a "-enclosed string, returning the next char. +func (t *tokenizer) skipString() (int, error) { + if err := t.skipStringHelper(); err != nil { + return 0, err + } + return t.read() +} + +// SkipStringHelper skips over a "-enclosed string. +func (t *tokenizer) skipStringHelper() error { + for { + c, err := t.read() + if err != nil { + return err + } + + switch c { + case -1, '\n': + return invalidChar(c) + + case '"': + return nil + + case '\\': + if _, err := t.read(); err != nil { + return err + } + } + } +} + +// SkipLongString skips over a '''-enclosed string, returning the next +// character after the closing '''. +func (t *tokenizer) skipLongString() (int, error) { + if err := t.skipLongStringHelper(t.skipCommentsHandler); err != nil { + return 0, err + } + return t.read() +} + +// SkipLongStringHelper skips over a '''-enclosed string. +func (t *tokenizer) skipLongStringHelper(handler commentHandler) error { + for { + c, err := t.read() + if err != nil { + return err + } + + switch c { + case -1: + return invalidChar(c) + + case '\'': + ok, err := t.skipEndOfLongString(handler) + if err != nil { + return err + } + if ok { + return nil + } + + case '\\': + if _, err = t.read(); err != nil { + return err + } + } + } +} + +// SkipEndOfLongString is called after reading a ' to determine if we've +// hit the end of the long string.. +func (t *tokenizer) skipEndOfLongString(handler commentHandler) (bool, error) { + // We just read a ', check for two more ''s. + cs, err := t.peekN(2) + if err != nil && err != io.EOF { + return false, err + } + + // If it's not a triple-quote, keep going. + if len(cs) < 2 || cs[0] != '\'' || cs[1] != '\'' { + return false, nil + } + + // Consume the triple-quote. + if err := t.skipN(2); err != nil { + return false, err + } + + // Consume any additional whitespace/comments. + c, _, err := t.skipWhitespaceWith(handler) + if err != nil { + return false, err + } + + // Check if it's another triple-quote; if so, keep going. + if c == '\'' { + ok, err := t.isTripleQuote() + if err != nil { + return false, err + } + if ok { + return false, nil + } + } + + t.unread(c) + return true, nil +} + +// SkipBlob skips over a blob value, returning the next character. +func (t *tokenizer) skipBlob() (int, error) { + if err := t.skipBlobHelper(); err != nil { + return 0, err + } + return t.read() +} + +// SkipBlobHelper skips over a blob value, stopping after reading the +// final '}'. +func (t *tokenizer) skipBlobHelper() error { + c, _, err := t.skipLobWhitespace() + if err != nil { + return err + } + + // TODO: If this is a clob, could we potentially have an embedded + // '}' here? + for c != '}' { + c, _, err = t.skipLobWhitespace() + if err != nil { + return err + } + if c == -1 { + return invalidChar(c) + } + } + + return t.expect(func(c int) bool { + return c == '}' + }) +} + +func (t *tokenizer) skipStruct() (int, error) { + return t.skipContainer('}') +} + +func (t *tokenizer) skipSexp() (int, error) { + return t.skipContainer(')') +} + +// SkipList skips forward past a list that the caller doesn't care to +// step in to. +func (t *tokenizer) skipList() (int, error) { + return t.skipContainer(']') +} + +// SkipContainer skips a container terminated by the given char and +// returns the next character. +func (t *tokenizer) skipContainer(term int) (int, error) { + if err := t.skipContainerHelper(term); err != nil { + return 0, err + } + return t.read() +} + +// SkipContainerHelper skips over a container terminated by the given +// char. +func (t *tokenizer) skipContainerHelper(term int) error { + if term != ']' && term != ')' && term != '}' { + panic("wat") + } + + for { + c, _, err := t.skipWhitespace() + if err != nil { + return err + } + + switch c { + case -1: + return invalidChar(c) + + case term: + return nil + + case '"': + if err := t.skipStringHelper(); err != nil { + return err + } + + case '\'': + ok, err := t.isTripleQuote() + if err != nil { + return err + } + if ok { + if err = t.skipLongStringHelper(t.skipCommentsHandler); err != nil { + return err + } + } else { + if err = t.skipSymbolQuotedHelper(); err != nil { + return err + } + } + + case '(': + if err := t.skipContainerHelper(')'); err != nil { + return err + } + + case '[': + if err := t.skipContainerHelper(']'); err != nil { + return err + } + + case '{': + c, err := t.peek() + if err != nil { + return err + } + + if c == '{' { + if _, err := t.read(); err != nil { + return err + } + if err := t.skipBlobHelper(); err != nil { + return err + } + } else if c == '}' { + if _, err := t.read(); err != nil { + return err + } + } else { + if err := t.skipContainerHelper('}'); err != nil { + return err + } + } + } + } +} + +// SkipDigits skips a sequence of digits starting with the +// given character. +func (t *tokenizer) skipDigits(c int) (int, error) { + var err error + for err == nil && isDigit(c) { + c, err = t.read() + } + return c, err +} + +// SkipWhitespace skips whitespace (and comments) when we're out +// in normal parsing territory. +func (t *tokenizer) skipWhitespace() (int, bool, error) { + return t.skipWhitespaceWith(t.skipCommentsHandler) +} + +// SkipLobWhitespace skips whitespace when we're inside a large +// object ({{ ///= }} or {{ '''///=''' }}) where comments are +// not allowed. +func (t *tokenizer) skipLobWhitespace() (int, bool, error) { + // Comments are not allowed inside a lob value; if we see a '/', + // it's the start of a base64-encoded value. + return t.skipWhitespaceWith(stopForCommentsHandler) +} + +// CommentHandler is a strategy for handling comments. Returns true +// if it found and handled a comment, false if it didn't find a +// comment, and returns an error if it choked on the comment. +type commentHandler func() (bool, error) + +// SkipWhitespaceWith skips whitespace using the given strategy for +// handling comments--generally speaking, either skipping over them +// using skipCommentsHandler, or stopping with a stopForCommentsHandler. +// Returns the first non-whitespace character it reads, and whether it +// actually skipped anything to find it. +func (t *tokenizer) skipWhitespaceWith(handler commentHandler) (int, bool, error) { + skipped := false + for { + c, err := t.read() + if err != nil { + return 0, skipped, err + } + + switch c { + case ' ', '\t', '\n', '\r': + // Skipped. + + case '/': + comment, err := handler() + if err != nil { + return 0, skipped, err + } + if !comment { + return '/', skipped, nil + } + + default: + return c, skipped, nil + } + skipped = true + } +} + +// StopForCommentsHandler is a commentHandler that stops skipping +// whitespace when it finds a (potential) comment. Use it when you +// expect a '/' to be an actual '/', not a comment. +func stopForCommentsHandler() (bool, error) { + return false, nil +} + +// SkipCommentsHandler is a commentHandler that skips over any +// comments it finds. +func (t *tokenizer) skipCommentsHandler() (bool, error) { + // We've just read a '/', which might be the start of a comment. + // Peek ahead to see if it is, and if so skip over it. + c, err := t.peek() + if err != nil { + return false, err + } + + switch c { + case '/': + return true, t.skipSingleLineComment() + case '*': + return true, t.skipBlockComment() + default: + return false, nil + } +} + +// SkipSingleLineComment skips over the body of a single-line comment, +// terminated by the end of the line (or file). +func (t *tokenizer) skipSingleLineComment() error { + for { + c, err := t.read() + if err != nil { + return err + } + + if c == -1 || c == '\n' { + return nil + } + } +} + +// SkipBlockComment skips over the body of a block comment, terminated +// by a '*/' sequence. +func (t *tokenizer) skipBlockComment() error { + star := false + for { + c, err := t.read() + if err != nil { + return err + } + if c == -1 { + return invalidChar(c) + } + + if star && c == '/' { + return nil + } + + star = (c == '*') + } +} diff --git a/skipper_test.go b/skipper_test.go new file mode 100644 index 00000000..743dcddf --- /dev/null +++ b/skipper_test.go @@ -0,0 +1,178 @@ +package ion + +import ( + "testing" +) + +func TestSkipNumber(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipNumber) + + test("", -1) + test("0", -1) + test("-1234567890,", ',') + test("1.2 ", ' ') + test("1d45\n", '\n') + test("1.4e-12//", '/') + + testErr("1.2d3d", "unexpected char 'd'") +} + +func TestSkipBinary(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipBinary) + + test("0b0", -1) + test("-0b10 ", ' ') + test("0b010101,", ',') + + testErr("0b2", "unexpected char '2'") +} + +func TestSkipHex(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipHex) + + test("0x0", -1) + test("-0x0F ", ' ') + test("0x1234567890abcdefABCDEF,", ',') + + testErr("0x0G", "unexpected char 'G'") +} + +func TestSkipTimestamp(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipTimestamp) + + test("2001T", -1) + test("2001-01T,", ',') + test("2001-01-02}", '}') + test("2001-01-02T ", ' ') + test("2001-01-02T+00:00\t", '\t') + test("2001-01-02T-00:00\n", '\n') + test("2001-01-02T03:04+00:00 ", ' ') + test("2001-01-02T03:04-00:00 ", ' ') + test("2001-01-02T03:04Z ", ' ') + test("2001-01-02T03:04z ", ' ') + test("2001-01-02T03:04:05Z ", ' ') + test("2001-01-02T03:04:05+00:00 ", ' ') + test("2001-01-02T03:04:05.666Z ", ' ') + test("2001-01-02T03:04:05.666666z ", ' ') + + testErr("", "unexpected EOF") + testErr("2001", "unexpected EOF") + testErr("2001z", "unexpected char 'z'") + testErr("20011", "unexpected char '1'") + testErr("2001-0", "unexpected EOF") + testErr("2001-01", "unexpected EOF") + testErr("2001-01-02Tz", "unexpected char 'z'") + testErr("2001-01-02T03", "unexpected EOF") + testErr("2001-01-02T03z", "unexpected char 'z'") + testErr("2001-01-02T03:04x ", "unexpected char 'x'") + testErr("2001-01-02T03:04:05x ", "unexpected char 'x'") +} + +func TestSkipSymbol(t *testing.T) { + test, _ := testSkip(t, (*tokenizer).skipSymbol) + + test("f", -1) + test("foo:", ':') + test("foo,", ',') + test("foo ", ' ') + test("foo\n", '\n') + test("foo]", ']') + test("foo}", '}') + test("foo)", ')') + test("foo\\n", '\\') +} + +func TestSkipSymbolQuoted(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipSymbolQuoted) + + test("'", -1) + test("foo',", ',') + test("foo\\'bar':", ':') + test("foo\\\nbar',", ',') + + testErr("foo", "unexpected EOF") + testErr("foo\n", "unexpected char '\\n'") +} + +func TestSkipSymbolOperator(t *testing.T) { + test, _ := testSkip(t, (*tokenizer).skipSymbolOperator) + + test("+", -1) + test("++", -1) + test("+= ", ' ') + test("%b", 'b') +} + +func TestSkipString(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipString) + + test("\"", -1) + test("\",", ',') + test("foo\\\"bar\"], \"\"", ']') + test("foo\\\nbar\" \t\t\t", ' ') + + testErr("foobar", "unexpected EOF") + testErr("foobar\n", "unexpected char '\\n'") +} + +func TestSkipLongString(t *testing.T) { + test, _ := testSkip(t, (*tokenizer).skipLongString) + + test("'''", -1) + test("''',", ',') + test("abc''',", ',') + test("abc''' }", '}') + test("abc''' /*more*/ '''def'''\t//more\r\n]", ']') +} + +func TestSkipBlob(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipBlob) + + test("}}", -1) + test("oogboog}},{{}}", ',') + test("'''not encoded'''}}\n", '\n') + + testErr("", "unexpected EOF") + testErr("oogboog", "unexpected EOF") + testErr("oogboog}", "unexpected EOF") + testErr("oog}{boog", "unexpected char '{'") +} + +func TestSkipList(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipList) + + test("]", -1) + test("[]],", ',') + test("[123, \"]\", ']']] ", ' ') + + testErr("abc, def, ", "unexpected EOF") +} + +type skipFunc func(*tokenizer) (int, error) +type skipTestFunc func(string, int) +type skipTestErrFunc func(string, string) + +func testSkip(t *testing.T, f skipFunc) (skipTestFunc, skipTestErrFunc) { + test := func(str string, ec int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, err := f(tok) + if err != nil { + t.Fatal(err) + } + if c != ec { + t.Errorf("expected '%c', got '%c'", ec, c) + } + }) + } + testErr := func(str string, e string) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + _, err := f(tok) + if err == nil || err.Error() != e { + t.Errorf("expected err=%v, got err=%v", e, err) + } + }) + } + return test, testErr +} diff --git a/textutils.go b/textutils.go index 9606eeaf..03b9accd 100644 --- a/textutils.go +++ b/textutils.go @@ -14,12 +14,12 @@ func symbolNeedsQuoting(sym string) bool { return true } - if !isIdentifierStart(sym[0]) { + if !isIdentifierStart(int(sym[0])) { return true } for i := 1; i < len(sym); i++ { - if !isIdentifierPart(sym[i]) { + if !isIdentifierPart(int(sym[i])) { return true } } @@ -38,7 +38,7 @@ func isSymbolRef(sym string) bool { } for i := 1; i < len(sym); i++ { - if !isDigit(sym[i]) { + if !isDigit(int(sym[i])) { return false } } @@ -47,7 +47,7 @@ func isSymbolRef(sym string) bool { } // Is this a valid first character for an identifier? -func isIdentifierStart(c byte) bool { +func isIdentifierStart(c int) bool { if c >= 'a' && c <= 'z' { return true } @@ -61,15 +61,65 @@ func isIdentifierStart(c byte) bool { } // Is this a valid character for later in an identifier? -func isIdentifierPart(c byte) bool { +func isIdentifierPart(c int) bool { return isIdentifierStart(c) || isDigit(c) } +// Is this a valid hex digit? +func isHexDigit(c int) bool { + if isDigit(c) { + return true + } + if c >= 'a' && c <= 'f' { + return true + } + if c >= 'A' && c <= 'F' { + return true + } + return false +} + // Is this a digit? -func isDigit(c byte) bool { +func isDigit(c int) bool { return c >= '0' && c <= '9' } +// Is this a valid part of an operator symbol? +func isOperatorChar(c int) bool { + switch c { + case '!', '#', '%', '&', '*', '+', '-', '.', '/', ';', '<', '=': + return true + case '>', '?', '@', '^', '`', '|', '~': + return true + default: + return false + } +} + +// Does this character mark the end of a normal (unquoted) value? Does +// *not* check for the start of a comment, because that requires two +// characters. Use tokenizer.isStopChar(c) or check for it yourself. +func isStopChar(c int) bool { + switch c { + case -1, '{', '}', '[', ']', '(', ')', ',', '"', '\'': + return true + case ' ', '\t', '\n', '\r': + return true + default: + return false + } +} + +// Is this character whitespace? +func isWhitespace(c int) bool { + switch c { + case ' ', '\t', '\n', '\r': + return true + default: + return false + } +} + // Write the given symbol out, quoting and encoding if necessary. func writeSymbol(sym string, out io.Writer) error { if symbolNeedsQuoting(sym) { diff --git a/tokenizer.go b/tokenizer.go new file mode 100644 index 00000000..045cb44f --- /dev/null +++ b/tokenizer.go @@ -0,0 +1,547 @@ +package ion + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +type tokenType int + +const ( + tokenError tokenType = iota + + tokenEOF // End of input + + tokenNumeric // Haven't seen enough to know which, yet + tokenInt // [0-9]+ + tokenBinary // 0b[01]+ + tokenHex // 0x[0-9a-fA-F]+ + tokenDecimal // [0-9]+.[0-9]+d[0-9]+ + tokenFloat // [0-9]+.[0-9]+e[0-9]+ + tokenFloatInf // +inf + tokenFloatMinusInf // -inf + tokenTimestamp // 2001-01-01T00:00:00.000Z + + tokenSymbol // [a-zA-Z_]+ + tokenSymbolQuoted // '[^']+' + tokenSymbolOperator // +-/* + + tokenString // "[^"]+" + tokenLongString // '''[^']+''' + + tokenDot // . + tokenComma // , + tokenColon // : + tokenDoubleColon // :: + + tokenOpenParen // ( + tokenCloseParen // ) + tokenOpenBrace // { + tokenCloseBrace // } + tokenOpenBracket // [ + tokenCloseBracket // ] + tokenOpenDoubleBrace // {{ + tokenCloseDoubleBrace // }} +) + +func (t tokenType) String() string { + switch t { + case tokenError: + return "error" + case tokenEOF: + return "EOF" + case tokenNumeric: + return "numeric" + case tokenInt: + return "int" + case tokenBinary: + return "binary" + case tokenHex: + return "hex" + case tokenDecimal: + return "decimal" + case tokenFloat: + return "float" + case tokenFloatInf: + return "+inf" + case tokenFloatMinusInf: + return "-inf" + case tokenTimestamp: + return "timestamp" + case tokenSymbol: + return "symbol" + case tokenSymbolQuoted: + return "symbolQuoted" + case tokenSymbolOperator: + return "symbolOperator" + + case tokenString: + return "string" + case tokenLongString: + return "longstring" + + case tokenDot: + return "dot" + case tokenComma: + return "comma" + case tokenColon: + return "colon" + case tokenDoubleColon: + return "doublecolon" + + case tokenOpenParen: + return "openparen" + case tokenCloseParen: + return "closeparen" + + case tokenOpenBrace: + return "openbrace" + case tokenCloseBrace: + return "closebrace" + + case tokenOpenBracket: + return "openbracket" + case tokenCloseBracket: + return "closebracket" + case tokenOpenDoubleBrace: + return "opendoublebrace" + case tokenCloseDoubleBrace: + return "closedoublebrace" + + default: + return "" + } +} + +type tokenizer struct { + in *bufio.Reader + buffer []int + + token tokenType + unfinished bool +} + +func tokenizeString(in string) *tokenizer { + return tokenizeBytes([]byte(in)) +} + +func tokenizeBytes(in []byte) *tokenizer { + return tokenize(bytes.NewReader(in)) +} + +func tokenize(in io.Reader) *tokenizer { + return &tokenizer{ + in: bufio.NewReader(in), + } +} + +// Token returns the type of the current token. +func (t *tokenizer) Token() tokenType { + return t.token +} + +// Next advances to the next token in the input stream. +func (t *tokenizer) Next() error { + var c int + var err error + + if t.unfinished { + c, err = t.skipValue() + } else { + c, _, err = t.skipWhitespace() + } + + if err != nil { + return err + } + + switch { + case c == -1: + return t.finish(tokenEOF, true) + + case c == '/': + t.unread(c) + return t.finish(tokenSymbolOperator, true) + + case c == ':': + c2, err := t.peek() + if err != nil { + return err + } + if c2 == ':' { + t.read() + return t.finish(tokenDoubleColon, false) + } else { + return t.finish(tokenColon, false) + } + + case c == '{': + c2, err := t.peek() + if err != nil { + return err + } + if c2 == '{' { + t.read() + return t.finish(tokenOpenDoubleBrace, true) + } else { + return t.finish(tokenOpenBrace, true) + } + + case c == '}': + return t.finish(tokenCloseBrace, false) + + case c == '[': + return t.finish(tokenOpenBracket, true) + + case c == ']': + return t.finish(tokenCloseBracket, false) + + case c == '(': + return t.finish(tokenOpenParen, true) + + case c == ')': + return t.finish(tokenCloseParen, false) + + case c == ',': + return t.finish(tokenComma, false) + + case c == '.': + c2, err := t.peek() + if err != nil { + return err + } + if isOperatorChar(c2) { + t.unread(c) + return t.finish(tokenSymbolOperator, true) + } else { + return t.finish(tokenDot, false) + } + + case c == '\'': + ok, err := t.isTripleQuote() + if err != nil { + return err + } + if ok { + return t.finish(tokenLongString, true) + } else { + return t.finish(tokenSymbolQuoted, true) + } + + case c == '+': + ok, err := t.isInf(c) + if err != nil { + return err + } + if ok { + return t.finish(tokenFloatInf, false) + } else { + t.unread(c) + return t.finish(tokenSymbolOperator, true) + } + + case isOperatorChar(c): + t.unread(c) + return t.finish(tokenSymbolOperator, true) + + case c == '"': + return t.finish(tokenString, true) + + case isIdentifierStart(c): + t.unread(c) + return t.finish(tokenSymbol, true) + + case isDigit(c): + tt, err := t.scanForNumericType(c) + if err != nil { + return err + } + + t.unread(c) + return t.finish(tt, true) + + case c == '-': + c2, err := t.peek() + if err != nil { + return err + } + + if isDigit(c2) { + t.read() + tt, err := t.scanForNumericType(c2) + if err != nil { + return err + } + if tt == tokenTimestamp { + // can't have negative timestamps. + return invalidChar(c2) + } + t.unread(c2) + return t.finish(tt, true) + } + + ok, err := t.isInf(c) + if err != nil { + return err + } + if ok { + return t.finish(tokenFloatMinusInf, false) + } + + t.unread(c) + return t.finish(tokenSymbolOperator, true) + + default: + return invalidChar(c) + } +} + +func (t *tokenizer) finish(token tokenType, more bool) error { + t.token = token + t.unfinished = more + return nil +} + +// IsTripleQuote returns true if this is a triple-quote sequence ('''). +func (t *tokenizer) isTripleQuote() (bool, error) { + // We've just read a '\'', check if the next two are too. + cs, err := t.peekN(2) + if err == io.EOF { + return false, nil + } + if err != nil { + return false, err + } + + if cs[0] == '\'' && cs[1] == '\'' { + t.skipN(2) + return true, nil + } + + return false, nil +} + +// IsInf returns true if the given character begins a '+inf' or +// '-inf' keyword. +func (t *tokenizer) isInf(c int) (bool, error) { + if c != '+' && c != '-' { + return false, nil + } + + cs, err := t.peekN(5) + if err != nil && err != io.EOF { + return false, err + } + + if len(cs) < 3 || cs[0] != 'i' || cs[1] != 'n' || cs[2] != 'f' { + // Definitely not +-inf. + return false, nil + } + + if len(cs) == 3 || isStopChar(cs[3]) { + // Cleanly-terminated +-inf. + t.skipN(3) + return true, nil + } + + if cs[3] == '/' && len(cs) > 4 && (cs[4] == '/' || cs[4] == '*') { + t.skipN(3) + // +-inf followed immediately by a comment works too. + return true, nil + } + + return false, nil +} + +// ScanForNumericType attempts to determine what type of number we +// have by peeking at a fininte number of characters. We can rule +// out binary (0b...), hex (0x...), and timestamps (....-) via this +// method. There are a couple other cases where we *could* distinguish, +// but it's unclear that it's worth it. +func (t *tokenizer) scanForNumericType(c int) (tokenType, error) { + if !isDigit(c) { + panic("scanForNumericType with non-digit") + } + + cs, err := t.peekN(4) + if err != nil && err != io.EOF { + return tokenError, err + } + + if c == '0' && len(cs) > 0 { + switch { + case cs[0] == 'b' || cs[0] == 'B': + return tokenBinary, nil + + case cs[0] == 'x' || cs[0] == 'X': + return tokenHex, nil + } + } + + if len(cs) >= 4 { + if isDigit(cs[0]) && isDigit(cs[1]) && isDigit(cs[2]) { + if cs[3] == '-' || cs[3] == 'T' { + return tokenTimestamp, nil + } + } + } + + // Can't tell yet; wait until actually reading it to find out. + return tokenNumeric, nil +} + +// Is this character a valid way to end a 'normal' (unquoted) value? +// Peeks in case of '/', so don't call it with a character you've +// peeked. +func (t *tokenizer) isStopChar(c int) (bool, error) { + if isStopChar(c) { + return true, nil + } + + if c == '/' { + c2, err := t.peek() + if err != nil { + return false, err + } + if c2 == '/' || c2 == '*' { + // Comment, also all done. + return true, nil + } + } + + return false, nil +} + +type matcher func(int) bool + +// Expect reads a byte of input and asserts that it matches some +// condition, returning an error if it does not. +func (t *tokenizer) expect(f matcher) error { + c, err := t.read() + if err != nil { + return err + } + if !f(c) { + return invalidChar(c) + } + return nil +} + +// InvalidChar returns an error complaining that the given character was +// unexpected. +func invalidChar(c int) error { + if c == -1 { + return errors.New("unexpected EOF") + } + return fmt.Errorf("unexpected char %q", c) +} + +// SkipN skips over the next n bytes of input. Presumably you've +// already peeked at them, and decided they're not worth keeping. +func (t *tokenizer) skipN(n int) error { + for i := 0; i < n; i++ { + c, err := t.read() + if err != nil { + return err + } + if c == -1 { + break + } + } + return nil +} + +// PeekN peeks at the next n bytes of input. Unlike read/peek, does +// NOT return -1 to indicate EOF. If it cannot peek N bytes ahead +// because of an EOF (or other error), it returns the bytes it was +// able to peek at along with the error. +func (t *tokenizer) peekN(n int) ([]int, error) { + var ret []int + var err error + + // Read ahead. + for i := 0; i < n; i++ { + var c int + c, err = t.read() + if err != nil { + break + } + if c == -1 { + err = io.EOF + break + } + ret = append(ret, c) + } + + // Put back the ones we got. + if err == io.EOF { + t.unread(-1) + } + for i := len(ret) - 1; i >= 0; i-- { + t.unread(ret[i]) + } + + return ret, err +} + +// Peek at the next byte of input without removing it. Other conditions +// from Read all apply. +func (t *tokenizer) peek() (int, error) { + if len(t.buffer) > 0 { + // Short-circuit and peek from the buffer. + return t.buffer[len(t.buffer)-1], nil + } + + c, err := t.read() + if err != nil { + return 0, err + } + + t.unread(c) + return c, nil +} + +// Read reads a byte of input from the underlying reader. EOF is +// returned as (-1, nil) rather than (0, io.EOF), because I find it +// easier to reason about that way. Newlines are normalized to '\n'. +func (t *tokenizer) read() (int, error) { + if len(t.buffer) > 0 { + // We've already peeked ahead; read from our buffer. + c := t.buffer[len(t.buffer)-1] + t.buffer = t.buffer[:len(t.buffer)-1] + return c, nil + } + + c, err := t.in.ReadByte() + if err == io.EOF { + return -1, nil + } + if err != nil { + return 0, err + } + + // Normalize \r and \r\n to just \n. + if c == '\r' { + cs, err := t.in.Peek(1) + if err != nil && err != io.EOF { + // Not EOF, because we haven't dealt with the '\r' yet. + return 0, err + } + if len(cs) > 0 && cs[0] == '\n' { + // Skip over the '\n' as well. + t.in.ReadByte() + } + return '\n', nil + } + + return int(c), nil +} + +// Unread pushes a character (or -1) back into the input stream to +// be read again later. +func (t *tokenizer) unread(c int) { + t.buffer = append(t.buffer, c) +} diff --git a/tokenizer_test.go b/tokenizer_test.go new file mode 100644 index 00000000..37b9ff03 --- /dev/null +++ b/tokenizer_test.go @@ -0,0 +1,412 @@ +package ion + +import ( + "io" + "testing" +) + +func TestNext(t *testing.T) { + tok := tokenizeString("foo::'foo':[] 123, {})") + + next := func(tt tokenType) { + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != tt { + t.Fatalf("expected %v, got %v", tt, tok.Token()) + } + } + + next(tokenSymbol) + next(tokenDoubleColon) + next(tokenSymbolQuoted) + next(tokenColon) + next(tokenOpenBracket) + next(tokenNumeric) + next(tokenComma) + next(tokenOpenBrace) +} + +func TestIsTripleQuote(t *testing.T) { + test := func(str string, eok bool, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + + ok, err := tok.isTripleQuote() + if err != nil { + t.Fatal(err) + } + if ok != eok { + t.Errorf("expected ok=%v, got ok=%v", eok, ok) + } + + read(t, tok, next) + }) + } + + test("''string'''", true, 's') + test("'string'''", false, '\'') + test("'", false, '\'') + test("", false, -1) +} + +func TestIsInf(t *testing.T) { + test := func(str string, eok bool, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + + ok, err := tok.isInf(c) + if err != nil { + t.Fatal(err) + } + + if ok != eok { + t.Errorf("expected %v, got %v", eok, ok) + } + + c, err = tok.read() + if err != nil { + t.Fatal(err) + } + if c != next { + t.Errorf("expected '%c', got '%c'", next, c) + } + }) + } + + test("+inf", true, -1) + test("-inf", true, -1) + test("+inf ", true, ' ') + test("-inf\t", true, '\t') + test("-inf\n", true, '\n') + test("+inf,", true, ',') + test("-inf}", true, '}') + test("+inf)", true, ')') + test("-inf]", true, ']') + test("+inf//", true, '/') + test("+inf/*", true, '/') + + test("+inf/", false, 'i') + test("-inf/0", false, 'i') + test("+int", false, 'i') + test("-iot", false, 'i') + test("+unf", false, 'u') + test("_inf", false, 'i') + + test("-in", false, 'i') + test("+i", false, 'i') + test("+", false, -1) + test("-", false, -1) +} + +func TestScanForNumericType(t *testing.T) { + test := func(str string, ett tokenType) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + + tt, err := tok.scanForNumericType(c) + if err != nil { + t.Fatal(err) + } + if tt != ett { + t.Errorf("expected %v, got %v", ett, tt) + } + }) + } + + test("0b0101", tokenBinary) + test("0B", tokenBinary) + test("0xABCD", tokenHex) + test("0X", tokenHex) + test("0000-00-00", tokenTimestamp) + test("0000T", tokenTimestamp) + + test("0", tokenNumeric) + test("1b0101", tokenNumeric) + test("1B", tokenNumeric) + test("1x0101", tokenNumeric) + test("1X", tokenNumeric) + test("1234", tokenNumeric) + test("12345", tokenNumeric) + test("1,23T", tokenNumeric) + test("12,3T", tokenNumeric) + test("123,T", tokenNumeric) +} + +func TestSkipWhitespace(t *testing.T) { + test := func(str string, eok bool, ec int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, ok, err := tok.skipWhitespace() + if err != nil { + t.Fatal(err) + } + + if ok != eok { + t.Errorf("expected ok=%v, got ok=%v", eok, ok) + } + if c != ec { + t.Errorf("expected c='%c', got c='%c'", ec, c) + } + }) + } + + test("/ 0)", false, '/') + test("xyz_", false, 'x') + test(" / 0)", true, '/') + test(" xyz_", true, 'x') + test(" \t\r\n / 0)", true, '/') + test("\t\t // comment\t\r\n\t\t x", true, 'x') + test(" \r\n /* comment *//* \r\n comment */x", true, 'x') +} + +func TestSkipLobWhitespace(t *testing.T) { + test := func(str string, eok bool, ec int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, ok, err := tok.skipLobWhitespace() + if err != nil { + t.Fatal(err) + } + + if ok != eok { + t.Errorf("expected ok=%v, got ok=%v", eok, ok) + } + if c != ec { + t.Errorf("expected c='%c', got c='%c'", ec, c) + } + }) + } + + test("///=", false, '/') + test("xyz_", false, 'x') + test(" ///=", true, '/') + test(" xyz_", true, 'x') + test("\r\n\t///=", true, '/') + test("\r\n\txyz_", true, 'x') +} + +func TestSkipCommentsHandler(t *testing.T) { + t.Run("SingleLine", func(t *testing.T) { + tok := tokenizeString("/comment\nok") + ok, err := tok.skipCommentsHandler() + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected ok=true, got ok=false") + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) + }) + + t.Run("Block", func(t *testing.T) { + tok := tokenizeString("*comm\nent*/ok") + ok, err := tok.skipCommentsHandler() + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected ok=true, got ok=false") + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) + }) + + t.Run("FalseAlarm", func(t *testing.T) { + tok := tokenizeString(" 0)") + ok, err := tok.skipCommentsHandler() + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("expected ok=false, got ok=true") + } + + read(t, tok, ' ') + read(t, tok, '0') + read(t, tok, ')') + read(t, tok, -1) + }) +} + +func TestSkipSingleLineComment(t *testing.T) { + tok := tokenizeString("single-line comment\r\nok") + err := tok.skipSingleLineComment() + if err != nil { + t.Fatal(err) + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) +} + +func TestSkipSingleLineCommentOnLastLine(t *testing.T) { + tok := tokenizeString("single-line comment") + err := tok.skipSingleLineComment() + if err != nil { + t.Fatal(err) + } + + read(t, tok, -1) +} + +func TestSkipBlockComment(t *testing.T) { + tok := tokenizeString("this is/ a\nmulti-line /** comment.**/ok") + err := tok.skipBlockComment() + if err != nil { + t.Fatal(err) + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) +} + +func TestSkipInvalidBlockComment(t *testing.T) { + tok := tokenizeString("this is a comment that never ends") + err := tok.skipBlockComment() + if err == nil { + t.Error("did not fail on bad block comment") + } +} + +func TestPeekN(t *testing.T) { + tok := tokenizeString("abc\r\ndef") + + peekN(t, tok, 1, nil, 'a') + peekN(t, tok, 2, nil, 'a', 'b') + peekN(t, tok, 3, nil, 'a', 'b', 'c') + + read(t, tok, 'a') + read(t, tok, 'b') + + peekN(t, tok, 3, nil, 'c', '\n', 'd') + peekN(t, tok, 2, nil, 'c', '\n') + peekN(t, tok, 3, nil, 'c', '\n', 'd') + + read(t, tok, 'c') + read(t, tok, '\n') + read(t, tok, 'd') + + peekN(t, tok, 3, io.EOF, 'e', 'f') + peekN(t, tok, 3, io.EOF, 'e', 'f') + peekN(t, tok, 2, nil, 'e', 'f') + + read(t, tok, 'e') + read(t, tok, 'f') + read(t, tok, -1) + + peekN(t, tok, 10, io.EOF) +} + +func peekN(t *testing.T, tok *tokenizer, n int, ee error, ecs ...int) { + cs, err := tok.peekN(n) + if err != ee { + t.Fatalf("expected err=%v, got err=%v", ee, err) + } + if !equal(ecs, cs) { + t.Errorf("expected %v, got %v", ecs, cs) + } +} + +func equal(a, b []int) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true +} + +func TestPeek(t *testing.T) { + tok := tokenizeString("abc") + + peek(t, tok, 'a') + peek(t, tok, 'a') + read(t, tok, 'a') + + peek(t, tok, 'b') + tok.unread('a') + + peek(t, tok, 'a') + read(t, tok, 'a') + read(t, tok, 'b') + peek(t, tok, 'c') + peek(t, tok, 'c') + + read(t, tok, 'c') + peek(t, tok, -1) + peek(t, tok, -1) + read(t, tok, -1) +} + +func peek(t *testing.T, tok *tokenizer, expected int) { + c, err := tok.peek() + if err != nil { + t.Fatal(err) + } + if c != expected { + t.Errorf("expected %v, got %v", expected, c) + } +} + +func TestReadUnread(t *testing.T) { + tok := tokenizeString("abc\rd\ne\r\n") + + read(t, tok, 'a') + tok.unread('a') + + read(t, tok, 'a') + read(t, tok, 'b') + read(t, tok, 'c') + tok.unread('c') + tok.unread('b') + + read(t, tok, 'b') + read(t, tok, 'c') + read(t, tok, '\n') + tok.unread('\n') + + read(t, tok, '\n') + read(t, tok, 'd') + read(t, tok, '\n') + read(t, tok, 'e') + read(t, tok, '\n') + read(t, tok, -1) + + tok.unread(-1) + tok.unread('\n') + + read(t, tok, '\n') + read(t, tok, -1) + read(t, tok, -1) +} + +func read(t *testing.T, tok *tokenizer, expected int) { + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + if c != expected { + t.Errorf("expected %v, got %v", expected, c) + } +} From dcb9d9674be401128e2fb93c6964e21045dece50 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 24 Jun 2019 14:34:29 +0300 Subject: [PATCH 09/56] minor fixups --- decimal.go | 4 +++ symboltable.go | 70 ++++++++++++++++++++++++--------------------- symboltable_test.go | 2 +- textutils.go | 3 +- tokenizer.go | 17 ++++------- 5 files changed, 49 insertions(+), 47 deletions(-) diff --git a/decimal.go b/decimal.go index ed7ceded..ee4ad829 100644 --- a/decimal.go +++ b/decimal.go @@ -87,6 +87,7 @@ func ParseDecimal(in string) (*Decimal, error) { return NewDecimalWithScale(n, -shift), nil } +// Abs returns the absolute value of this Decimal. func (d *Decimal) Abs() *Decimal { return &Decimal{ n: new(big.Int).Abs(d.n), @@ -94,6 +95,7 @@ func (d *Decimal) Abs() *Decimal { } } +// Add returns the result of adding this Decimal to another Decimal. func (d *Decimal) Add(o *Decimal) *Decimal { // a*10^x + b*10^y = (a*10^(x-y) + b) * 10^y dd, oo := rescale(d, o) @@ -103,6 +105,7 @@ func (d *Decimal) Add(o *Decimal) *Decimal { } } +// Sub returns the result of substrating another Decimal from this Decimal. func (d *Decimal) Sub(o *Decimal) *Decimal { dd, oo := rescale(d, o) return &Decimal{ @@ -111,6 +114,7 @@ func (d *Decimal) Sub(o *Decimal) *Decimal { } } +// Neg returns the negative of this Decimal. func (d *Decimal) Neg() *Decimal { return &Decimal{ n: new(big.Int).Neg(d.n), diff --git a/symboltable.go b/symboltable.go index 2f6d5cab..f386c358 100644 --- a/symboltable.go +++ b/symboltable.go @@ -21,17 +21,22 @@ type SymbolTable interface { // A SharedSymbolTable is distributed out-of-band and referenced from // a LocalSymbolTable to save space. -type SharedSymbolTable struct { +type SharedSymbolTable interface { + SymbolTable + + Name() string + Version() int +} + +type sharedSymbolTable struct { name string version int symbols []string index map[string]int } -var _ SymbolTable = &SharedSymbolTable{} - // NewSharedSymbolTable creates a new shared symbol table. -func NewSharedSymbolTable(name string, version int, symbols []string) *SharedSymbolTable { +func NewSharedSymbolTable(name string, version int, symbols []string) SharedSymbolTable { if name == "" { panic("name must be non-empty") } @@ -41,7 +46,7 @@ func NewSharedSymbolTable(name string, version int, symbols []string) *SharedSym index, copy := buildIndex(symbols, 0) - return &SharedSymbolTable{ + return &sharedSymbolTable{ name: name, version: version, symbols: copy, @@ -63,31 +68,31 @@ func buildIndex(symbols []string, offset int) (map[string]int, []string) { return index, copy } -func (s *SharedSymbolTable) Name() string { +func (s *sharedSymbolTable) Name() string { return s.name } -func (s *SharedSymbolTable) Version() int { +func (s *sharedSymbolTable) Version() int { return s.version } -func (s *SharedSymbolTable) MaxID() int { +func (s *sharedSymbolTable) MaxID() int { return len(s.symbols) } -func (s *SharedSymbolTable) FindByName(sym string) (int, bool) { +func (s *sharedSymbolTable) FindByName(sym string) (int, bool) { id, ok := s.index[sym] return id, ok } -func (s *SharedSymbolTable) FindByID(id int) (string, bool) { +func (s *sharedSymbolTable) FindByID(id int) (string, bool) { if id <= 0 || id > len(s.symbols) { return "", false } return s.symbols[id-1], true } -func (s *SharedSymbolTable) WriteTo(w Writer) error { +func (s *sharedSymbolTable) WriteTo(w Writer) error { w.TypeAnnotation("$ion_shared_symbol_table") w.BeginStruct() @@ -110,7 +115,7 @@ func (s *SharedSymbolTable) WriteTo(w Writer) error { return w.Err() } -func (s *SharedSymbolTable) String() string { +func (s *sharedSymbolTable) String() string { buf := strings.Builder{} w := NewTextWriter(&buf) @@ -119,7 +124,7 @@ func (s *SharedSymbolTable) String() string { return buf.String() } -// The (implied) system symbol table for Ion v1.0. +// V1SystemSymbolTable is the (implied) system symbol table for Ion v1.0. var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ "$ion", "$ion_1_0", @@ -134,8 +139,8 @@ var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ // A LocalSymbolTable is transmitted in-band along with the binary data // it describes. It may include SharedSymbolTables by reference. -type LocalSymbolTable struct { - imports []*SharedSymbolTable +type localSymbolTable struct { + imports []SharedSymbolTable offsets []int maxImportID int @@ -143,14 +148,12 @@ type LocalSymbolTable struct { index map[string]int } -var _ SymbolTable = &LocalSymbolTable{} - // NewLocalSymbolTable creates a new local symbol table. -func NewLocalSymbolTable(imports []*SharedSymbolTable, symbols []string) *LocalSymbolTable { +func NewLocalSymbolTable(imports []SharedSymbolTable, symbols []string) SymbolTable { imps, offsets, maxID := processImports(imports) index, copy := buildIndex(symbols, maxID) - return &LocalSymbolTable{ + return &localSymbolTable{ imports: imps, offsets: offsets, maxImportID: maxID, @@ -159,8 +162,8 @@ func NewLocalSymbolTable(imports []*SharedSymbolTable, symbols []string) *LocalS } } -func processImports(imports []*SharedSymbolTable) ([]*SharedSymbolTable, []int, int) { - imps := append([]*SharedSymbolTable{}, imports...) +func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []int, int) { + imps := append([]SharedSymbolTable{}, imports...) // TODO: Automatically add V1SystemSymbolTable? @@ -174,11 +177,11 @@ func processImports(imports []*SharedSymbolTable) ([]*SharedSymbolTable, []int, return imps, offsets, maxID } -func (t *LocalSymbolTable) MaxID() int { +func (t *localSymbolTable) MaxID() int { return t.maxImportID + len(t.symbols) } -func (t *LocalSymbolTable) FindByName(s string) (int, bool) { +func (t *localSymbolTable) FindByName(s string) (int, bool) { for i, imp := range t.imports { if id, ok := imp.FindByName(s); ok { return t.offsets[i] + id, true @@ -192,7 +195,7 @@ func (t *LocalSymbolTable) FindByName(s string) (int, bool) { return 0, false } -func (t *LocalSymbolTable) FindByID(id int) (string, bool) { +func (t *localSymbolTable) FindByID(id int) (string, bool) { if id <= 0 { return "", false } @@ -209,7 +212,7 @@ func (t *LocalSymbolTable) FindByID(id int) (string, bool) { return "", false } -func (t *LocalSymbolTable) findByIDInImports(id int) (string, bool) { +func (t *localSymbolTable) findByIDInImports(id int) (string, bool) { i := 1 off := 0 @@ -223,7 +226,7 @@ func (t *LocalSymbolTable) findByIDInImports(id int) (string, bool) { return t.imports[i-1].FindByID(id - off) } -func (t *LocalSymbolTable) WriteTo(w Writer) error { +func (t *localSymbolTable) WriteTo(w Writer) error { w.TypeAnnotation("$ion_symbol_table") w.BeginStruct() @@ -262,7 +265,7 @@ func (t *LocalSymbolTable) WriteTo(w Writer) error { return w.Err() } -func (t *LocalSymbolTable) String() string { +func (t *localSymbolTable) String() string { buf := strings.Builder{} w := NewTextWriter(&buf) @@ -278,17 +281,18 @@ type SymbolTableBuilder interface { // Add adds a symbol to this symbol table. Add(symbol string) (int, bool) // Build creates an immutable local symbol table. - Build() *LocalSymbolTable + Build() SymbolTable } type symbolTableBuilder struct { - LocalSymbolTable + localSymbolTable } -func NewSymbolTableBuilder(imports ...*SharedSymbolTable) SymbolTableBuilder { +// NewSymbolTableBuilder creates a new symbol table builder with the given imports. +func NewSymbolTableBuilder(imports ...SharedSymbolTable) SymbolTableBuilder { imps, offsets, maxID := processImports(imports) return &symbolTableBuilder{ - LocalSymbolTable{ + localSymbolTable{ imports: imps, offsets: offsets, maxImportID: maxID, @@ -309,14 +313,14 @@ func (b *symbolTableBuilder) Add(symbol string) (int, bool) { return id, true } -func (b *symbolTableBuilder) Build() *LocalSymbolTable { +func (b *symbolTableBuilder) Build() SymbolTable { symbols := append([]string{}, b.symbols...) index := make(map[string]int) for s, i := range b.index { index[s] = i } - return &LocalSymbolTable{ + return &localSymbolTable{ imports: b.imports, offsets: b.offsets, maxImportID: b.maxImportID, diff --git a/symboltable_test.go b/symboltable_test.go index 20fd2320..bcadca36 100644 --- a/symboltable_test.go +++ b/symboltable_test.go @@ -58,7 +58,7 @@ func TestLocalSymbolTable(t *testing.T) { } func TestLocalSymbolTableWithImports(t *testing.T) { - imports := []*SharedSymbolTable{V1SystemSymbolTable} + imports := []SharedSymbolTable{V1SystemSymbolTable} st := NewLocalSymbolTable(imports, []string{ "foo", "bar", diff --git a/textutils.go b/textutils.go index 03b9accd..1c3d11a6 100644 --- a/textutils.go +++ b/textutils.go @@ -130,9 +130,8 @@ func writeSymbol(sym string, out io.Writer) error { return err } return writeRawChar('\'', out) - } else { - return writeRawString(sym, out) } + return writeRawString(sym, out) } // Write the given symbol out, escaping any characters that need escaping. diff --git a/tokenizer.go b/tokenizer.go index 045cb44f..9aaff905 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -174,9 +174,8 @@ func (t *tokenizer) Next() error { if c2 == ':' { t.read() return t.finish(tokenDoubleColon, false) - } else { - return t.finish(tokenColon, false) } + return t.finish(tokenColon, false) case c == '{': c2, err := t.peek() @@ -186,9 +185,8 @@ func (t *tokenizer) Next() error { if c2 == '{' { t.read() return t.finish(tokenOpenDoubleBrace, true) - } else { - return t.finish(tokenOpenBrace, true) } + return t.finish(tokenOpenBrace, true) case c == '}': return t.finish(tokenCloseBrace, false) @@ -216,9 +214,8 @@ func (t *tokenizer) Next() error { if isOperatorChar(c2) { t.unread(c) return t.finish(tokenSymbolOperator, true) - } else { - return t.finish(tokenDot, false) } + return t.finish(tokenDot, false) case c == '\'': ok, err := t.isTripleQuote() @@ -227,9 +224,8 @@ func (t *tokenizer) Next() error { } if ok { return t.finish(tokenLongString, true) - } else { - return t.finish(tokenSymbolQuoted, true) } + return t.finish(tokenSymbolQuoted, true) case c == '+': ok, err := t.isInf(c) @@ -238,10 +234,9 @@ func (t *tokenizer) Next() error { } if ok { return t.finish(tokenFloatInf, false) - } else { - t.unread(c) - return t.finish(tokenSymbolOperator, true) } + t.unread(c) + return t.finish(tokenSymbolOperator, true) case isOperatorChar(c): t.unread(c) From f1d6578e9c29ed7df79019cc46fe97cfdd96eb56 Mon Sep 17 00:00:00 2001 From: David Murray Date: Thu, 27 Jun 2019 00:11:55 -0700 Subject: [PATCH 10/56] Initial textreader impl --- api.go | 45 +++++- ctx.go | 36 +++++ skipper.go | 58 ++++++- textreader.go | 391 +++++++++++++++++++++++++++++++++++++++++++++ textreader_test.go | 55 +++++++ textutils.go | 13 ++ textwriter.go | 26 +-- tokenizer.go | 212 ++++++++++++++++++++---- tokenizer_test.go | 127 +++++++++++++-- writer.go | 42 +---- 10 files changed, 905 insertions(+), 100 deletions(-) create mode 100644 ctx.go create mode 100644 textreader.go create mode 100644 textreader_test.go diff --git a/api.go b/api.go index 22676803..059a1efa 100644 --- a/api.go +++ b/api.go @@ -1,6 +1,7 @@ package ion import ( + "fmt" "math/big" "time" ) @@ -39,20 +40,56 @@ const ( SexpType ) +func (t Type) String() string { + switch t { + case NoType: + return "" + case NullType: + return "null" + case BoolType: + return "bool" + case IntType: + return "int" + case FloatType: + return "float" + case DecimalType: + return "decimal" + case TimestampType: + return "timestamp" + case StringType: + return "string" + case SymbolType: + return "symbol" + case BlobType: + return "blob" + case ClobType: + return "clob" + case StructType: + return "struct" + case ListType: + return "list" + case SexpType: + return "sexp" + default: + return fmt.Sprintf("", uint8(t)) + } +} + // A Reader reads Ion values from an input stream. type Reader interface { SymbolTable() SymbolTable - Next() (Type, error) + Next() bool Type() Type + Err() error + + FieldName() string + TypeAnnotations() []string IsNull() bool StepIn() error StepOut() error - FieldName() (string, error) - TypeAnnotations() ([]string, error) - BoolValue() (bool, error) IntValue() (int, error) Int64Value() (int64, error) diff --git a/ctx.go b/ctx.go new file mode 100644 index 00000000..f92b2077 --- /dev/null +++ b/ctx.go @@ -0,0 +1,36 @@ +package ion + +type ctxType byte + +const ( + ctxAtTopLevel ctxType = iota + ctxInStruct + ctxInList + ctxInSexp +) + +// ctx is a context stack. +type ctx struct { + stack []ctxType +} + +// peek returns the current context +func (c *ctx) peek() ctxType { + if len(c.stack) == 0 { + return ctxAtTopLevel + } + return c.stack[len(c.stack)-1] +} + +// push pushes a new context onto the stack. +func (c *ctx) push(ctx ctxType) { + c.stack = append(c.stack, ctx) +} + +// pop pops the top context off the stack. +func (c *ctx) pop() { + if len(c.stack) == 0 { + panic("pop called at top level") + } + c.stack = c.stack[:len(c.stack)-1] +} diff --git a/skipper.go b/skipper.go index f9e8722c..a1d3c20d 100644 --- a/skipper.go +++ b/skipper.go @@ -5,6 +5,20 @@ import ( "io" ) +// FinishValue skips to the end of the current value if (and only if) +// we're currently in the middle of reading it. +func (t *tokenizer) finishValue() error { + if t.unfinished { + c, err := t.skipValue() + if err != nil { + return err + } + t.unread(c) + t.unfinished = false + } + return nil +} + // SkipValue skips to the end of the current value, if the caller // didn't bother to consume it before calling Next again. func (t *tokenizer) skipValue() (int, error) { @@ -12,7 +26,7 @@ func (t *tokenizer) skipValue() (int, error) { var err error switch t.token { - case tokenNumeric, tokenInt, tokenDecimal, tokenFloat: + case tokenNumber: c, err = t.skipNumber() case tokenBinary: c, err = t.skipBinary() @@ -547,16 +561,28 @@ func (t *tokenizer) skipStruct() (int, error) { return t.skipContainer('}') } +func (t *tokenizer) skipStructHelper() error { + return t.skipContainerHelper('}') +} + func (t *tokenizer) skipSexp() (int, error) { return t.skipContainer(')') } +func (t *tokenizer) skipSexpHelper() error { + return t.skipContainerHelper(')') +} + // SkipList skips forward past a list that the caller doesn't care to // step in to. func (t *tokenizer) skipList() (int, error) { return t.skipContainer(']') } +func (t *tokenizer) skipListHelper() error { + return t.skipContainerHelper(']') +} + // SkipContainer skips a container terminated by the given char and // returns the next character. func (t *tokenizer) skipContainer(term int) (int, error) { @@ -767,3 +793,33 @@ func (t *tokenizer) skipBlockComment() error { star = (c == '*') } } + +// Peeks ahead to see if the next token is a double colon, and +// if so skips it. If not, leaves the next token unconsumed. +func (t *tokenizer) skipDoubleColon() (bool, error) { + // Read whitespace and first non-whitespace char. + c, _, err := t.skipWhitespace() + if err != nil { + return false, err + } + if c != ':' { + // Not followed by a double-colon; put it back. + t.unread(c) + return false, nil + } + + // Peek to see if it's a double colon. + c, err = t.peek() + if err != nil { + return false, err + } + if c != ':' { + // Nope; put back the first ':'. + t.unread(':') + return false, nil + } + + // Yep; eat it and return true. + t.read() + return true, nil +} diff --git a/textreader.go b/textreader.go new file mode 100644 index 00000000..98e0fd68 --- /dev/null +++ b/textreader.go @@ -0,0 +1,391 @@ +package ion + +import ( + "bufio" + "errors" + "fmt" + "io" + "math" + "math/big" + "time" +) + +type textReaderState uint8 + +const ( + trsDone textReaderState = iota + trsBeforeFieldName + trsBeforeTypeAnnotations + trsBeforeScalar + trsBeforeContainer + trsInValue + trsAfterValue +) + +type textReader struct { + tok tokenizer + state textReaderState + ctx ctx + eof bool + err error + + fieldName string + typeAnnotations []string + valueType Type + value interface{} +} + +// NewTextReader creates a new text reader. +func NewTextReader(in io.Reader) Reader { + return &textReader{ + tok: tokenizer{ + in: bufio.NewReader(in), + }, + state: trsBeforeTypeAnnotations, + } +} + +func (t *textReader) SymbolTable() SymbolTable { + // Text content doesn't have a symbol table. + return nil +} + +func (t *textReader) Next() bool { + if t.state == trsDone || t.eof { + return false + } + + err := t.finishValue() + if err != nil { + t.explode(err) + return false + } + + t.fieldName = "" + t.typeAnnotations = nil + t.valueType = NoType + t.value = nil + + if err := t.tok.Next(); err != nil { + t.explode(err) + return false + } + + for { + var f func() (bool, error) + + switch t.state { + case trsAfterValue: + f = t.nextAfterValue + case trsBeforeFieldName: + f = t.nextBeforeFieldName + case trsBeforeTypeAnnotations: + f = t.nextBeforeTypeAnnotations + default: + panic("invalid state") + } + + done, err := f() + if err != nil { + t.explode(err) + return false + } + if done { + return !t.eof + } + + if err := t.tok.Next(); err != nil { + t.explode(err) + return false + } + } +} + +func (t *textReader) nextAfterValue() (bool, error) { + tok := t.tok.Token() + switch tok { + case tokenComma: + // Another value coming; eat the comma and move to the + // appropriate next state. + switch t.ctx.peek() { + case ctxInStruct: + t.state = trsBeforeFieldName + case ctxInList: + t.state = trsBeforeTypeAnnotations + default: + panic("invalid state") + } + return false, nil + + case tokenCloseBrace: + // No more values in this struct. + if t.ctx.peek() == ctxInStruct { + t.eof = true + return true, nil + } + return false, errors.New("unexpected token '}'") + + case tokenCloseBracket: + // No more values in this list. + if t.ctx.peek() == ctxInList { + t.eof = true + return true, nil + } + return false, errors.New("unexpected token ']'") + + default: + return false, fmt.Errorf("unexpected token '%v'", tok) + } +} + +func (t *textReader) nextBeforeFieldName() (bool, error) { + tok := t.tok.Token() + switch tok { + case tokenCloseBrace: + // No more values in this struct. + t.eof = true + return true, nil + + case tokenSymbol, tokenSymbolQuoted: + // Read the field name. + val, err := t.tok.ReadValue(tok) + if err != nil { + return false, err + } + + // Skip over the following colon. + if err = t.tok.Next(); err != nil { + return false, err + } + if tok = t.tok.Token(); tok != tokenColon { + return false, fmt.Errorf("unexpected token '%v'", tok) + } + + t.fieldName = val + t.state = trsBeforeTypeAnnotations + + return false, nil + + default: + return false, fmt.Errorf("unexpected token '%v'", tok) + } +} + +func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { + tok := t.tok.Token() + switch tok { + case tokenEOF: + if t.ctx.peek() == ctxAtTopLevel { + t.eof = true + return true, nil + } + return false, errors.New("unexpected EOF") + + case tokenSymbol, tokenSymbolQuoted: + val, err := t.tok.ReadValue(tok) + if err != nil { + return false, err + } + + ok, err := t.tok.skipDoubleColon() + if err != nil { + return false, err + } + if ok { + // val was a type annotation; remember it and keep going. + t.typeAnnotations = append(t.typeAnnotations, val) + return false, nil + } + + // val was a legit symbol value. + t.onSymbol(val, tok) + return true, nil + + default: + return false, fmt.Errorf("unexpected token '%v'", tok) + } +} + +func (t *textReader) onSymbol(val string, tok tokenType) { + valueType := SymbolType + var value interface{} = val + + if tok == tokenSymbol { + switch val { + case "null": + // TODO: Deal with potential '.type'. + valueType = NullType + value = nil + + case "true": + valueType = BoolType + value = true + + case "false": + valueType = BoolType + value = false + + case "nan": + valueType = FloatType + value = math.NaN() + } + } + + t.state = t.stateAfterValue() + t.valueType = valueType + t.value = value +} + +func (t *textReader) Type() Type { + return t.valueType +} + +func (t *textReader) Err() error { + return t.err +} + +func (t *textReader) FieldName() string { + return t.fieldName +} + +func (t *textReader) TypeAnnotations() []string { + return t.typeAnnotations +} + +func (t *textReader) IsNull() bool { + return false +} + +func (t *textReader) StepIn() error { + if t.state != trsBeforeContainer { + return errors.New("invalid state") + } + + var ctx ctxType + switch t.valueType { + case StructType: + ctx = ctxInStruct + case ListType: + ctx = ctxInList + case SexpType: + ctx = ctxInSexp + default: + panic("trsBeforeContainer with unexpected valueType") + } + t.ctx.push(ctx) + + if ctx == ctxInStruct { + t.state = trsBeforeFieldName + } else { + t.state = trsBeforeTypeAnnotations + } + + return nil +} + +func (t *textReader) StepOut() error { + ctx := t.ctx.peek() + if ctx == ctxAtTopLevel { + return errors.New("invalid state") + } + + err := t.tok.finishValue() + if err != nil { + t.explode(err) + return err + } + + switch t.ctx.peek() { + case ctxInStruct: + err = t.tok.skipStructHelper() + case ctxInList: + err = t.tok.skipListHelper() + case ctxInSexp: + err = t.tok.skipSexpHelper() + default: + panic("invalid ctx") + } + + if err != nil { + t.explode(err) + return err + } + + t.ctx.pop() + t.state = trsAfterValue + t.valueType = NoType + t.value = nil + + return nil +} + +func (t *textReader) BoolValue() (bool, error) { + return false, errors.New("not implemented yet") +} + +func (t *textReader) IntValue() (int, error) { + return 0, errors.New("not implemented yet") +} + +func (t *textReader) Int64Value() (int64, error) { + return 0, errors.New("not implemented yet") +} + +func (t *textReader) BigIntValue() (*big.Int, error) { + return nil, errors.New("not implemented yet") +} + +func (t *textReader) FloatValue() (float64, error) { + return 0.0, errors.New("not implemented yet") +} + +func (t *textReader) DecimalValue() (*Decimal, error) { + return nil, errors.New("not implemented yet") +} + +func (t *textReader) TimeValue() (time.Time, error) { + return time.Time{}, errors.New("not implemented yet") +} + +func (t *textReader) StringValue() (string, error) { + switch t.valueType { + case StringType, SymbolType: + return t.value.(string), nil + + default: + return "", errors.New("value is not a string") + } +} + +func (t *textReader) ByteValue() ([]byte, error) { + return nil, errors.New("not implemented yet") +} + +// FinishValue finishes reading the current value, if there is one. +func (t *textReader) finishValue() error { + err := t.tok.finishValue() + if err != nil { + return err + } + + t.state = t.stateAfterValue() + return nil +} + +func (t *textReader) stateAfterValue() textReaderState { + switch t.ctx.peek() { + case ctxInList, ctxInStruct: + return trsAfterValue + case ctxInSexp, ctxAtTopLevel: + return trsBeforeTypeAnnotations + default: + panic("invalid ctx") + } +} + +// Explode explodes the reader state when something unexpected +// happens and further calls to Next are a bad idea. +func (t *textReader) explode(err error) { + t.state = trsDone + t.err = err +} diff --git a/textreader_test.go b/textreader_test.go new file mode 100644 index 00000000..e496d59d --- /dev/null +++ b/textreader_test.go @@ -0,0 +1,55 @@ +package ion + +import ( + "strings" + "testing" +) + +func TestSymbols(t *testing.T) { + r := NewTextReader(strings.NewReader("'null'::foo bar a::b::'baz'")) + + test := func(etas []string, eval string) { + if !r.Next() { + t.Fatal("next returned false") + } + + if r.Type() != SymbolType { + t.Fatalf("expected type=symbol, got type=%v", r.Type()) + } + + if !strequals(r.TypeAnnotations(), etas) { + t.Errorf("expected tas=%v, got tas=%v", etas, r.TypeAnnotations()) + } + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected val=%v, got val=%v", eval, val) + } + } + + test([]string{"null"}, "foo") + test([]string{}, "bar") + test([]string{"a", "b"}, "baz") + + if r.Next() { + t.Errorf("next unexpectedly returned true") + } +} + +func strequals(a, b []string) bool { + if len(a) != len(b) { + return false + } + + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + + return true +} diff --git a/textutils.go b/textutils.go index 1c3d11a6..44c46b77 100644 --- a/textutils.go +++ b/textutils.go @@ -168,6 +168,19 @@ func writeEscapedString(str string, out io.Writer) error { return nil } +func fromHex(c int) (int, error) { + if c >= '0' && c <= '9' { + return c - '0', nil + } + if c >= 'a' && c <= 'f' { + return 10 + (c - 'a'), nil + } + if c >= 'A' && c <= 'F' { + return 10 + (c - 'A'), nil + } + return 0, invalidChar(c) +} + var hexChars = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'} // Write out the given character in escaped form. diff --git a/textwriter.go b/textwriter.go index aa1d489e..ea1173e5 100644 --- a/textwriter.go +++ b/textwriter.go @@ -32,10 +32,10 @@ func NewTextWriter(out io.Writer) Writer { func (w *textWriter) beginValue() error { if w.needsSeparator { var sep byte - switch w.ctx() { - case inStructCtx, inListCtx: + switch w.ctx.peek() { + case ctxInStruct, ctxInList: sep = ',' - case inSexpCtx: + case ctxInSexp: sep = ' ' default: sep = '\n' @@ -89,7 +89,7 @@ func (w *textWriter) begin(t ctxType, c byte) error { return err } - w.push(t) + w.ctx.push(t) w.needsSeparator = false return writeRawChar(c, w.out) @@ -97,7 +97,7 @@ func (w *textWriter) begin(t ctxType, c byte) error { // end finishes writing a container of the given type func (w *textWriter) end(t ctxType, c byte) error { - if w.ctx() != t { + if w.ctx.peek() != t { return errors.New("not in that kind of container") } @@ -107,7 +107,7 @@ func (w *textWriter) end(t ctxType, c byte) error { w.fieldName = "" w.typeAnnotations = nil - w.pop() + w.ctx.pop() w.endValue() return nil @@ -118,7 +118,7 @@ func (w *textWriter) BeginStruct() { if w.err != nil { return } - w.err = w.begin(inStructCtx, '{') + w.err = w.begin(ctxInStruct, '{') } // EndStruct finishes writing a struct. @@ -126,7 +126,7 @@ func (w *textWriter) EndStruct() { if w.err != nil { return } - w.err = w.end(inStructCtx, '}') + w.err = w.end(ctxInStruct, '}') } // BeginList begins writing a list. @@ -134,7 +134,7 @@ func (w *textWriter) BeginList() { if w.err != nil { return } - w.err = w.begin(inListCtx, '[') + w.err = w.begin(ctxInList, '[') } // EndList finishes writing a list. @@ -142,7 +142,7 @@ func (w *textWriter) EndList() { if w.err != nil { return } - w.err = w.end(inListCtx, ']') + w.err = w.end(ctxInList, ']') } // BeginSexp begins writing an s-expression. @@ -150,7 +150,7 @@ func (w *textWriter) BeginSexp() { if w.err != nil { return } - w.err = w.begin(inSexpCtx, '(') + w.err = w.begin(ctxInSexp, '(') } // EndSexp finishes writing an s-expression. @@ -158,7 +158,7 @@ func (w *textWriter) EndSexp() { if w.err != nil { return } - w.err = w.end(inSexpCtx, ')') + w.err = w.end(ctxInSexp, ')') } // writeValue writes a value whose raw encoding is produced by the @@ -397,7 +397,7 @@ func (w *textWriter) Finish() error { if w.err != nil { return w.err } - if w.ctx() != atTopLevelCtx { + if w.ctx.peek() != ctxAtTopLevel { w.err = errors.New("not at top level") return w.err } diff --git a/tokenizer.go b/tokenizer.go index 9aaff905..3d0e23d0 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "strings" ) type tokenType int @@ -15,12 +16,9 @@ const ( tokenEOF // End of input - tokenNumeric // Haven't seen enough to know which, yet - tokenInt // [0-9]+ + tokenNumber // Haven't seen enough to know which, yet tokenBinary // 0b[01]+ tokenHex // 0x[0-9a-fA-F]+ - tokenDecimal // [0-9]+.[0-9]+d[0-9]+ - tokenFloat // [0-9]+.[0-9]+e[0-9]+ tokenFloatInf // +inf tokenFloatMinusInf // -inf tokenTimestamp // 2001-01-01T00:00:00.000Z @@ -50,66 +48,61 @@ const ( func (t tokenType) String() string { switch t { case tokenError: - return "error" + return "" case tokenEOF: - return "EOF" - case tokenNumeric: - return "numeric" - case tokenInt: - return "int" + return "" + case tokenNumber: + return "" case tokenBinary: - return "binary" + return "" case tokenHex: - return "hex" - case tokenDecimal: - return "decimal" - case tokenFloat: - return "float" + return "" case tokenFloatInf: return "+inf" case tokenFloatMinusInf: return "-inf" case tokenTimestamp: - return "timestamp" + return "" case tokenSymbol: - return "symbol" + return "" case tokenSymbolQuoted: - return "symbolQuoted" + return "" case tokenSymbolOperator: - return "symbolOperator" + return "" case tokenString: - return "string" + return "" case tokenLongString: - return "longstring" + return "" case tokenDot: - return "dot" + return "." case tokenComma: - return "comma" + return "," case tokenColon: - return "colon" + return ":" case tokenDoubleColon: - return "doublecolon" + return "::" case tokenOpenParen: - return "openparen" + return "(" case tokenCloseParen: - return "closeparen" + return ")" case tokenOpenBrace: - return "openbrace" + return "{" case tokenCloseBrace: - return "closebrace" + return "}" case tokenOpenBracket: - return "openbracket" + return "[" case tokenCloseBracket: - return "closebracket" + return "]" + case tokenOpenDoubleBrace: - return "opendoublebrace" + return "{{" case tokenCloseDoubleBrace: - return "closedoublebrace" + return "}}" default: return "" @@ -300,6 +293,155 @@ func (t *tokenizer) finish(token tokenType, more bool) error { return nil } +// ReadValue reads the value of a token of the given type. +func (t *tokenizer) ReadValue(tok tokenType) (string, error) { + var str string + var err error + + switch tok { + case tokenSymbol: + str, err = t.readSymbol() + case tokenSymbolQuoted: + str, err = t.readQuotedSymbol() + default: + panic("unsupported token type") + } + + if err != nil { + return "", err + } + + t.unfinished = false + return str, nil +} + +// ReadSymbol reads an unquoted symbol value. +func (t *tokenizer) readSymbol() (string, error) { + ret := strings.Builder{} + + c, err := t.peek() + if err != nil { + return "", err + } + + for isIdentifierPart(c) { + ret.WriteByte(byte(c)) + t.read() + c, err = t.peek() + if err != nil { + return "", err + } + } + + return ret.String(), nil +} + +// ReadQuotedSymbol reads a quoted symbol. +func (t *tokenizer) readQuotedSymbol() (string, error) { + ret := strings.Builder{} + + for { + c, err := t.read() + if err != nil { + return "", err + } + + switch c { + case -1, '\n': + return "", invalidChar(c) + + case '\'': + return ret.String(), nil + + case '\\': + c, err = t.peek() + if err != nil { + return "", err + } + + if c == '\n' { + t.read() + continue + } + + r, err := t.readEscapedChar(false) + if err != nil { + return "", err + } + ret.WriteRune(r) + + default: + ret.WriteByte(byte(c)) + } + } +} + +// ReadEscapedChar reads an escaped character. +func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { + // We just read the '\', grab the next char. + c, err := t.read() + if err != nil { + return 0, err + } + + switch c { + case '0': + return '\x00', nil + case 'a': + return '\a', nil + case 'b': + return '\b', nil + case 't': + return '\t', nil + case 'n': + return '\n', nil + case 'f': + return '\f', nil + case 'r': + return '\r', nil + case 'v': + return '\v', nil + case '\'': + return '\'', nil + case '"': + return '"', nil + case '\\': + return '\\', nil + case 'U': + if clob { + return 0, invalidChar('U') + } + return t.readHexEscapeSeq(8) + case 'u': + return t.readHexEscapeSeq(4) + case 'x': + return t.readHexEscapeSeq(2) + } + + return 0, fmt.Errorf("bad escape sequence '\\%q'", c) +} + +func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { + val := rune(0) + + for len > 0 { + c, err := t.read() + if err != nil { + return 0, err + } + + d, err := fromHex(c) + if err != nil { + return 0, err + } + + val = (val << 4) | rune(d) + len-- + } + + return val, nil +} + // IsTripleQuote returns true if this is a triple-quote sequence ('''). func (t *tokenizer) isTripleQuote() (bool, error) { // We've just read a '\'', check if the next two are too. @@ -385,7 +527,7 @@ func (t *tokenizer) scanForNumericType(c int) (tokenType, error) { } // Can't tell yet; wait until actually reading it to find out. - return tokenNumeric, nil + return tokenNumber, nil } // Is this character a valid way to end a 'normal' (unquoted) value? diff --git a/tokenizer_test.go b/tokenizer_test.go index 37b9ff03..09224518 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -22,11 +22,116 @@ func TestNext(t *testing.T) { next(tokenSymbolQuoted) next(tokenColon) next(tokenOpenBracket) - next(tokenNumeric) + next(tokenNumber) next(tokenComma) next(tokenOpenBrace) } +func TestReadSymbol(t *testing.T) { + test := func(str string, expected string, next tokenType) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + if err := tok.Next(); err != nil { + t.Fatal(err) + } + + if tok.Token() != tokenSymbol { + t.Fatal("not a symbol") + } + + actual, err := tok.readSymbol() + if err != nil { + t.Fatal(err) + } + + if actual != expected { + t.Errorf("expected '%v', got '%v'", expected, actual) + } + + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != next { + t.Errorf("expected next=%v, got next=%v", next, tok.Token()) + } + }) + } + + test("a", "a", tokenEOF) + test("abc", "abc", tokenEOF) + test("null +inf", "null", tokenFloatInf) + test("false,", "false", tokenComma) + test("nan]", "nan", tokenCloseBracket) +} + +func TestReadSymbols(t *testing.T) { + tok := tokenizeString("foo bar baz beep boop null") + expected := []string{"foo", "bar", "baz", "beep", "boop", "null"} + + for i := 0; i < len(expected); i++ { + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != tokenSymbol { + t.Fatalf("expected %v, got %v", tokenSymbol, tok.Token()) + } + + val, err := tok.readSymbol() + if err != nil { + t.Fatal(err) + } + + if val != expected[i] { + t.Errorf("expected %v, got %v", expected[i], val) + } + } +} + +func TestReadQuotedSymbol(t *testing.T) { + test := func(str string, expected string, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + if err := tok.Next(); err != nil { + t.Fatal(err) + } + + if tok.Token() != tokenSymbolQuoted { + t.Fatal("not a quoted symbol") + } + + actual, err := tok.readQuotedSymbol() + if err != nil { + t.Fatal(err) + } + + if actual != expected { + t.Errorf("expected '%v', got '%v'", expected, actual) + } + + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + if c != next { + t.Errorf("expected next=%q, got next=%q", next, c) + } + }) + } + + test("'a'", "a", -1) + test("'a b c'", "a b c", -1) + test("'null' ", "null", ' ') + test("'false',", "false", ',') + test("'nan']", "nan", ']') + + test("'a\\'b'", "a'b", -1) + test("'a\\\nb'", "ab", -1) + test("'a\\\\b'", "a\\b", -1) + test("'a\x20b'", "a b", -1) + test("'a\\u2248b'", "a≈b", -1) + test("'a\\U0001F44Db'", "a👍b", -1) +} + func TestIsTripleQuote(t *testing.T) { test := func(str string, eok bool, next int) { t.Run(str, func(t *testing.T) { @@ -129,16 +234,16 @@ func TestScanForNumericType(t *testing.T) { test("0000-00-00", tokenTimestamp) test("0000T", tokenTimestamp) - test("0", tokenNumeric) - test("1b0101", tokenNumeric) - test("1B", tokenNumeric) - test("1x0101", tokenNumeric) - test("1X", tokenNumeric) - test("1234", tokenNumeric) - test("12345", tokenNumeric) - test("1,23T", tokenNumeric) - test("12,3T", tokenNumeric) - test("123,T", tokenNumeric) + test("0", tokenNumber) + test("1b0101", tokenNumber) + test("1B", tokenNumber) + test("1x0101", tokenNumber) + test("1X", tokenNumber) + test("1234", tokenNumber) + test("12345", tokenNumber) + test("1,23T", tokenNumber) + test("12,3T", tokenNumber) + test("123,T", tokenNumber) } func TestSkipWhitespace(t *testing.T) { diff --git a/writer.go b/writer.go index 1eecc61b..6f43a640 100644 --- a/writer.go +++ b/writer.go @@ -5,20 +5,11 @@ import ( "io" ) -type ctxType byte - -const ( - atTopLevelCtx ctxType = iota - inStructCtx - inListCtx - inSexpCtx -) - // writer holds shared stuff for all writers. type writer struct { - out io.Writer - ctxArr []ctxType - err error + out io.Writer + ctx ctx + err error fieldName string typeAnnotations []string @@ -26,17 +17,17 @@ type writer struct { // InStruct returns true if we're currently writing a struct. func (w *writer) InStruct() bool { - return w.ctx() == inStructCtx + return w.ctx.peek() == ctxInStruct } // InList returns true if we're currently writing a list. func (w *writer) InList() bool { - return w.ctx() == inListCtx + return w.ctx.peek() == ctxInList } // InSexp returns true if we're currently writing an s-expression. func (w *writer) InSexp() bool { - return w.ctx() == inSexpCtx + return w.ctx.peek() == ctxInSexp } // Err returns the current error, or nil if there are none yet. @@ -73,24 +64,3 @@ func (w *writer) TypeAnnotations(val ...string) { } w.typeAnnotations = append(w.typeAnnotations, val...) } - -// ctx returns the current writing context -func (w *writer) ctx() ctxType { - if len(w.ctxArr) == 0 { - return atTopLevelCtx - } - return w.ctxArr[len(w.ctxArr)-1] -} - -// push pushes a new writing context when a new container is begun. -func (w *writer) push(ctx ctxType) { - w.ctxArr = append(w.ctxArr, ctx) -} - -// pop pops the writing context when a container is ended. -func (w *writer) pop() { - if len(w.ctxArr) == 0 { - panic("pop called at top level") - } - w.ctxArr = w.ctxArr[:len(w.ctxArr)-1] -} From b5f7d6bdfeb6a8948d9d5796f5d6f75fb26c9ed0 Mon Sep 17 00:00:00 2001 From: David Murray Date: Thu, 27 Jun 2019 01:27:44 -0700 Subject: [PATCH 11/56] handling for special symbols --- skipper.go | 39 ++++++++++----- textreader.go | 122 ++++++++++++++++++++++++++++++++++++++++++--- textreader_test.go | 92 +++++++++++++++++++++++++++++++++- 3 files changed, 232 insertions(+), 21 deletions(-) diff --git a/skipper.go b/skipper.go index a1d3c20d..4b189b4a 100644 --- a/skipper.go +++ b/skipper.go @@ -684,6 +684,17 @@ func (t *tokenizer) skipWhitespace() (int, bool, error) { return t.skipWhitespaceWith(t.skipCommentsHandler) } +// SkipWhitespaceHelper is a 'helper' form of SkipWhitespace that +// unreads the first non-whitespace char instead of returning it. +func (t *tokenizer) skipWhitespaceHelper() (bool, error) { + c, ok, err := t.skipWhitespace() + if err != nil { + return false, err + } + t.unread(c) + return ok, err +} + // SkipLobWhitespace skips whitespace when we're inside a large // object ({{ ///= }} or {{ '''///=''' }}) where comments are // not allowed. @@ -797,29 +808,33 @@ func (t *tokenizer) skipBlockComment() error { // Peeks ahead to see if the next token is a double colon, and // if so skips it. If not, leaves the next token unconsumed. func (t *tokenizer) skipDoubleColon() (bool, error) { - // Read whitespace and first non-whitespace char. - c, _, err := t.skipWhitespace() + cs, err := t.peekN(2) + if err == io.EOF { + return false, nil + } if err != nil { return false, err } - if c != ':' { - // Not followed by a double-colon; put it back. - t.unread(c) - return false, nil + + if cs[0] == ':' && cs[1] == ':' { + t.skipN(2) + return true, nil } - // Peek to see if it's a double colon. - c, err = t.peek() + return false, nil +} + +// Peeks ahead to see if the next token is a dot, and +// if so skips it. If not, leaves the next token unconsumed. +func (t *tokenizer) skipDot() (bool, error) { + c, err := t.peek() if err != nil { return false, err } - if c != ':' { - // Nope; put back the first ':'. - t.unread(':') + if c != '.' { return false, nil } - // Yep; eat it and return true. t.read() return true, nil } diff --git a/textreader.go b/textreader.go index 98e0fd68..98fb5f23 100644 --- a/textreader.go +++ b/textreader.go @@ -7,6 +7,7 @@ import ( "io" "math" "math/big" + "strings" "time" ) @@ -45,6 +46,11 @@ func NewTextReader(in io.Reader) Reader { } } +// NewTextReaderString creates a new text reader from a string. +func NewTextReaderString(str string) Reader { + return NewTextReader(strings.NewReader(str)) +} + func (t *textReader) SymbolTable() SymbolTable { // Text content doesn't have a symbol table. return nil @@ -101,6 +107,8 @@ func (t *textReader) Next() bool { } } +// NextAfterValue moves to the next value when we're in the +// AfterValue state. func (t *textReader) nextAfterValue() (bool, error) { tok := t.tok.Token() switch tok { @@ -138,6 +146,8 @@ func (t *textReader) nextAfterValue() (bool, error) { } } +// NextBeforeFieldName moves to the next value when we're in the +// BeforeFieldName state. func (t *textReader) nextBeforeFieldName() (bool, error) { tok := t.tok.Token() switch tok { @@ -152,6 +162,11 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { if err != nil { return false, err } + if tok == tokenSymbol { + if err := verifyUnquotedSymbol(val, "field name"); err != nil { + return false, err + } + } // Skip over the following colon. if err = t.tok.Next(); err != nil { @@ -171,6 +186,8 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { } } +// NextBeforeTypeAnnotations moves to the next value when we're in the +// BeforeTypeAnnotations state. func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { tok := t.tok.Token() switch tok { @@ -187,18 +204,30 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { return false, err } + ws, err := t.tok.skipWhitespaceHelper() + if err != nil { + return false, err + } + ok, err := t.tok.skipDoubleColon() if err != nil { return false, err } if ok { // val was a type annotation; remember it and keep going. + if tok == tokenSymbol { + if err := verifyUnquotedSymbol(val, "type annotation"); err != nil { + return false, err + } + } t.typeAnnotations = append(t.typeAnnotations, val) return false, nil } // val was a legit symbol value. - t.onSymbol(val, tok) + if err := t.onSymbol(val, tok, ws); err != nil { + return false, err + } return true, nil default: @@ -206,15 +235,26 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } } -func (t *textReader) onSymbol(val string, tok tokenType) { +func verifyUnquotedSymbol(val string, ctx string) error { + switch val { + case "null", "true", "false", "nan": + return fmt.Errorf("cannot use unquoted keyword %v as %v", val, ctx) + } + return nil +} + +func (t *textReader) onSymbol(val string, tok tokenType, ws bool) error { valueType := SymbolType var value interface{} = val if tok == tokenSymbol { switch val { case "null": - // TODO: Deal with potential '.type'. - valueType = NullType + vt, err := t.onNull(ws) + if err != nil { + return err + } + valueType = vt value = nil case "true": @@ -234,6 +274,67 @@ func (t *textReader) onSymbol(val string, tok tokenType) { t.state = t.stateAfterValue() t.valueType = valueType t.value = value + + return nil +} + +func (t *textReader) onNull(ws bool) (Type, error) { + if !ws { + ok, err := t.tok.skipDot() + if err != nil { + return NoType, err + } + if ok { + return t.readNullType() + } + } + + return NullType, nil +} + +func (t *textReader) readNullType() (Type, error) { + if err := t.tok.Next(); err != nil { + return NoType, err + } + if t.tok.Token() != tokenSymbol { + return NoType, fmt.Errorf("unexpected token %v after null", t.tok.Token()) + } + + val, err := t.tok.ReadValue(tokenSymbol) + if err != nil { + return NoType, err + } + + switch val { + case "null": + return NullType, nil + case "bool": + return BoolType, nil + case "int": + return IntType, nil + case "float": + return FloatType, nil + case "decimal": + return DecimalType, nil + case "timestamp": + return TimestampType, nil + case "symbol": + return SymbolType, nil + case "string": + return StringType, nil + case "blob": + return BlobType, nil + case "clob": + return ClobType, nil + case "list": + return ListType, nil + case "struct": + return StructType, nil + case "sexp": + return SexpType, nil + default: + return NoType, fmt.Errorf("invalid symbol null.%v", val) + } } func (t *textReader) Type() Type { @@ -253,7 +354,7 @@ func (t *textReader) TypeAnnotations() []string { } func (t *textReader) IsNull() bool { - return false + return t.value == nil } func (t *textReader) StepIn() error { @@ -320,7 +421,10 @@ func (t *textReader) StepOut() error { } func (t *textReader) BoolValue() (bool, error) { - return false, errors.New("not implemented yet") + if t.valueType == BoolType { + return t.value.(bool), nil + } + return false, errors.New("value is not a bool") } func (t *textReader) IntValue() (int, error) { @@ -336,7 +440,11 @@ func (t *textReader) BigIntValue() (*big.Int, error) { } func (t *textReader) FloatValue() (float64, error) { - return 0.0, errors.New("not implemented yet") + if t.valueType == FloatType { + return t.value.(float64), nil + } + // TODO: Cast ints/decimals? + return 0.0, errors.New("value is not a float") } func (t *textReader) DecimalValue() (*Decimal, error) { diff --git a/textreader_test.go b/textreader_test.go index e496d59d..128a7c36 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -1,12 +1,12 @@ package ion import ( - "strings" + "math" "testing" ) func TestSymbols(t *testing.T) { - r := NewTextReader(strings.NewReader("'null'::foo bar a::b::'baz'")) + r := NewTextReaderString("'null'::foo bar a::b::'baz'") test := func(etas []string, eval string) { if !r.Next() { @@ -38,6 +38,9 @@ func TestSymbols(t *testing.T) { if r.Next() { t.Errorf("next unexpectedly returned true") } + if r.Err() != nil { + t.Error(r.Err()) + } } func strequals(a, b []string) bool { @@ -53,3 +56,88 @@ func strequals(a, b []string) bool { return true } + +func TestSpecialSymbols(t *testing.T) { + r := NewTextReaderString("null\nnull.struct\ntrue\nfalse\nnan") + + // null + { + if !r.Next() { + t.Fatal("next returned false") + } + if r.Type() != NullType { + t.Errorf("expected type=NullType, got %v", r.Type()) + } + if !r.IsNull() { + t.Error("expected isNull=true, got false") + } + } + + // null.struct + { + if !r.Next() { + t.Fatal("next returned false") + } + if r.Type() != StructType { + t.Errorf("expected type=StructType, got %v", r.Type()) + } + if !r.IsNull() { + t.Error("expected isNull=true, got false") + } + } + + // true + { + if !r.Next() { + t.Fatal("next returned false") + } + if r.Type() != BoolType { + t.Errorf("expected type=BoolType, got %v", r.Type()) + } + val, err := r.BoolValue() + if err != nil { + t.Fatal(err) + } + if !val { + t.Error("expected value=true, got false") + } + } + + // false + { + if !r.Next() { + t.Fatal("next returned false") + } + if r.Type() != BoolType { + t.Errorf("expected type=BoolType, got %v", r.Type()) + } + val, err := r.BoolValue() + if err != nil { + t.Fatal(err) + } + if val { + t.Error("expected value=false, got true") + } + } + + // nan + { + if !r.Next() { + t.Fatal("next returned false") + } + if r.Type() != FloatType { + t.Errorf("expected type=FloatType, got %v", r.Type()) + } + val, err := r.FloatValue() + if err != nil { + t.Fatal(err) + } + if !math.IsNaN(val) { + t.Errorf("expected value=NaN, got %v", val) + } + } + + if r.Next() { + t.Error("next returned true") + } +} From e4a4ba22b905421ec6f8bdebb00f85057c960183 Mon Sep 17 00:00:00 2001 From: David Murray Date: Thu, 27 Jun 2019 05:41:33 -0700 Subject: [PATCH 12/56] strings and integers --- textreader.go | 127 +++++++++++++++++- textreader_test.go | 145 ++++++++++++++++++++- textutils.go | 55 ++++++++ tokenizer.go | 315 ++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 617 insertions(+), 25 deletions(-) diff --git a/textreader.go b/textreader.go index 98fb5f23..3c611df3 100644 --- a/textreader.go +++ b/textreader.go @@ -230,6 +230,23 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } return true, nil + case tokenString, tokenLongString: + val, err := t.tok.ReadValue(tok) + if err != nil { + return false, err + } + + t.state = t.stateAfterValue() + t.valueType = StringType + t.value = val + return true, nil + + case tokenBinary, tokenHex, tokenNumber: + if err := t.onNumber(tok); err != nil { + return false, err + } + return true, nil + default: return false, fmt.Errorf("unexpected token '%v'", tok) } @@ -337,6 +354,69 @@ func (t *textReader) readNullType() (Type, error) { } } +func (t *textReader) onNumber(tok tokenType) error { + var valueType Type + var value interface{} + + switch tok { + case tokenBinary: + val, err := t.tok.ReadValue(tok) + if err != nil { + return err + } + + valueType = IntType + value, err = parseInt(val, 2) + if err != nil { + return err + } + + case tokenHex: + val, err := t.tok.ReadValue(tok) + if err != nil { + return err + } + + valueType = IntType + value, err = parseInt(val, 16) + if err != nil { + return err + } + + case tokenNumber: + val, tt, err := t.tok.ReadNumber() + if err != nil { + return err + } + + valueType = tt + + switch tt { + case IntType: + value, err = parseInt(val, 10) + case FloatType: + value, err = parseFloat(val) + case DecimalType: + value, err = parseDecimal(val) + default: + panic("unexpected type") + } + + if err != nil { + return err + } + + default: + panic("unexpected token type") + } + + t.state = t.stateAfterValue() + t.valueType = valueType + t.value = value + + return nil +} + func (t *textReader) Type() Type { return t.valueType } @@ -422,25 +502,63 @@ func (t *textReader) StepOut() error { func (t *textReader) BoolValue() (bool, error) { if t.valueType == BoolType { + if t.value == nil { + return false, nil + } return t.value.(bool), nil } return false, errors.New("value is not a bool") } func (t *textReader) IntValue() (int, error) { - return 0, errors.New("not implemented yet") + i, err := t.Int64Value() + if err != nil { + return 0, err + } + if i > math.MaxInt32 || i < math.MinInt32 { + return 0, errors.New("value out of bounds") + } + return int(i), nil } func (t *textReader) Int64Value() (int64, error) { - return 0, errors.New("not implemented yet") + if t.valueType == IntType { + if t.value == nil { + return 0, nil + } + + if i, ok := t.value.(int64); ok { + return i, nil + } + + bi := t.value.(*big.Int) + if bi.IsInt64() { + return bi.Int64(), nil + } + + return 0, errors.New("value out of bounds") + } + return 0, errors.New("value is not an int") } func (t *textReader) BigIntValue() (*big.Int, error) { - return nil, errors.New("not implemented yet") + if t.valueType == IntType { + if t.value == nil { + return nil, nil + } + if i, ok := t.value.(int64); ok { + return big.NewInt(i), nil + } + return t.value.(*big.Int), nil + } + return nil, errors.New("value is not an int") } func (t *textReader) FloatValue() (float64, error) { if t.valueType == FloatType { + if t.value == nil { + return 0.0, nil + } return t.value.(float64), nil } // TODO: Cast ints/decimals? @@ -458,6 +576,9 @@ func (t *textReader) TimeValue() (time.Time, error) { func (t *textReader) StringValue() (string, error) { switch t.valueType { case StringType, SymbolType: + if t.value == nil { + return "", nil + } return t.value.(string), nil default: diff --git a/textreader_test.go b/textreader_test.go index 128a7c36..085c014b 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -1,12 +1,152 @@ package ion import ( + "errors" + "fmt" "math" + "math/big" "testing" ) +func TestInts(t *testing.T) { + test := func(str string, m func(Reader) error) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Error("next returned false") + } + if r.Type() != IntType { + t.Errorf("expected type=IntType, got %v", r.Type()) + } + + if err := m(r); err != nil { + t.Error(err) + } + + if r.Next() { + t.Error("next returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } + }) + } + + test("null.int", func(r Reader) error { + if !r.IsNull() { + return errors.New("expected isnull=true, got false") + } + + val, err := r.IntValue() + if err != nil { + return err + } + if val != 0 { + return fmt.Errorf("expected 0, got %v", val) + } + + return nil + }) + + testInt := func(str string, eval int) { + test(str, func(r Reader) error { + val, err := r.IntValue() + if err != nil { + return err + } + if val != eval { + return fmt.Errorf("expected %v, got %v", eval, val) + } + return nil + }) + } + + testInt("0", 0) + testInt("12345", 12345) + testInt("-12345", -12345) + testInt("0b000101", 5) + testInt("-0b000101", -5) + testInt("0x01020e0F", 0x01020e0f) + testInt("-0x01020e0F", -0x01020e0f) + + testInt64 := func(str string, eval int64) { + test(str, func(r Reader) error { + val, err := r.Int64Value() + if err != nil { + return err + } + if val != eval { + return fmt.Errorf("expected %v, got %v", eval, val) + } + return nil + }) + } + + testInt64("0x123FFFFFFFF", 0x123FFFFFFFF) + testInt64("-0x123FFFFFFFF", -0x123FFFFFFFF) + + testBigInt := func(str string) { + test(str, func(r Reader) error { + val, err := r.BigIntValue() + if err != nil { + return err + } + + eval, _ := (&big.Int{}).SetString(str, 0) + if eval.Cmp(val) != 0 { + return fmt.Errorf("expected %v, got %v", eval, val) + } + + return nil + }) + } + + testBigInt("0xEFFFFFFFFFFFFFFF") + testBigInt("0xFFFFFFFFFFFFFFFF") + testBigInt("-0x1FFFFFFFFFFFFFFFF") +} + +func TestStrings(t *testing.T) { + r := NewTextReaderString(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) + + test := func(etas []string, eval string) { + if !r.Next() { + t.Fatal("next returned false") + } + + if r.Type() != StringType { + t.Fatalf("expected type=string, got type=%v", r.Type()) + } + + if !strequals(r.TypeAnnotations(), etas) { + t.Errorf("expected tas=%v, got tas=%v", etas, r.TypeAnnotations()) + } + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected val=%v, got val=%v", eval, val) + } + } + + test([]string{"foo"}, "bar") + test(nil, "baz") + test([]string{"a", "b"}, "beepboop") + test(nil, "") + + if r.Next() { + t.Errorf("next unexpectedly returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } +} + func TestSymbols(t *testing.T) { - r := NewTextReaderString("'null'::foo bar a::b::'baz'") + r := NewTextReaderString("'null'::foo bar a::b::'baz' null.symbol") test := func(etas []string, eval string) { if !r.Next() { @@ -32,8 +172,9 @@ func TestSymbols(t *testing.T) { } test([]string{"null"}, "foo") - test([]string{}, "bar") + test(nil, "bar") test([]string{"a", "b"}, "baz") + test(nil, "") if r.Next() { t.Errorf("next unexpectedly returned true") diff --git a/textutils.go b/textutils.go index 44c46b77..c7e51ad0 100644 --- a/textutils.go +++ b/textutils.go @@ -1,7 +1,10 @@ package ion import ( + "errors" "io" + "math/big" + "strconv" ) // Does this symbol need to be quoted in text form? @@ -231,3 +234,55 @@ func writeRawChar(c byte, out io.Writer) error { _, err := out.Write([]byte{c}) return err } + +func parseFloat(str string) (float64, error) { + return 0, errors.New("not implemented yet") +} + +func parseDecimal(str string) (*Decimal, error) { + return nil, errors.New("not implemented yet") +} + +func parseInt(str string, radix int) (interface{}, error) { + digits := str + + switch radix { + case 10: + // All set. + + case 2, 16: + neg := false + if digits[0] == '-' { + neg = true + digits = digits[1:] + } + + // Skip over the '0x' prefix. + digits = digits[2:] + if neg { + digits = "-" + digits + } + + default: + panic("unsupported radix") + } + + i, err := strconv.ParseInt(digits, radix, 64) + if err == nil { + return i, nil + } + if err.(*strconv.NumError).Err != strconv.ErrRange { + return nil, err + } + + bi, ok := (&big.Int{}).SetString(digits, radix) + if !ok { + return nil, &strconv.NumError{ + Func: "ParseInt", + Num: str, + Err: strconv.ErrSyntax, + } + } + + return bi, nil +} diff --git a/tokenizer.go b/tokenizer.go index 3d0e23d0..b405203a 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -231,26 +231,6 @@ func (t *tokenizer) Next() error { t.unread(c) return t.finish(tokenSymbolOperator, true) - case isOperatorChar(c): - t.unread(c) - return t.finish(tokenSymbolOperator, true) - - case c == '"': - return t.finish(tokenString, true) - - case isIdentifierStart(c): - t.unread(c) - return t.finish(tokenSymbol, true) - - case isDigit(c): - tt, err := t.scanForNumericType(c) - if err != nil { - return err - } - - t.unread(c) - return t.finish(tt, true) - case c == '-': c2, err := t.peek() if err != nil { @@ -268,6 +248,7 @@ func (t *tokenizer) Next() error { return invalidChar(c2) } t.unread(c2) + t.unread(c) return t.finish(tt, true) } @@ -282,6 +263,26 @@ func (t *tokenizer) Next() error { t.unread(c) return t.finish(tokenSymbolOperator, true) + case isOperatorChar(c): + t.unread(c) + return t.finish(tokenSymbolOperator, true) + + case c == '"': + return t.finish(tokenString, true) + + case isIdentifierStart(c): + t.unread(c) + return t.finish(tokenSymbol, true) + + case isDigit(c): + tt, err := t.scanForNumericType(c) + if err != nil { + return err + } + + t.unread(c) + return t.finish(tt, true) + default: return invalidChar(c) } @@ -303,6 +304,14 @@ func (t *tokenizer) ReadValue(tok tokenType) (string, error) { str, err = t.readSymbol() case tokenSymbolQuoted: str, err = t.readQuotedSymbol() + case tokenString: + str, err = t.readString() + case tokenLongString: + str, err = t.readLongString() + case tokenBinary: + str, err = t.readBinary() + case tokenHex: + str, err = t.readHex() default: panic("unsupported token type") } @@ -315,6 +324,106 @@ func (t *tokenizer) ReadValue(tok tokenType) (string, error) { return str, nil } +// ReadNumber reads a number and determines the type. +func (t *tokenizer) ReadNumber() (string, Type, error) { + w := strings.Builder{} + + c, err := t.read() + if err != nil { + return "", NoType, err + } + + if c == '-' { + w.WriteByte('-') + c, err = t.read() + if err != nil { + return "", NoType, err + } + } + + first := c + oldlen := w.Len() + + c, err = t.readDigits(c, &w) + if err != nil { + return "", NoType, err + } + + if first == '0' { + if w.Len()-oldlen > 1 { + return "", NoType, errors.New("invalid leading zeroes") + } + } + + tt := IntType + + if c == '.' { + w.WriteByte('.') + tt = DecimalType + + if c, err = t.read(); err != nil { + return "", NoType, err + } + if c, err = t.readDigits(c, &w); err != nil { + return "", NoType, err + } + } + + switch c { + case 'e', 'E': + tt = FloatType + + w.WriteByte(byte(c)) + if c, err = t.readExponent(&w); err != nil { + return "", NoType, err + } + + case 'd', 'D': + tt = DecimalType + + w.WriteByte(byte(c)) + if c, err = t.readExponent(&w); err != nil { + return "", NoType, err + } + } + + ok, err := t.isStopChar(c) + if err != nil { + return "", NoType, err + } + if !ok { + return "", NoType, invalidChar(c) + } + t.unread(c) + + return w.String(), tt, nil +} + +func (t *tokenizer) readExponent(w io.ByteWriter) (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + if c == '+' || c == '-' { + w.WriteByte(byte(c)) + if c, err = t.read(); err != nil { + return 0, err + } + } + + return t.readDigits(c, w) +} + +func (t *tokenizer) readDigits(c int, w io.ByteWriter) (int, error) { + if !isDigit(c) { + return 0, invalidChar(c) + } + w.WriteByte(byte(c)) + + return t.readRadixDigits(isDigit, w) +} + // ReadSymbol reads an unquoted symbol value. func (t *tokenizer) readSymbol() (string, error) { ret := strings.Builder{} @@ -376,6 +485,92 @@ func (t *tokenizer) readQuotedSymbol() (string, error) { } } +// ReadString reads a quoted string. +func (t *tokenizer) readString() (string, error) { + ret := strings.Builder{} + + for { + c, err := t.read() + if err != nil { + return "", err + } + + switch c { + case -1, '\n': + return "", invalidChar(c) + + case '"': + return ret.String(), nil + + case '\\': + c, err = t.peek() + if err != nil { + return "", err + } + + if c == '\n' { + t.read() + continue + } + + r, err := t.readEscapedChar(false) + if err != nil { + return "", err + } + ret.WriteRune(r) + + default: + ret.WriteByte(byte(c)) + } + } +} + +// ReadLongString reads a triple-quoted string. +func (t *tokenizer) readLongString() (string, error) { + ret := strings.Builder{} + + for { + c, err := t.read() + if err != nil { + return "", err + } + + switch c { + case -1: + return "", invalidChar(c) + + case '\'': + ok, err := t.skipEndOfLongString(t.skipCommentsHandler) + if err != nil { + return "", err + } + if ok { + return ret.String(), nil + } + + case '\\': + c, err = t.peek() + if err != nil { + return "", err + } + + if c == '\n' { + t.read() + continue + } + + r, err := t.readEscapedChar(false) + if err != nil { + return "", err + } + ret.WriteRune(r) + + default: + ret.WriteByte(byte(c)) + } + } +} + // ReadEscapedChar reads an escaped character. func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { // We just read the '\', grab the next char. @@ -442,6 +637,86 @@ func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { return val, nil } +func (t *tokenizer) readBinary() (string, error) { + isB := func(c int) bool { + return c == 'b' || c == 'B' + } + isDigit := func(c int) bool { + return c == '0' || c == '1' + } + return t.readRadix(isB, isDigit) +} + +func (t *tokenizer) readHex() (string, error) { + isX := func(c int) bool { + return c == 'x' || c == 'X' + } + return t.readRadix(isX, isHexDigit) +} + +func (t *tokenizer) readRadix(pok, dok matcher) (string, error) { + w := strings.Builder{} + + c, err := t.read() + if err != nil { + return "", err + } + + if c == '-' { + w.WriteByte('-') + c, err = t.read() + if err != nil { + return "", err + } + } + + if c != '0' { + return "", invalidChar(c) + } + w.WriteByte('0') + + c, err = t.read() + if err != nil { + return "", err + } + if !pok(c) { + return "", invalidChar(c) + } + w.WriteByte(byte(c)) + + c, err = t.readRadixDigits(dok, &w) + if err != nil { + return "", err + } + + ok, err := t.isStopChar(c) + if err != nil { + return "", err + } + if !ok { + return "", invalidChar(c) + } + t.unread(c) + + return w.String(), nil +} + +func (t *tokenizer) readRadixDigits(dok matcher, w io.ByteWriter) (int, error) { + var c int + var err error + + for { + c, err = t.read() + if err != nil { + return 0, err + } + if !dok(c) { + return c, nil + } + w.WriteByte(byte(c)) + } +} + // IsTripleQuote returns true if this is a triple-quote sequence ('''). func (t *tokenizer) isTripleQuote() (bool, error) { // We've just read a '\'', check if the next two are too. From e75397aa715f943d01f2b4931d0f7180aff5a851 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 2 Jul 2019 13:08:36 -0700 Subject: [PATCH 13/56] floats, doubles, timestamps --- textreader.go | 53 ++++++++++++++- textreader_test.go | 114 ++++++++++++++++++++++++++++++++ textutils.go | 72 ++++++++++++++++++++- textutils_test.go | 33 ++++++++++ tokenizer.go | 158 ++++++++++++++++++++++++++++++++++++++++++++- tokenizer_test.go | 45 +++++++++++++ 6 files changed, 468 insertions(+), 7 deletions(-) diff --git a/textreader.go b/textreader.go index 3c611df3..1dadd61c 100644 --- a/textreader.go +++ b/textreader.go @@ -241,12 +241,18 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.value = val return true, nil - case tokenBinary, tokenHex, tokenNumber: + case tokenBinary, tokenHex, tokenNumber, tokenFloatInf, tokenFloatMinusInf: if err := t.onNumber(tok); err != nil { return false, err } return true, nil + case tokenTimestamp: + if err := t.onTimestamp(); err != nil { + return false, err + } + return true, nil + default: return false, fmt.Errorf("unexpected token '%v'", tok) } @@ -406,6 +412,14 @@ func (t *textReader) onNumber(tok tokenType) error { return err } + case tokenFloatInf: + valueType = FloatType + value = math.Inf(1) + + case tokenFloatMinusInf: + valueType = FloatType + value = math.Inf(-1) + default: panic("unexpected token type") } @@ -417,6 +431,24 @@ func (t *textReader) onNumber(tok tokenType) error { return nil } +func (t *textReader) onTimestamp() error { + val, err := t.tok.ReadValue(tokenTimestamp) + if err != nil { + return err + } + + value, err := parseTimestamp(val) + if err != nil { + return err + } + + t.state = t.stateAfterValue() + t.valueType = TimestampType + t.value = value + + return nil +} + func (t *textReader) Type() Type { return t.valueType } @@ -566,11 +598,26 @@ func (t *textReader) FloatValue() (float64, error) { } func (t *textReader) DecimalValue() (*Decimal, error) { - return nil, errors.New("not implemented yet") + switch t.valueType { + case DecimalType: + if t.value == nil { + return nil, nil + } + return t.value.(*Decimal), nil + } + // TODO: Cast floats/ints? + return nil, errors.New("value is not a decimal") } func (t *textReader) TimeValue() (time.Time, error) { - return time.Time{}, errors.New("not implemented yet") + switch t.valueType { + case TimestampType: + if t.value == nil { + return time.Time{}, nil + } + return t.value.(time.Time), nil + } + return time.Time{}, errors.New("value is not a timestamp") } func (t *textReader) StringValue() (string, error) { diff --git a/textreader_test.go b/textreader_test.go index 085c014b..02304b8a 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -6,8 +6,122 @@ import ( "math" "math/big" "testing" + "time" ) +func TestTimestamps(t *testing.T) { + test := func(str string, eval time.Time) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Error("next returned false") + t.Fatal(r.Err()) + } + if r.Type() != TimestampType { + t.Errorf("expected type=TimestampType, got %v", r.Type()) + } + + val, err := r.TimeValue() + if err != nil { + t.Fatal(err) + } + if !val.Equal(eval) { + t.Errorf("expected %v, got %v", eval, val) + } + + if r.Next() { + t.Error("next returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } + }) + } + + et := time.Date(2001, time.January, 1, 0, 0, 0, 0, time.UTC) + test("2001T", et) + test("2001-01T", et) + test("2001-01-01", et) + test("2001-01-01T", et) + test("2001-01-01T00:00Z", et) + test("2001-01-01T00:00:00Z", et) + test("2001-01-01T00:00:00.000Z", et) + test("2001-01-01T00:00:00.000+00:00", et) +} + +func TestDoubles(t *testing.T) { + test := func(str string, eval string) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Error("next returned false") + } + if r.Type() != DecimalType { + t.Errorf("expected type=DecimalType, got %v", r.Type()) + } + + ee := MustParseDecimal(eval) + + val, err := r.DecimalValue() + if err != nil { + t.Fatal(err) + } + if !ee.Equal(val) { + t.Errorf("expected %v, got %v", ee, val) + } + + if r.Next() { + t.Error("next returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } + }) + } + + test("123.", "123") + test("123.0", "123") + test("123.456", "123.456") + test("123d2", "12300") + test("123d+2", "12300") + test("123d-2", "1.23") +} + +func TestFloats(t *testing.T) { + test := func(str string, eval float64) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Error("next returned false") + } + if r.Type() != FloatType { + t.Errorf("expected type=FloatType, got %v", r.Type()) + } + + val, err := r.FloatValue() + if err != nil { + t.Error(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + + if r.Next() { + t.Error("next returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } + }) + } + + test("1e100\n", 1e100) + test("1.2e+0", 1.2) + test("-123.456e-78", -123.456e-78) + test("+inf", math.Inf(1)) + test("-inf", math.Inf(-1)) +} + func TestInts(t *testing.T) { test := func(str string, m func(Reader) error) { t.Run(str, func(t *testing.T) { diff --git a/textutils.go b/textutils.go index c7e51ad0..90d8d905 100644 --- a/textutils.go +++ b/textutils.go @@ -1,10 +1,11 @@ package ion import ( - "errors" + "fmt" "io" "math/big" "strconv" + "time" ) // Does this symbol need to be quoted in text form? @@ -236,11 +237,11 @@ func writeRawChar(c byte, out io.Writer) error { } func parseFloat(str string) (float64, error) { - return 0, errors.New("not implemented yet") + return strconv.ParseFloat(str, 64) } func parseDecimal(str string) (*Decimal, error) { - return nil, errors.New("not implemented yet") + return ParseDecimal(str) } func parseInt(str string, radix int) (interface{}, error) { @@ -286,3 +287,68 @@ func parseInt(str string, radix int) (interface{}, error) { return bi, nil } + +func parseTimestamp(val string) (time.Time, error) { + if len(val) < 5 { + return invalidTimestamp(val) + } + + year, err := strconv.ParseInt(val[:4], 10, 32) + if err != nil { + return invalidTimestamp(val) + } + if len(val) == 5 && (val[4] == 't' || val[4] == 'T') { + // yyyyT + return time.Date(int(year), 1, 1, 0, 0, 0, 0, time.UTC), nil + } + if val[4] != '-' { + return invalidTimestamp(val) + } + + if len(val) < 8 { + return invalidTimestamp(val) + } + + month, err := strconv.ParseInt(val[5:7], 10, 32) + if err != nil { + return invalidTimestamp(val) + } + + if len(val) == 8 && (val[7] == 't' || val[7] == 'T') { + // yyyy-mmT + return time.Date(int(year), time.Month(month), 1, 0, 0, 0, 0, time.UTC), nil + } + if val[7] != '-' { + return invalidTimestamp(val) + } + + if len(val) < 10 { + return invalidTimestamp(val) + } + + day, err := strconv.ParseInt(val[8:10], 10, 32) + if err != nil { + return invalidTimestamp(val) + } + + if len(val) == 10 || (len(val) == 11 && (val[10] == 't' || val[10] == 'T')) { + // yyyy-mm-dd or yyyy-mm-ddT + return time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC), nil + } + if val[10] != 't' && val[10] != 'T' { + return invalidTimestamp(val) + } + + if len(val) < 17 { + return invalidTimestamp(val) + } + if val[16] != ':' { + return time.Parse("2006-01-02T15:04Z07:00", val) + } + + return time.Parse(time.RFC3339Nano, val) +} + +func invalidTimestamp(val string) (time.Time, error) { + return time.Time{}, fmt.Errorf("invalid timestamp: %v", val) +} diff --git a/textutils_test.go b/textutils_test.go index 684904ec..988a4e71 100644 --- a/textutils_test.go +++ b/textutils_test.go @@ -3,8 +3,41 @@ package ion import ( "strings" "testing" + "time" ) +func TestParseTimestamp(t *testing.T) { + test := func(str string, eval string) { + t.Run(str, func(t *testing.T) { + val, err := parseTimestamp(str) + if err != nil { + t.Fatal(err) + } + + et, err := time.Parse(time.RFC3339Nano, eval) + if err != nil { + t.Fatal(err) + } + + if !val.Equal(et) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("1234T", "1234-01-01T00:00:00Z") + test("1234-05T", "1234-05-01T00:00:00Z") + test("1234-05-06", "1234-05-06T00:00:00Z") + test("1234-05-06T", "1234-05-06T00:00:00Z") + test("1234-05-06T07:08Z", "1234-05-06T07:08:00Z") + test("1234-05-06T07:08:09Z", "1234-05-06T07:08:09Z") + test("1234-05-06T07:08:09.100Z", "1234-05-06T07:08:09.100Z") + test("1234-05-06T07:08:09.100100Z", "1234-05-06T07:08:09.100100Z") + + test("1234-05-06T07:08+09:10", "1234-05-06T07:08:00+09:10") + test("1234-05-06T07:08:09-10:11", "1234-05-06T07:08:09-10:11") +} + func TestWriteSymbol(t *testing.T) { test := func(sym, expected string) { t.Run(expected, func(t *testing.T) { diff --git a/tokenizer.go b/tokenizer.go index b405203a..bdb5be97 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -312,6 +312,8 @@ func (t *tokenizer) ReadValue(tok tokenType) (string, error) { str, err = t.readBinary() case tokenHex: str, err = t.readHex() + case tokenTimestamp: + str, err = t.readTimestamp() default: panic("unsupported token type") } @@ -417,7 +419,7 @@ func (t *tokenizer) readExponent(w io.ByteWriter) (int, error) { func (t *tokenizer) readDigits(c int, w io.ByteWriter) (int, error) { if !isDigit(c) { - return 0, invalidChar(c) + return c, nil } w.WriteByte(byte(c)) @@ -717,6 +719,160 @@ func (t *tokenizer) readRadixDigits(dok matcher, w io.ByteWriter) (int, error) { } } +func (t *tokenizer) readTimestamp() (string, error) { + w := strings.Builder{} + + c, err := t.readTimestampDigits(4, &w) + if err != nil { + return "", err + } + if c == 'T' { + // yyyyT + w.WriteByte('T') + return w.String(), nil + } + if c != '-' { + return "", invalidChar(c) + } + w.WriteByte('-') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c == 'T' { + // yyyy-mmT + w.WriteByte('T') + return w.String(), nil + } + if c != '-' { + return "", invalidChar(c) + } + w.WriteByte('-') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c != 'T' { + // yyyy-mm-dd + return t.readTimestampFinish(c, &w) + } + w.WriteByte('T') + + if c, err = t.read(); err != nil { + return "", err + } + if !isDigit(c) { + // yyyy-mm-ddT(+hh:mm)? + if c, err = t.readTimestampOffset(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) + } + w.WriteByte(byte(c)) + + if c, err = t.readTimestampDigits(1, &w); err != nil { + return "", err + } + if c != ':' { + return "", invalidChar(c) + } + w.WriteByte(':') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c != ':' { + // yyyy-mm-ddThh:mmZ + if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) + } + w.WriteByte(':') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c != '.' { + // yyyy-mm-ddThh:mm:ssZ + if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) + } + w.WriteByte('.') + + // yyyy-mm-ddThh:mm:ss.ssssZ + if c, err = t.read(); err != nil { + return "", err + } + if isDigit(c) { + if c, err = t.readDigits(c, &w); err != nil { + return "", err + } + } + + if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) +} + +func (t *tokenizer) readTimestampOffsetOrZ(c int, w io.ByteWriter) (int, error) { + if c == '-' || c == '+' { + return t.readTimestampOffset(c, w) + } + if c == 'z' || c == 'Z' { + w.WriteByte(byte(c)) + return t.read() + } + return 0, invalidChar(c) +} + +func (t *tokenizer) readTimestampOffset(c int, w io.ByteWriter) (int, error) { + if c != '-' && c != '+' { + return c, nil + } + w.WriteByte(byte(c)) + + c, err := t.readTimestampDigits(2, w) + if err != nil { + return 0, err + } + if c != ':' { + return 0, invalidChar(c) + } + w.WriteByte(':') + return t.readTimestampDigits(2, w) +} + +func (t *tokenizer) readTimestampDigits(n int, w io.ByteWriter) (int, error) { + for n > 0 { + c, err := t.read() + if err != nil { + return 0, err + } + if !isDigit(c) { + return 0, invalidChar(c) + } + w.WriteByte(byte(c)) + n-- + } + return t.read() +} + +func (t *tokenizer) readTimestampFinish(c int, w fmt.Stringer) (string, error) { + ok, err := t.isStopChar(c) + if err != nil { + return "", err + } + if !ok { + return "", invalidChar(c) + } + t.unread(c) + return w.String(), nil +} + // IsTripleQuote returns true if this is a triple-quote sequence ('''). func (t *tokenizer) isTripleQuote() (bool, error) { // We've just read a '\'', check if the next two are too. diff --git a/tokenizer_test.go b/tokenizer_test.go index 09224518..7354b162 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -132,6 +132,51 @@ func TestReadQuotedSymbol(t *testing.T) { test("'a\\U0001F44Db'", "a👍b", -1) } +func TestReadTimestamp(t *testing.T) { + test := func(str string, eval string, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != tokenTimestamp { + t.Fatalf("unexpected token %v", tok.Token()) + } + + val, err := tok.ReadValue(tokenTimestamp) + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + if c != next { + t.Errorf("expected %q, got %q", next, c) + } + }) + } + + test("2001T", "2001T", -1) + test("2001-01T,", "2001-01T", ',') + test("2001-01-02}", "2001-01-02", '}') + test("2001-01-02T ", "2001-01-02T", ' ') + test("2001-01-02T+00:00\t", "2001-01-02T+00:00", '\t') + test("2001-01-02T-00:00\n", "2001-01-02T-00:00", '\n') + test("2001-01-02T03:04+00:00 ", "2001-01-02T03:04+00:00", ' ') + test("2001-01-02T03:04-00:00 ", "2001-01-02T03:04-00:00", ' ') + test("2001-01-02T03:04Z ", "2001-01-02T03:04Z", ' ') + test("2001-01-02T03:04z ", "2001-01-02T03:04z", ' ') + test("2001-01-02T03:04:05Z ", "2001-01-02T03:04:05Z", ' ') + test("2001-01-02T03:04:05+00:00 ", "2001-01-02T03:04:05+00:00", ' ') + test("2001-01-02T03:04:05.666Z ", "2001-01-02T03:04:05.666Z", ' ') + test("2001-01-02T03:04:05.666666z ", "2001-01-02T03:04:05.666666z", ' ') +} + func TestIsTripleQuote(t *testing.T) { test := func(str string, eok bool, next int) { t.Run(str, func(t *testing.T) { From ba3adeb5d99d59ec77e66d0760f0b8196fccac79 Mon Sep 17 00:00:00 2001 From: David Murray Date: Wed, 3 Jul 2019 12:23:03 -0700 Subject: [PATCH 14/56] blobs and clobs --- textreader.go | 83 +++++++++++++++++++++++++++++++++++++++++++++- textreader_test.go | 70 ++++++++++++++++++++++++++++++++++++++ tokenizer.go | 82 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 1 deletion(-) diff --git a/textreader.go b/textreader.go index 1dadd61c..a4090193 100644 --- a/textreader.go +++ b/textreader.go @@ -2,6 +2,7 @@ package ion import ( "bufio" + "encoding/base64" "errors" "fmt" "io" @@ -253,6 +254,12 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } return true, nil + case tokenOpenDoubleBrace: + if err := t.onLob(); err != nil { + return false, err + } + return true, nil + default: return false, fmt.Errorf("unexpected token '%v'", tok) } @@ -449,6 +456,73 @@ func (t *textReader) onTimestamp() error { return nil } +func (t *textReader) onLob() error { + c, _, err := t.tok.skipLobWhitespace() + if err != nil { + return err + } + + var ( + valType Type + val []byte + ) + + // TODO: Peek for clobs. + if c == '"' { + + // Short clob. + valType = ClobType + + str, err := t.tok.ReadShortClob() + if err != nil { + return err + } + + val = []byte(str) + + } else if c == '\'' { + + // Long clob. + ok, err := t.tok.isTripleQuote() + if err != nil { + return err + } + if !ok { + return invalidChar(c) + } + + valType = ClobType + + str, err := t.tok.ReadLongClob() + if err != nil { + return err + } + + val = []byte(str) + + } else { + // Normal blob. + valType = BlobType + t.tok.unread(c) + + b64, err := t.tok.ReadBlob() + if err != nil { + return err + } + + val, err = base64.StdEncoding.DecodeString(b64) + if err != nil { + return err + } + } + + t.state = t.stateAfterValue() + t.valueType = valType + t.value = val + + return nil +} + func (t *textReader) Type() Type { return t.valueType } @@ -634,7 +708,14 @@ func (t *textReader) StringValue() (string, error) { } func (t *textReader) ByteValue() ([]byte, error) { - return nil, errors.New("not implemented yet") + switch t.valueType { + case BlobType, ClobType: + if t.value == nil { + return nil, nil + } + return t.value.([]byte), nil + } + return nil, errors.New("value is not a byte array") } // FinishValue finishes reading the current value, if there is one. diff --git a/textreader_test.go b/textreader_test.go index 02304b8a..254867a5 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -1,6 +1,7 @@ package ion import ( + "bytes" "errors" "fmt" "math" @@ -9,6 +10,75 @@ import ( "time" ) +func TestClobs(t *testing.T) { + test := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Error("next returned false") + t.Fatal(r.Err()) + } + if r.Type() != ClobType { + t.Errorf("expected type=ClobType, got %v", r.Type()) + } + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + + if r.Next() { + t.Error("next returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } + }) + } + + test("{{\"\"}}", []byte{}) + test("{{ \"hello world\" }}", []byte("hello world")) + test("{{'''hello world'''}}", []byte("hello world")) + test("{{'''hello'''\n'''world'''}}", []byte("helloworld")) +} + +func TestBlobs(t *testing.T) { + test := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Error("next returned false") + t.Fatal(r.Err()) + } + if r.Type() != BlobType { + t.Errorf("expected type=BlobType, got %v", r.Type()) + } + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + + if r.Next() { + t.Error("next returned true") + } + if r.Err() != nil { + t.Error(r.Err()) + } + }) + } + + test("{{}}", []byte{}) + test("{{AA==}}", []byte{0}) + test("{{ SGVsbG8g\r\nV29ybGQ= }}", []byte("Hello World")) +} + func TestTimestamps(t *testing.T) { test := func(str string, eval time.Time) { t.Run(str, func(t *testing.T) { diff --git a/tokenizer.go b/tokenizer.go index bdb5be97..5300616b 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -873,6 +873,88 @@ func (t *tokenizer) readTimestampFinish(c int, w fmt.Stringer) (string, error) { return w.String(), nil } +func (t *tokenizer) ReadBlob() (string, error) { + w := strings.Builder{} + + var ( + c int + err error + ) + + for { + if c, _, err = t.skipLobWhitespace(); err != nil { + return "", err + } + if c == -1 { + return "", invalidChar(c) + } + if c == '}' { + break + } + w.WriteByte(byte(c)) + } + + if c, err = t.read(); err != nil { + return "", err + } + if c != '}' { + return "", invalidChar(c) + } + + t.unfinished = false + return w.String(), nil +} + +func (t *tokenizer) ReadShortClob() (string, error) { + str, err := t.readString() + if err != nil { + return "", err + } + + c, _, err := t.skipLobWhitespace() + if err != nil { + return "", err + } + if c != '}' { + return "", invalidChar(c) + } + + if c, err = t.read(); err != nil { + return "", err + } + if c != '}' { + return "", invalidChar(c) + } + + t.unfinished = false + return str, nil +} + +func (t *tokenizer) ReadLongClob() (string, error) { + str, err := t.readLongString() + if err != nil { + return "", err + } + + c, _, err := t.skipLobWhitespace() + if err != nil { + return "", err + } + if c != '}' { + return "", invalidChar(c) + } + + if c, err = t.read(); err != nil { + return "", err + } + if c != '}' { + return "", invalidChar(c) + } + + t.unfinished = false + return str, nil +} + // IsTripleQuote returns true if this is a triple-quote sequence ('''). func (t *tokenizer) isTripleQuote() (bool, error) { // We've just read a '\'', check if the next two are too. From 8f8d5dc5d026feab54a564d433cf0f081c66560e Mon Sep 17 00:00:00 2001 From: David Murray Date: Sun, 14 Jul 2019 17:16:23 +1000 Subject: [PATCH 15/56] structs and lists --- skipper.go | 7 +- textreader.go | 102 ++++++++++++++++++++------ textreader_test.go | 175 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 261 insertions(+), 23 deletions(-) diff --git a/skipper.go b/skipper.go index 4b189b4a..18d09c7b 100644 --- a/skipper.go +++ b/skipper.go @@ -7,16 +7,17 @@ import ( // FinishValue skips to the end of the current value if (and only if) // we're currently in the middle of reading it. -func (t *tokenizer) finishValue() error { +func (t *tokenizer) finishValue() (bool, error) { if t.unfinished { c, err := t.skipValue() if err != nil { - return err + return true, err } t.unread(c) t.unfinished = false + return true, nil } - return nil + return false, nil } // SkipValue skips to the end of the current value, if the caller diff --git a/textreader.go b/textreader.go index a4090193..7db60cb8 100644 --- a/textreader.go +++ b/textreader.go @@ -8,6 +8,7 @@ import ( "io" "math" "math/big" + "strconv" "strings" "time" ) @@ -24,6 +25,23 @@ const ( trsAfterValue ) +func (s textReaderState) String() string { + switch s { + case trsDone: + return "" + case trsBeforeFieldName: + return "" + case trsBeforeTypeAnnotations: + return "" + case trsBeforeContainer: + return "" + case trsAfterValue: + return "" + default: + return strconv.Itoa(int(s)) + } +} + type textReader struct { tok tokenizer state textReaderState @@ -260,6 +278,37 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } return true, nil + case tokenOpenBrace: + t.state = trsBeforeContainer + t.valueType = StructType + return true, nil + + case tokenOpenBracket: + t.state = trsBeforeContainer + t.valueType = ListType + return true, nil + + case tokenOpenParen: + t.state = trsBeforeContainer + t.valueType = SexpType + return true, nil + + case tokenCloseBrace: + // No more values in this struct. + if t.ctx.peek() == ctxInStruct { + t.eof = true + return true, nil + } + return false, errors.New("unexpected token '}'") + + case tokenCloseBracket: + // No more values in this list. + if t.ctx.peek() == ctxInList { + t.eof = true + return true, nil + } + return false, errors.New("unexpected token ']'") + default: return false, fmt.Errorf("unexpected token '%v'", tok) } @@ -467,9 +516,7 @@ func (t *textReader) onLob() error { val []byte ) - // TODO: Peek for clobs. if c == '"' { - // Short clob. valType = ClobType @@ -481,7 +528,6 @@ func (t *textReader) onLob() error { val = []byte(str) } else if c == '\'' { - // Long clob. ok, err := t.tok.isTripleQuote() if err != nil { @@ -544,6 +590,9 @@ func (t *textReader) IsNull() bool { } func (t *textReader) StepIn() error { + if t.err != nil { + return t.err + } if t.state != trsBeforeContainer { return errors.New("invalid state") } @@ -567,39 +616,49 @@ func (t *textReader) StepIn() error { t.state = trsBeforeTypeAnnotations } + // TODO: Make this less hacky. + t.tok.unfinished = false return nil } func (t *textReader) StepOut() error { + if t.err != nil { + return t.err + } + ctx := t.ctx.peek() if ctx == ctxAtTopLevel { return errors.New("invalid state") } - err := t.tok.finishValue() + _, err := t.tok.finishValue() if err != nil { t.explode(err) return err } - switch t.ctx.peek() { - case ctxInStruct: - err = t.tok.skipStructHelper() - case ctxInList: - err = t.tok.skipListHelper() - case ctxInSexp: - err = t.tok.skipSexpHelper() - default: - panic("invalid ctx") - } + if !t.eof { + // Haven't seen the end of the container yet; skip until we + // find it. + switch t.ctx.peek() { + case ctxInStruct: + err = t.tok.skipStructHelper() + case ctxInList: + err = t.tok.skipListHelper() + case ctxInSexp: + err = t.tok.skipSexpHelper() + default: + panic("invalid ctx") + } - if err != nil { - t.explode(err) - return err + if err != nil { + t.explode(err) + return err + } } t.ctx.pop() - t.state = trsAfterValue + t.state = t.stateAfterValue() t.valueType = NoType t.value = nil @@ -720,12 +779,15 @@ func (t *textReader) ByteValue() ([]byte, error) { // FinishValue finishes reading the current value, if there is one. func (t *textReader) finishValue() error { - err := t.tok.finishValue() + ok, err := t.tok.finishValue() if err != nil { return err } - t.state = t.stateAfterValue() + if ok { + t.state = t.stateAfterValue() + } + return nil } diff --git a/textreader_test.go b/textreader_test.go index 254867a5..e12e64c2 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -10,6 +10,181 @@ import ( "time" ) +func TestIgnoreValues(t *testing.T) { + r := NewTextReaderString("{skip: me, please: true}\n[skip, me, please]\nfoo") + + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != StructType { + t.Fatalf("expected StructType, got %v", r.Type()) + } + + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != ListType { + t.Fatalf("expected ListType, got %v", r.Type()) + } + + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != SymbolType { + t.Fatalf("expected SymbolType, got %v", r.Type()) + } + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + if val != "foo" { + t.Errorf("expected foo, got %v", val) + } + + if r.Next() { + t.Error("next returned true") + } +} + +func TestStructs(t *testing.T) { + test := func(str string, f func(r Reader, t *testing.T)) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != StructType { + t.Errorf("expected type=StructType, got %v", r.Type()) + } + + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + + f(r, t) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } + + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } + }) + } + + test("{\r\n}", func(r Reader, t *testing.T) { + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } + }) + + test("{foo: bar}", func(r Reader, t *testing.T) { + symbol(t, r, "bar") + if r.FieldName() != "foo" { + t.Errorf("expected foo, got %v", r.FieldName()) + } + }) + + test("{foo: a, bar: b, baz: c}", func(r Reader, t *testing.T) { + symbol(t, r, "a") + if r.FieldName() != "foo" { + t.Errorf("expected foo, got %v", r.FieldName()) + } + + symbol(t, r, "b") + if r.FieldName() != "bar" { + t.Errorf("expected bar, got %v", r.FieldName()) + } + + symbol(t, r, "c") + if r.FieldName() != "baz" { + t.Errorf("expected baz, got %v", r.FieldName()) + } + }) +} + +func TestLists(t *testing.T) { + test := func(str string, f func(r Reader, t *testing.T)) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != ListType { + t.Errorf("expected type=ListType, got %v", r.Type()) + } + + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + + f(r, t) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } + + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } + }) + } + + test("[ ]", func(r Reader, t *testing.T) { + if r.Next() { + t.Fatal("next returned true") + } + }) + + test("[foo]", func(r Reader, t *testing.T) { + symbol(t, r, "foo") + if r.Next() { + t.Fatal("next returned true") + } + }) + + test("[foo, bar, baz]", func(r Reader, t *testing.T) { + symbol(t, r, "foo") + symbol(t, r, "bar") + symbol(t, r, "baz") + if r.Next() { + t.Fatal("next returned true") + } + }) +} + +func symbol(t *testing.T, r Reader, eval string) { + next(t, r, SymbolType) + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func next(t *testing.T, r Reader, et Type) { + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != et { + t.Fatalf("expected %v, got %v", et, r.Type()) + } +} + func TestClobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { From 66de8b9035c19e333f6a551f148600d329c88f40 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 15 Jul 2019 14:23:36 +1000 Subject: [PATCH 16/56] sexps --- textreader.go | 8 ++++++++ textreader_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++ textwriter_test.go | 2 +- 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/textreader.go b/textreader.go index 7db60cb8..5c017b3d 100644 --- a/textreader.go +++ b/textreader.go @@ -309,6 +309,14 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } return false, errors.New("unexpected token ']'") + case tokenCloseParen: + // No more values in this sexp. + if t.ctx.peek() == ctxInSexp { + t.eof = true + return true, nil + } + return false, errors.New("unexpected token ')'") + default: return false, fmt.Errorf("unexpected token '%v'", tok) } diff --git a/textreader_test.go b/textreader_test.go index e12e64c2..f175cc77 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -47,6 +47,56 @@ func TestIgnoreValues(t *testing.T) { } } +func TestSexps(t *testing.T) { + test := func(str string, f func(r Reader, t *testing.T)) { + t.Run(str, func(t *testing.T) { + r := NewTextReaderString(str) + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != SexpType { + t.Errorf("expected type=SexpType, got %v", r.Type()) + } + + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + + f(r, t) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } + + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } + }) + } + + test("(\t)", func(r Reader, t *testing.T) { + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } + }) + + test("(foo)", func(r Reader, t *testing.T) { + symbol(t, r, "foo") + }) + + test("(foo bar baz)", func(r Reader, t *testing.T) { + symbol(t, r, "foo") + symbol(t, r, "bar") + symbol(t, r, "baz") + }) +} + func TestStructs(t *testing.T) { test := func(str string, f func(r Reader, t *testing.T)) { t.Run(str, func(t *testing.T) { diff --git a/textwriter_test.go b/textwriter_test.go index 0285e386..a90fb95e 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -121,7 +121,7 @@ func TestNestedLists(t *testing.T) { }) } -func TestSexps(t *testing.T) { +func TestWriteSexps(t *testing.T) { testTextWriter(t, "()\n(())\n(() ())", func(w Writer) { w.BeginSexp() w.EndSexp() From c11d8dbcd3b9e9e3fa9449c882a6015f1436cc8b Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 15 Jul 2019 14:41:10 +1000 Subject: [PATCH 17/56] Adding test data submodule --- .gitmodules | 3 +++ ion-tests | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 ion-tests diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..38171e66 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ion-tests"] + path = ion-tests + url = https://github.com/amzn/ion-tests.git diff --git a/ion-tests b/ion-tests new file mode 160000 index 00000000..ae156785 --- /dev/null +++ b/ion-tests @@ -0,0 +1 @@ +Subproject commit ae1567858e4c215154c613a3c4e3dc25d4c0dfc6 From a69cbd373e10c5b923224b092111068118a2ff16 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 15 Jul 2019 16:55:27 +1000 Subject: [PATCH 18/56] standard test suite for text reader, bugfixes --- reader_test.go | 114 +++++++++++++++++++++++++++++++++++++++++++++ textreader.go | 30 +++++++++--- textreader_test.go | 78 +++++++++++++++++++++++++------ textutils.go | 11 ++++- tokenizer.go | 31 +++++++++++- 5 files changed, 242 insertions(+), 22 deletions(-) create mode 100644 reader_test.go diff --git a/reader_test.go b/reader_test.go new file mode 100644 index 00000000..3741d281 --- /dev/null +++ b/reader_test.go @@ -0,0 +1,114 @@ +package ion + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" +) + +var blacklist = map[string]bool{ + "ion-tests/iontestdata/good/emptyAnnotatedInt.10n": true, + "ion-tests/iontestdata/good/subfieldVarUInt32bit.ion": true, + "ion-tests/iontestdata/good/utf16.ion": true, + "ion-tests/iontestdata/good/utf32.ion": true, + "ion-tests/iontestdata/good/whitespace.ion": true, + "ion-tests/iontestdata/good/item1.10n": true, + + // timestamps too long for time.Parse(); FIXME + "ion-tests/iontestdata/good/equivs/timestampsLargeFractionalPrecision.ion": true, + "ion-tests/iontestdata/good/timestamp/equivTimeline/timestamps.ion": true, +} + +func print(level int, obj interface{}) { + fmt.Print(" > ") + for i := 0; i < level; i++ { + fmt.Print(" ") + } + fmt.Println(obj) +} + +func drain(t *testing.T, r Reader, level int) { + for r.Next() { + // print(level, r.Type()) + + if !r.IsNull() { + switch r.Type() { + case StructType, ListType, SexpType: + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + + drain(t, r, level+1) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } + } + } + } + + if r.Err() != nil { + t.Fatal(r.Err()) + } +} + +func testReaderFile(t *testing.T, path string) { + if _, ok := blacklist[path]; ok { + return + } + + // fmt.Println(path) + + file, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + var r Reader + + if strings.HasSuffix(path, ".ion") { + r = NewTextReader(file) + // r.(*textReader).debug = true + } else if strings.HasSuffix(path, ".10n") { + // Binary ion not yet supported. + return + } else { + t.Fatal("unexpected suffix on file", path) + } + + drain(t, r, 0) +} + +func testReaderDir(t *testing.T, path string) { + files, err := ioutil.ReadDir(path) + if err != nil { + t.Fatal(err) + } + + for _, file := range files { + fp := filepath.Join(path, file.Name()) + if file.IsDir() { + testReaderDir(t, fp) + } else { + t.Run(fp, func(t *testing.T) { + testReaderFile(t, fp) + }) + } + } +} + +func TestReader(t *testing.T) { + testReaderDir(t, "ion-tests/iontestdata/good") +} + +// func TestAllNulls(t *testing.T) { +// testReaderFile(t, "ion-tests/iontestdata/good/allNulls.ion") +// } + +// func TestStructWhitespace(t *testing.T) { +// testReaderFile(t, "ion-tests/iontestdata/good/equivs/structWhitespace.ion") +// } diff --git a/textreader.go b/textreader.go index 5c017b3d..1c7f5a9b 100644 --- a/textreader.go +++ b/textreader.go @@ -53,6 +53,8 @@ type textReader struct { typeAnnotations []string valueType Type value interface{} + + debug bool } // NewTextReader creates a new text reader. @@ -80,12 +82,20 @@ func (t *textReader) Next() bool { return false } + if t.debug { + fmt.Println("state:", t.state) + } + err := t.finishValue() if err != nil { t.explode(err) return false } + if t.debug { + fmt.Println("state after finish:", t.state) + } + t.fieldName = "" t.typeAnnotations = nil t.valueType = NoType @@ -96,6 +106,10 @@ func (t *textReader) Next() bool { return false } + if t.debug { + fmt.Println("read token:", t.tok.Token()) + } + for { var f func() (bool, error) @@ -107,7 +121,7 @@ func (t *textReader) Next() bool { case trsBeforeTypeAnnotations: f = t.nextBeforeTypeAnnotations default: - panic("invalid state") + panic(fmt.Sprintf("invalid state: %v", t.state)) } done, err := f() @@ -140,7 +154,7 @@ func (t *textReader) nextAfterValue() (bool, error) { case ctxInList: t.state = trsBeforeTypeAnnotations default: - panic("invalid state") + panic(fmt.Sprintf("invalid state: %v", t.ctx.peek())) } return false, nil @@ -175,7 +189,7 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { t.eof = true return true, nil - case tokenSymbol, tokenSymbolQuoted: + case tokenSymbol, tokenSymbolQuoted, tokenString, tokenLongString: // Read the field name. val, err := t.tok.ReadValue(tok) if err != nil { @@ -217,7 +231,7 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } return false, errors.New("unexpected EOF") - case tokenSymbol, tokenSymbolQuoted: + case tokenSymbol, tokenSymbolQuoted, tokenSymbolOperator, tokenDot: val, err := t.tok.ReadValue(tok) if err != nil { return false, err @@ -281,16 +295,19 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { case tokenOpenBrace: t.state = trsBeforeContainer t.valueType = StructType + t.value = StructType return true, nil case tokenOpenBracket: t.state = trsBeforeContainer t.valueType = ListType + t.value = ListType return true, nil case tokenOpenParen: t.state = trsBeforeContainer t.valueType = SexpType + t.value = SexpType return true, nil case tokenCloseBrace: @@ -602,7 +619,7 @@ func (t *textReader) StepIn() error { return t.err } if t.state != trsBeforeContainer { - return errors.New("invalid state") + return fmt.Errorf("stepin called in invalid state %v", t.state) } var ctx ctxType @@ -636,7 +653,7 @@ func (t *textReader) StepOut() error { ctx := t.ctx.peek() if ctx == ctxAtTopLevel { - return errors.New("invalid state") + return errors.New("stepout called at top level") } _, err := t.tok.finishValue() @@ -669,6 +686,7 @@ func (t *textReader) StepOut() error { t.state = t.stateAfterValue() t.valueType = NoType t.value = nil + t.eof = false return nil } diff --git a/textreader_test.go b/textreader_test.go index f175cc77..eea4929e 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -47,7 +47,7 @@ func TestIgnoreValues(t *testing.T) { } } -func TestSexps(t *testing.T) { +func TestReadSexps(t *testing.T) { test := func(str string, f func(r Reader, t *testing.T)) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) @@ -161,6 +161,56 @@ func TestStructs(t *testing.T) { }) } +func TestMultipleStructs(t *testing.T) { + r := NewTextReaderString("{} {} {}") + + for i := 0; i < 3; i++ { + if !r.Next() { + t.Error("next returned false") + t.Fatal(r.Err()) + } + if r.Type() != StructType { + t.Fatalf("expected struct, got %v", r.Type()) + } + + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + if r.Next() { + t.Fatal("next returned true") + } + if err := r.StepOut(); err != nil { + t.Fatal(err) + } + } + + if r.Next() { + t.Fatal("next returned true") + } +} + +func TestNullStructs(t *testing.T) { + r := NewTextReaderString("null.struct {}") + + if !r.Next() { + t.Fatal(r.Err()) + } + if !r.IsNull() { + t.Error("expected null, got not-null") + } + + if !r.Next() { + t.Fatal(r.Err()) + } + if r.IsNull() { + t.Error("expected not-null, got null") + } + + if r.Next() { + t.Fatal("next returned true") + } +} + func TestLists(t *testing.T) { test := func(str string, f func(r Reader, t *testing.T)) { t.Run(str, func(t *testing.T) { @@ -471,12 +521,12 @@ func TestInts(t *testing.T) { } testInt("0", 0) - testInt("12345", 12345) - testInt("-12345", -12345) - testInt("0b000101", 5) - testInt("-0b000101", -5) - testInt("0x01020e0F", 0x01020e0f) - testInt("-0x01020e0F", -0x01020e0f) + testInt("12_345", 12345) + testInt("-1_2_3_4_5", -12345) + testInt("0b00_0101", 5) + testInt("-0b00_0101", -5) + testInt("0x01_02_0e_0F", 0x01020e0f) + testInt("-0x0102_0e0F", -0x01020e0f) testInt64 := func(str string, eval int64) { test(str, func(r Reader) error { @@ -491,17 +541,17 @@ func TestInts(t *testing.T) { }) } - testInt64("0x123FFFFFFFF", 0x123FFFFFFFF) - testInt64("-0x123FFFFFFFF", -0x123FFFFFFFF) + testInt64("0x123_FFFF_FFFF", 0x123FFFFFFFF) + testInt64("-0x123_FFFF_FFFF", -0x123FFFFFFFF) - testBigInt := func(str string) { + testBigInt := func(str string, estr string) { test(str, func(r Reader) error { val, err := r.BigIntValue() if err != nil { return err } - eval, _ := (&big.Int{}).SetString(str, 0) + eval, _ := (&big.Int{}).SetString(estr, 0) if eval.Cmp(val) != 0 { return fmt.Errorf("expected %v, got %v", eval, val) } @@ -510,9 +560,9 @@ func TestInts(t *testing.T) { }) } - testBigInt("0xEFFFFFFFFFFFFFFF") - testBigInt("0xFFFFFFFFFFFFFFFF") - testBigInt("-0x1FFFFFFFFFFFFFFFF") + testBigInt("0xEFFF_FFFF_FFFF_FFFF", "0xEFFFFFFFFFFFFFFF") + testBigInt("0xFFFF_FFFF_FFFF_FFFF", "0xFFFFFFFFFFFFFFFF") + testBigInt("-0x1_FFFF_FFFF_FFFF_FFFF", "-0x1FFFFFFFFFFFFFFFF") } func TestStrings(t *testing.T) { diff --git a/textutils.go b/textutils.go index 90d8d905..63b77bd8 100644 --- a/textutils.go +++ b/textutils.go @@ -237,7 +237,16 @@ func writeRawChar(c byte, out io.Writer) error { } func parseFloat(str string) (float64, error) { - return strconv.ParseFloat(str, 64) + val, err := strconv.ParseFloat(str, 64) + if err != nil { + if ne, ok := err.(*strconv.NumError); ok { + if ne.Err == strconv.ErrRange { + // Ignore me, val will be +-inf which is fine. + return val, nil + } + } + } + return val, err } func parseDecimal(str string) (*Decimal, error) { diff --git a/tokenizer.go b/tokenizer.go index 5300616b..d5059980 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -304,6 +304,8 @@ func (t *tokenizer) ReadValue(tok tokenType) (string, error) { str, err = t.readSymbol() case tokenSymbolQuoted: str, err = t.readQuotedSymbol() + case tokenSymbolOperator, tokenDot: + str, err = t.readOperator() case tokenString: str, err = t.readString() case tokenLongString: @@ -487,6 +489,26 @@ func (t *tokenizer) readQuotedSymbol() (string, error) { } } +func (t *tokenizer) readOperator() (string, error) { + ret := strings.Builder{} + + c, err := t.peek() + if err != nil { + return "", err + } + + for isOperatorChar(c) { + ret.WriteByte(byte(c)) + t.read() + c, err = t.peek() + if err != nil { + return "", err + } + } + + return ret.String(), nil +} + // ReadString reads a quoted string. func (t *tokenizer) readString() (string, error) { ret := strings.Builder{} @@ -598,6 +620,10 @@ func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { return '\r', nil case 'v': return '\v', nil + case '?': + return '?', nil + case '/': + return '/', nil case '\'': return '\'', nil case '"': @@ -615,7 +641,7 @@ func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { return t.readHexEscapeSeq(2) } - return 0, fmt.Errorf("bad escape sequence '\\%q'", c) + return 0, fmt.Errorf("bad escape sequence '\\%c'", c) } func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { @@ -712,6 +738,9 @@ func (t *tokenizer) readRadixDigits(dok matcher, w io.ByteWriter) (int, error) { if err != nil { return 0, err } + if c == '_' { + continue + } if !dok(c) { return c, nil } From df048f479d4cf4ad1330f81aceb439861870965a Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 16 Jul 2019 13:18:24 +1000 Subject: [PATCH 19/56] test refactoring, bugfixes --- textreader.go | 9 +- textreader_test.go | 761 ++++++++++++++++++--------------------------- tokenizer.go | 4 - 3 files changed, 312 insertions(+), 462 deletions(-) diff --git a/textreader.go b/textreader.go index 1c7f5a9b..33ead3a3 100644 --- a/textreader.go +++ b/textreader.go @@ -231,7 +231,14 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } return false, errors.New("unexpected EOF") - case tokenSymbol, tokenSymbolQuoted, tokenSymbolOperator, tokenDot: + case tokenSymbolOperator, tokenDot: + if t.ctx.peek() != ctxInSexp { + // Operators can only appear inside an sexp. + return false, fmt.Errorf("unexpected token '%v'", tok) + } + fallthrough + + case tokenSymbol, tokenSymbolQuoted: val, err := t.tok.ReadValue(tok) if err != nil { return false, err diff --git a/textreader_test.go b/textreader_test.go index eea4929e..c098804f 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -2,8 +2,6 @@ package ion import ( "bytes" - "errors" - "fmt" "math" "math/big" "testing" @@ -11,73 +9,26 @@ import ( ) func TestIgnoreValues(t *testing.T) { - r := NewTextReaderString("{skip: me, please: true}\n[skip, me, please]\nfoo") + r := NewTextReaderString("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != StructType { - t.Fatalf("expected StructType, got %v", r.Type()) - } + _next(t, r, SexpType) + _next(t, r, StructType) + _next(t, r, ListType) - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != ListType { - t.Fatalf("expected ListType, got %v", r.Type()) - } - - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != SymbolType { - t.Fatalf("expected SymbolType, got %v", r.Type()) - } - - val, err := r.StringValue() - if err != nil { - t.Fatal(err) - } - if val != "foo" { - t.Errorf("expected foo, got %v", val) - } - - if r.Next() { - t.Error("next returned true") - } + _symbol(t, r, "foo") + _eof(t, r) } func TestReadSexps(t *testing.T) { - test := func(str string, f func(r Reader, t *testing.T)) { + test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != SexpType { - t.Errorf("expected type=SexpType, got %v", r.Type()) - } - - if err := r.StepIn(); err != nil { - t.Fatal(err) - } - - f(r, t) - - if err := r.StepOut(); err != nil { - t.Fatal(err) - } - - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() != nil { - t.Fatal(r.Err()) - } + _sexp(t, r, f) + _eof(t, r) }) } - test("(\t)", func(r Reader, t *testing.T) { + test("(\t)", func(t *testing.T, r Reader) { if r.Next() { t.Errorf("next returned true") } @@ -86,78 +37,38 @@ func TestReadSexps(t *testing.T) { } }) - test("(foo)", func(r Reader, t *testing.T) { - symbol(t, r, "foo") + test("(foo)", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") }) - test("(foo bar baz)", func(r Reader, t *testing.T) { - symbol(t, r, "foo") - symbol(t, r, "bar") - symbol(t, r, "baz") + test("(foo bar baz :: boop)", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + _symbol(t, r, "bar") + _symbolAF(t, r, "", []string{"baz"}, "boop") }) } func TestStructs(t *testing.T) { - test := func(str string, f func(r Reader, t *testing.T)) { + test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != StructType { - t.Errorf("expected type=StructType, got %v", r.Type()) - } - - if err := r.StepIn(); err != nil { - t.Fatal(err) - } - - f(r, t) - - if err := r.StepOut(); err != nil { - t.Fatal(err) - } - - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() != nil { - t.Fatal(r.Err()) - } + _struct(t, r, f) + _eof(t, r) }) } - test("{\r\n}", func(r Reader, t *testing.T) { - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() != nil { - t.Fatal(r.Err()) - } + test("{\r\n}", func(t *testing.T, r Reader) { + _eof(t, r) }) - test("{foo: bar}", func(r Reader, t *testing.T) { - symbol(t, r, "bar") - if r.FieldName() != "foo" { - t.Errorf("expected foo, got %v", r.FieldName()) - } + test("{foo : bar :: baz}", func(t *testing.T, r Reader) { + _symbolAF(t, r, "foo", []string{"bar"}, "baz") }) - test("{foo: a, bar: b, baz: c}", func(r Reader, t *testing.T) { - symbol(t, r, "a") - if r.FieldName() != "foo" { - t.Errorf("expected foo, got %v", r.FieldName()) - } - - symbol(t, r, "b") - if r.FieldName() != "bar" { - t.Errorf("expected bar, got %v", r.FieldName()) - } - - symbol(t, r, "c") - if r.FieldName() != "baz" { - t.Errorf("expected baz, got %v", r.FieldName()) - } + test("{foo: a, bar: b, baz: c}", func(t *testing.T, r Reader) { + _symbolAF(t, r, "foo", nil, "a") + _symbolAF(t, r, "bar", nil, "b") + _symbolAF(t, r, "baz", nil, "c") }) } @@ -165,137 +76,73 @@ func TestMultipleStructs(t *testing.T) { r := NewTextReaderString("{} {} {}") for i := 0; i < 3; i++ { - if !r.Next() { - t.Error("next returned false") - t.Fatal(r.Err()) - } - if r.Type() != StructType { - t.Fatalf("expected struct, got %v", r.Type()) - } - - if err := r.StepIn(); err != nil { - t.Fatal(err) - } - if r.Next() { - t.Fatal("next returned true") - } - if err := r.StepOut(); err != nil { - t.Fatal(err) - } + _struct(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) } - if r.Next() { - t.Fatal("next returned true") - } + _eof(t, r) } func TestNullStructs(t *testing.T) { - r := NewTextReaderString("null.struct {}") - - if !r.Next() { - t.Fatal(r.Err()) - } - if !r.IsNull() { - t.Error("expected null, got not-null") - } - - if !r.Next() { - t.Fatal(r.Err()) - } - if r.IsNull() { - t.Error("expected not-null, got null") - } + r := NewTextReaderString("null.struct 'null'::{foo:bar}") - if r.Next() { - t.Fatal("next returned true") - } + _null(t, r, StructType) + _nextAF(t, r, StructType, "", []string{"null"}) + _eof(t, r) } func TestLists(t *testing.T) { - test := func(str string, f func(r Reader, t *testing.T)) { + test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != ListType { - t.Errorf("expected type=ListType, got %v", r.Type()) - } - - if err := r.StepIn(); err != nil { - t.Fatal(err) - } - - f(r, t) - - if err := r.StepOut(); err != nil { - t.Fatal(err) - } - - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() != nil { - t.Fatal(r.Err()) - } + _list(t, r, f) + _eof(t, r) }) } - test("[ ]", func(r Reader, t *testing.T) { - if r.Next() { - t.Fatal("next returned true") - } + test("[ ]", func(t *testing.T, r Reader) { + _eof(t, r) }) - test("[foo]", func(r Reader, t *testing.T) { - symbol(t, r, "foo") - if r.Next() { - t.Fatal("next returned true") - } + test("[foo]", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + _eof(t, r) }) - test("[foo, bar, baz]", func(r Reader, t *testing.T) { - symbol(t, r, "foo") - symbol(t, r, "bar") - symbol(t, r, "baz") - if r.Next() { - t.Fatal("next returned true") - } + test("[foo, bar, baz::boop]", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + _symbol(t, r, "bar") + _symbolAF(t, r, "", []string{"baz"}, "boop") + _eof(t, r) }) } -func symbol(t *testing.T, r Reader, eval string) { - next(t, r, SymbolType) - - val, err := r.StringValue() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) +func TestReadNestedLists(t *testing.T) { + empty := func(t *testing.T, r Reader) { + _eof(t, r) } -} -func next(t *testing.T, r Reader, et Type) { - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != et { - t.Fatalf("expected %v, got %v", et, r.Type()) - } + r := NewTextReaderString("[[], [[]]]") + + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, empty) + + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, empty) + }) + + _eof(t, r) + }) + + _eof(t, r) } func TestClobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Error("next returned false") - t.Fatal(r.Err()) - } - if r.Type() != ClobType { - t.Errorf("expected type=ClobType, got %v", r.Type()) - } + _next(t, r, ClobType) val, err := r.ByteValue() if err != nil { @@ -305,12 +152,7 @@ func TestClobs(t *testing.T) { t.Errorf("expected %v, got %v", eval, val) } - if r.Next() { - t.Error("next returned true") - } - if r.Err() != nil { - t.Error(r.Err()) - } + _eof(t, r) }) } @@ -324,13 +166,7 @@ func TestBlobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Error("next returned false") - t.Fatal(r.Err()) - } - if r.Type() != BlobType { - t.Errorf("expected type=BlobType, got %v", r.Type()) - } + _next(t, r, BlobType) val, err := r.ByteValue() if err != nil { @@ -340,12 +176,7 @@ func TestBlobs(t *testing.T) { t.Errorf("expected %v, got %v", eval, val) } - if r.Next() { - t.Error("next returned true") - } - if r.Err() != nil { - t.Error(r.Err()) - } + _eof(t, r) }) } @@ -355,16 +186,10 @@ func TestBlobs(t *testing.T) { } func TestTimestamps(t *testing.T) { - test := func(str string, eval time.Time) { + testA := func(str string, etas []string, eval time.Time) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Error("next returned false") - t.Fatal(r.Err()) - } - if r.Type() != TimestampType { - t.Errorf("expected type=TimestampType, got %v", r.Type()) - } + _nextAF(t, r, TimestampType, "", etas) val, err := r.TimeValue() if err != nil { @@ -374,15 +199,14 @@ func TestTimestamps(t *testing.T) { t.Errorf("expected %v, got %v", eval, val) } - if r.Next() { - t.Error("next returned true") - } - if r.Err() != nil { - t.Error(r.Err()) - } + _eof(t, r) }) } + test := func(str string, eval time.Time) { + testA(str, nil, eval) + } + et := time.Date(2001, time.January, 1, 0, 0, 0, 0, time.UTC) test("2001T", et) test("2001-01T", et) @@ -392,21 +216,18 @@ func TestTimestamps(t *testing.T) { test("2001-01-01T00:00:00Z", et) test("2001-01-01T00:00:00.000Z", et) test("2001-01-01T00:00:00.000+00:00", et) + + testA("foo::'bar'::2001-01-01T00:00:00.000Z", []string{"foo", "bar"}, et) } func TestDoubles(t *testing.T) { - test := func(str string, eval string) { + testA := func(str string, etas []string, eval string) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) - if !r.Next() { - t.Error("next returned false") - } - if r.Type() != DecimalType { - t.Errorf("expected type=DecimalType, got %v", r.Type()) - } - ee := MustParseDecimal(eval) + r := NewTextReaderString(str) + _nextAF(t, r, DecimalType, "", etas) + val, err := r.DecimalValue() if err != nil { t.Fatal(err) @@ -415,108 +236,73 @@ func TestDoubles(t *testing.T) { t.Errorf("expected %v, got %v", ee, val) } - if r.Next() { - t.Error("next returned true") - } - if r.Err() != nil { - t.Error(r.Err()) - } + _eof(t, r) }) } + test := func(str string, eval string) { + testA(str, nil, eval) + } + test("123.", "123") test("123.0", "123") test("123.456", "123.456") test("123d2", "12300") test("123d+2", "12300") test("123d-2", "1.23") + + testA(" foo :: 'bar' :: 123. ", []string{"foo", "bar"}, "123") } func TestFloats(t *testing.T) { - test := func(str string, eval float64) { + testA := func(str string, etas []string, eval float64) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Error("next returned false") - } - if r.Type() != FloatType { - t.Errorf("expected type=FloatType, got %v", r.Type()) - } - - val, err := r.FloatValue() - if err != nil { - t.Error(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - - if r.Next() { - t.Error("next returned true") - } - if r.Err() != nil { - t.Error(r.Err()) - } + _floatAF(t, r, "", etas, eval) + _eof(t, r) }) } + test := func(str string, eval float64) { + testA(str, nil, eval) + } + test("1e100\n", 1e100) test("1.2e+0", 1.2) test("-123.456e-78", -123.456e-78) test("+inf", math.Inf(1)) test("-inf", math.Inf(-1)) + + testA("foo::'bar'::1e100", []string{"foo", "bar"}, 1e100) } func TestInts(t *testing.T) { - test := func(str string, m func(Reader) error) { + test := func(str string, f func(*testing.T, Reader)) { t.Run(str, func(t *testing.T) { r := NewTextReaderString(str) - if !r.Next() { - t.Error("next returned false") - } - if r.Type() != IntType { - t.Errorf("expected type=IntType, got %v", r.Type()) - } + _next(t, r, IntType) - if err := m(r); err != nil { - t.Error(err) - } + f(t, r) - if r.Next() { - t.Error("next returned true") - } - if r.Err() != nil { - t.Error(r.Err()) - } + _eof(t, r) }) } - test("null.int", func(r Reader) error { + test("null.int", func(t *testing.T, r Reader) { if !r.IsNull() { - return errors.New("expected isnull=true, got false") - } - - val, err := r.IntValue() - if err != nil { - return err + t.Fatal("expected isnull=true, got false") } - if val != 0 { - return fmt.Errorf("expected 0, got %v", val) - } - - return nil }) testInt := func(str string, eval int) { - test(str, func(r Reader) error { + test(str, func(t *testing.T, r Reader) { val, err := r.IntValue() if err != nil { - return err + t.Fatal(err) } if val != eval { - return fmt.Errorf("expected %v, got %v", eval, val) + t.Errorf("expected %v, got %v", eval, val) } - return nil }) } @@ -529,15 +315,14 @@ func TestInts(t *testing.T) { testInt("-0x0102_0e0F", -0x01020e0f) testInt64 := func(str string, eval int64) { - test(str, func(r Reader) error { + test(str, func(t *testing.T, r Reader) { val, err := r.Int64Value() if err != nil { - return err + t.Fatal(err) } if val != eval { - return fmt.Errorf("expected %v, got %v", eval, val) + t.Errorf("expected %v, got %v", eval, val) } - return nil }) } @@ -545,18 +330,16 @@ func TestInts(t *testing.T) { testInt64("-0x123_FFFF_FFFF", -0x123FFFFFFFF) testBigInt := func(str string, estr string) { - test(str, func(r Reader) error { + test(str, func(t *testing.T, r Reader) { val, err := r.BigIntValue() if err != nil { - return err + t.Fatal(err) } eval, _ := (&big.Int{}).SetString(estr, 0) if eval.Cmp(val) != 0 { - return fmt.Errorf("expected %v, got %v", eval, val) + t.Errorf("expected %v, got %v", eval, val) } - - return nil }) } @@ -568,176 +351,240 @@ func TestInts(t *testing.T) { func TestStrings(t *testing.T) { r := NewTextReaderString(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) - test := func(etas []string, eval string) { - if !r.Next() { - t.Fatal("next returned false") - } + _stringAF(t, r, "", []string{"foo"}, "bar") + _string(t, r, "baz") + _stringAF(t, r, "", []string{"a", "b"}, "beepboop") + _null(t, r, StringType) - if r.Type() != StringType { - t.Fatalf("expected type=string, got type=%v", r.Type()) - } + _eof(t, r) +} - if !strequals(r.TypeAnnotations(), etas) { - t.Errorf("expected tas=%v, got tas=%v", etas, r.TypeAnnotations()) - } +func TestSymbols(t *testing.T) { + r := NewTextReaderString("'null'::foo bar a::b::'baz' null.symbol") - val, err := r.StringValue() - if err != nil { - t.Fatal(err) - } + _symbolAF(t, r, "", []string{"null"}, "foo") + _symbol(t, r, "bar") + _symbolAF(t, r, "", []string{"a", "b"}, "baz") + _null(t, r, SymbolType) - if val != eval { - t.Errorf("expected val=%v, got val=%v", eval, val) - } - } + _eof(t, r) +} + +func TestSpecialSymbols(t *testing.T) { + r := NewTextReaderString("null\nnull.struct\ntrue\nfalse\nnan") + + _null(t, r, NullType) + _null(t, r, StructType) + + _bool(t, r, true) + _bool(t, r, false) + _float(t, r, math.NaN()) + _eof(t, r) +} + +func TestOperators(t *testing.T) { + r := NewTextReaderString("(a*(b+c))") + + _sexp(t, r, func(t *testing.T, r Reader) { + _symbol(t, r, "a") + _symbol(t, r, "*") + _sexp(t, r, func(t *testing.T, r Reader) { + _symbol(t, r, "b") + _symbol(t, r, "+") + _symbol(t, r, "c") + _eof(t, r) + }) + _eof(t, r) + }) +} + +func TestTopLevelOperators(t *testing.T) { + r := NewTextReaderString("a + b") - test([]string{"foo"}, "bar") - test(nil, "baz") - test([]string{"a", "b"}, "beepboop") - test(nil, "") + _symbol(t, r, "a") if r.Next() { - t.Errorf("next unexpectedly returned true") + t.Errorf("next returned true") } - if r.Err() != nil { - t.Error(r.Err()) + if r.Err() == nil { + t.Error("no error") } } -func TestSymbols(t *testing.T) { - r := NewTextReaderString("'null'::foo bar a::b::'baz' null.symbol") +type containerhandler func(t *testing.T, r Reader) - test := func(etas []string, eval string) { - if !r.Next() { - t.Fatal("next returned false") - } +func _sexp(t *testing.T, r Reader, f containerhandler) { + _sexpAF(t, r, "", nil, f) +} - if r.Type() != SymbolType { - t.Fatalf("expected type=symbol, got type=%v", r.Type()) - } +func _sexpAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { + _containerAF(t, r, SexpType, efn, etas, f) +} - if !strequals(r.TypeAnnotations(), etas) { - t.Errorf("expected tas=%v, got tas=%v", etas, r.TypeAnnotations()) - } +func _struct(t *testing.T, r Reader, f containerhandler) { + _structAF(t, r, "", nil, f) +} - val, err := r.StringValue() - if err != nil { - t.Fatal(err) - } +func _structAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { + _containerAF(t, r, StructType, efn, etas, f) +} + +func _list(t *testing.T, r Reader, f containerhandler) { + _listAF(t, r, "", nil, f) +} + +func _listAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { + _containerAF(t, r, ListType, efn, etas, f) +} + +func _containerAF(t *testing.T, r Reader, et Type, efn string, etas []string, f containerhandler) { + _nextAF(t, r, et, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.%v", et, et) + } + + if err := r.StepIn(); err != nil { + t.Fatal(err) + } - if val != eval { - t.Errorf("expected val=%v, got val=%v", eval, val) + f(t, r) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } +} + +func _float(t *testing.T, r Reader, eval float64) { + _floatAF(t, r, "", nil, eval) +} + +func _floatAF(t *testing.T, r Reader, efn string, etas []string, eval float64) { + _nextAF(t, r, FloatType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.float", eval) + } + + val, err := r.FloatValue() + if err != nil { + t.Fatal(err) + } + + if math.IsNaN(eval) { + if !math.IsNaN(val) { + t.Errorf("expected %v, got %v", eval, val) } + } else if eval != val { + t.Errorf("expected %v, got %v", eval, val) } +} - test([]string{"null"}, "foo") - test(nil, "bar") - test([]string{"a", "b"}, "baz") - test(nil, "") +func _string(t *testing.T, r Reader, eval string) { + _stringAF(t, r, "", nil, eval) +} - if r.Next() { - t.Errorf("next unexpectedly returned true") +func _stringAF(t *testing.T, r Reader, efn string, etas []string, eval string) { + _nextAF(t, r, StringType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.string", eval) } - if r.Err() != nil { - t.Error(r.Err()) + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) } } -func strequals(a, b []string) bool { - if len(a) != len(b) { - return false +func _symbol(t *testing.T, r Reader, eval string) { + _symbolAF(t, r, "", nil, eval) +} + +func _symbolAF(t *testing.T, r Reader, efn string, etas []string, eval string) { + _nextAF(t, r, SymbolType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.symbol", eval) } - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } + val, err := r.StringValue() + if err != nil { + t.Fatal(err) } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} - return true +func _bool(t *testing.T, r Reader, eval bool) { + _boolAF(t, r, "", nil, eval) } -func TestSpecialSymbols(t *testing.T) { - r := NewTextReaderString("null\nnull.struct\ntrue\nfalse\nnan") +func _boolAF(t *testing.T, r Reader, efn string, etas []string, eval bool) { + _nextAF(t, r, BoolType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.bool", eval) + } - // null - { - if !r.Next() { - t.Fatal("next returned false") - } - if r.Type() != NullType { - t.Errorf("expected type=NullType, got %v", r.Type()) - } - if !r.IsNull() { - t.Error("expected isNull=true, got false") - } + val, err := r.BoolValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) } +} - // null.struct - { - if !r.Next() { - t.Fatal("next returned false") - } - if r.Type() != StructType { - t.Errorf("expected type=StructType, got %v", r.Type()) - } - if !r.IsNull() { - t.Error("expected isNull=true, got false") - } +func _null(t *testing.T, r Reader, et Type) { + _nullAF(t, r, et, "", nil) +} + +func _nullAF(t *testing.T, r Reader, et Type, efn string, etas []string) { + _nextAF(t, r, et, efn, etas) + if !r.IsNull() { + t.Error("isnull returned false") } +} - // true - { - if !r.Next() { - t.Fatal("next returned false") - } - if r.Type() != BoolType { - t.Errorf("expected type=BoolType, got %v", r.Type()) - } - val, err := r.BoolValue() - if err != nil { - t.Fatal(err) - } - if !val { - t.Error("expected value=true, got false") - } +func _next(t *testing.T, r Reader, et Type) { + _nextAF(t, r, et, "", nil) +} + +func _nextAF(t *testing.T, r Reader, et Type, efn string, etas []string) { + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != et { + t.Fatalf("expected %v, got %v", et, r.Type()) } - // false - { - if !r.Next() { - t.Fatal("next returned false") - } - if r.Type() != BoolType { - t.Errorf("expected type=BoolType, got %v", r.Type()) - } - val, err := r.BoolValue() - if err != nil { - t.Fatal(err) - } - if val { - t.Error("expected value=false, got true") - } + if efn != r.FieldName() { + t.Errorf("expected fieldname=%v, got %v", efn, r.FieldName()) } + if !_strequals(etas, r.TypeAnnotations()) { + t.Errorf("expected type annotations=%v, got %v", etas, r.TypeAnnotations()) + } +} - // nan - { - if !r.Next() { - t.Fatal("next returned false") - } - if r.Type() != FloatType { - t.Errorf("expected type=FloatType, got %v", r.Type()) - } - val, err := r.FloatValue() - if err != nil { - t.Fatal(err) - } - if !math.IsNaN(val) { - t.Errorf("expected value=NaN, got %v", val) +func _strequals(a, b []string) bool { + if len(a) != len(b) { + return false + } + + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false } } + return true +} + +func _eof(t *testing.T, r Reader) { if r.Next() { - t.Error("next returned true") + t.Fatal("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) } } diff --git a/tokenizer.go b/tokenizer.go index d5059980..7d5310e7 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -155,10 +155,6 @@ func (t *tokenizer) Next() error { case c == -1: return t.finish(tokenEOF, true) - case c == '/': - t.unread(c) - return t.finish(tokenSymbolOperator, true) - case c == ':': c2, err := t.peek() if err != nil { From 1991b6565b674e99d7309c085607cc7588ec6767 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 22 Jul 2019 12:24:33 +1000 Subject: [PATCH 20/56] Truncate text timestamps with more than nanosecond precision --- reader_test.go | 4 ---- textreader_test.go | 3 +++ textutils.go | 13 +++++++++++++ 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/reader_test.go b/reader_test.go index 3741d281..6aad0833 100644 --- a/reader_test.go +++ b/reader_test.go @@ -16,10 +16,6 @@ var blacklist = map[string]bool{ "ion-tests/iontestdata/good/utf32.ion": true, "ion-tests/iontestdata/good/whitespace.ion": true, "ion-tests/iontestdata/good/item1.10n": true, - - // timestamps too long for time.Parse(); FIXME - "ion-tests/iontestdata/good/equivs/timestampsLargeFractionalPrecision.ion": true, - "ion-tests/iontestdata/good/timestamp/equivTimeline/timestamps.ion": true, } func print(level int, obj interface{}) { diff --git a/textreader_test.go b/textreader_test.go index c098804f..20ff4592 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -216,6 +216,9 @@ func TestTimestamps(t *testing.T) { test("2001-01-01T00:00:00Z", et) test("2001-01-01T00:00:00.000Z", et) test("2001-01-01T00:00:00.000+00:00", et) + test("2001-01-01T00:00:00.000000Z", et) + test("2001-01-01T00:00:00.000000000Z", et) + test("2001-01-01T00:00:00.000000000999Z", et) // We truncate, at least for now. testA("foo::'bar'::2001-01-01T00:00:00.000Z", []string{"foo", "bar"}, et) } diff --git a/textutils.go b/textutils.go index 63b77bd8..edb4682d 100644 --- a/textutils.go +++ b/textutils.go @@ -355,6 +355,19 @@ func parseTimestamp(val string) (time.Time, error) { return time.Parse("2006-01-02T15:04Z07:00", val) } + if len(val) > 19 && val[19] == '.' { + i := 20 + for i < len(val) && isDigit(int(val[i])) { + i++ + } + + if i >= 29 { + // Too much precision for a go Time. + // TODO: We should probably round instead of truncating? Ah well. + return time.Parse(time.RFC3339Nano, val[:29]+val[i:]) + } + } + return time.Parse(time.RFC3339Nano, val) } From 753267261d28c642f80c2ab8a7972a7432502f71 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 22 Jul 2019 21:02:27 +1000 Subject: [PATCH 21/56] marshal --- marshal.go | 269 ++++++++++++++++++++++++++++++++++++++++++++++++ marshal_test.go | 60 +++++++++++ textwriter.go | 27 ++++- 3 files changed, 352 insertions(+), 4 deletions(-) create mode 100644 marshal.go create mode 100644 marshal_test.go diff --git a/marshal.go b/marshal.go new file mode 100644 index 00000000..1769197c --- /dev/null +++ b/marshal.go @@ -0,0 +1,269 @@ +package ion + +import ( + "bytes" + "fmt" + "math/big" + "reflect" + "sort" + "strings" +) + +type marshaller struct { + buf *bytes.Buffer + w Writer +} + +func newTextMarshaller() *marshaller { + buf := &bytes.Buffer{} + return &marshaller{ + buf: buf, + w: NewTextWriterOpts(buf, OptQuietFinish), + } +} + +// MarshalText marshals values to text ion. +func MarshalText(v interface{}) ([]byte, error) { + m := newTextMarshaller() + if err := m.marshal(v); err != nil { + return nil, err + } + return m.buf.Bytes(), nil +} + +func (m *marshaller) marshal(v interface{}) error { + if err := m.marshalValue(reflect.ValueOf(v)); err != nil { + return err + } + return m.w.Finish() +} + +func (m *marshaller) marshalValue(r reflect.Value) error { + if !r.IsValid() { + m.w.WriteNull() + return nil + } + + t := r.Type() + switch t.Kind() { + case reflect.Bool: + m.w.WriteBool(r.Bool()) + return m.w.Err() + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + m.w.WriteInt(r.Int()) + return m.w.Err() + + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + m.w.WriteInt(int64(r.Uint())) + return m.w.Err() + + case reflect.Uint, reflect.Uint64, reflect.Uintptr: + i := big.Int{} + i.SetUint64(r.Uint()) + m.w.WriteBigInt(&i) + return m.w.Err() + + case reflect.Float32, reflect.Float64: + m.w.WriteFloat(r.Float()) + return m.w.Err() + + // TODO: Decimal + // TODO: Time + + case reflect.String: + m.w.WriteString(r.String()) + return m.w.Err() + + case reflect.Interface, reflect.Ptr: + return m.marshalInterfaceOrPtr(r) + + case reflect.Struct: + return m.marshalStruct(r) + + case reflect.Map: + return m.marshalMap(r) + + case reflect.Slice: + return m.marshalSlice(r) + + case reflect.Array: + return m.marshalArray(r) + + default: + return fmt.Errorf("unsupported type %v", r.Type()) + } +} + +func (m *marshaller) marshalInterfaceOrPtr(r reflect.Value) error { + if r.IsNil() { + m.w.WriteNull() + return m.w.Err() + } + return m.marshalValue(r.Elem()) +} + +func (m *marshaller) marshalMap(r reflect.Value) error { + if r.IsNil() { + m.w.WriteNull() + return m.w.Err() + } + + m.w.BeginStruct() + + for _, key := range getKeys(r) { + m.w.FieldName(key.s) + value := r.MapIndex(key.v) + if err := m.marshalValue(value); err != nil { + return err + } + } + + m.w.EndStruct() + return m.w.Err() +} + +type mapkey struct { + v reflect.Value + s string +} + +func getKeys(r reflect.Value) []mapkey { + keys := r.MapKeys() + res := make([]mapkey, len(keys)) + + for i, key := range keys { + // TODO: Handle other kinds of keys. + if key.Kind() != reflect.String { + panic("unexpected map key type") + } + res[i] = mapkey{ + v: key, + s: key.String(), + } + } + + sort.Slice(res, func(i, j int) bool { return res[i].s < res[j].s }) + + return res +} + +func (m *marshaller) marshalSlice(r reflect.Value) error { + if r.Type().Elem().Kind() == reflect.Uint8 { + return m.marshalBlob(r) + } + + if r.IsNil() { + m.w.WriteNull() + return m.w.Err() + } + + return m.marshalArray(r) +} + +func (m *marshaller) marshalBlob(r reflect.Value) error { + if r.IsNil() { + m.w.WriteNull() + } else { + m.w.WriteBlob(r.Bytes()) + } + return m.w.Err() +} + +func (m *marshaller) marshalArray(r reflect.Value) error { + m.w.BeginList() + + for i := 0; i < r.Len(); i++ { + if err := m.marshalValue(r.Index(i)); err != nil { + return err + } + } + + m.w.EndList() + return m.w.Err() +} + +func (m *marshaller) marshalStruct(r reflect.Value) error { + m.w.BeginStruct() + + fields := getFields(r.Type()) + for i := range fields { + f := &fields[i] + m.w.FieldName(f.name) + if err := m.marshalValue(r.Field(f.index)); err != nil { + return err + } + } + + m.w.EndStruct() + return m.w.Err() +} + +type field struct { + name string + typ reflect.Type + index int +} + +func getFields(t reflect.Type) []field { + fields := []field{} + + // current := []reflect.Type{} + // next := []reflect.Type{t} + // visited := map[reflect.Type]bool{} + + // for len(next) > 0 { + // current, next = next, current[:0] + // for _, c := range current { + // if visited[c] { + // continue + // } + // visited[c] = true + + c := t + + for i := 0; i < c.NumField(); i++ { + f := c.Field(i) + + tag := f.Tag.Get("json") + if tag == "-" { + continue + } + name := parseTag(tag) + + fType := f.Type + if fType.Name() == "" && fType.Kind() == reflect.Ptr { + fType = fType.Elem() + } + + if name == "" && f.Anonymous && fType.Kind() == reflect.Struct { + // next = append(next, fType) + continue + } + + if name == "" { + name = f.Name + } + + fields = append(fields, field{ + name: name, + typ: fType, + index: i, + }) + } + + // } + // } + + sort.Slice(fields, func(i, j int) bool { return fields[i].index < fields[j].index }) + + return fields +} + +func parseTag(tag string) string { + if idx := strings.Index(tag, ","); idx != -1 { + // Ignore additional JSON options, at least for now. + return tag[:idx] + } + return tag +} diff --git a/marshal_test.go b/marshal_test.go new file mode 100644 index 00000000..f7d7e756 --- /dev/null +++ b/marshal_test.go @@ -0,0 +1,60 @@ +package ion + +import ( + "math" + "testing" +) + +func TestMarshalText(t *testing.T) { + test := func(v interface{}, eval string) { + t.Run(eval, func(t *testing.T) { + val, err := MarshalText(v) + if err != nil { + t.Fatal(err) + } + if string(val) != eval { + t.Errorf("expected '%v', got '%v'", eval, string(val)) + } + }) + } + + test(nil, "null") + test(true, "true") + test(false, "false") + + test(byte(42), "42") + test(-42, "-42") + test(uint64(math.MaxUint64), "18446744073709551615") + test(math.MinInt64, "-9223372036854775808") + + test(42.0, "4.2e+1") + test(math.Inf(1), "+inf") + test(math.Inf(-1), "-inf") + test(math.NaN(), "nan") + + test("hello\tworld", "\"hello\\tworld\"") + + test(struct{ A, B int }{42, 0}, "{A:42,B:0}") + test(struct { + A int `json:"val,ignoreme"` + B int `json:"-"` + }{42, 0}, "{val:42}") + + test(struct{ v interface{} }{}, "{v:null}") + test(struct{ v interface{} }{"42"}, "{v:\"42\"}") + + fourtytwo := 42 + + test(struct{ v *int }{}, "{v:null}") + test(struct{ v *int }{&fourtytwo}, "{v:42}") + + test(map[string]int{"b": 2, "a": 1}, "{a:1,b:2}") + + test(struct{ v []int }{}, "{v:null}") + test(struct{ v []int }{[]int{4, 2}}, "{v:[4,2]}") + + test(struct{ v []byte }{}, "{v:null}") + test(struct{ v []byte }{[]byte{4, 2}}, "{v:{{BAI=}}}") + + test(struct{ v [2]byte }{[2]byte{4, 2}}, "{v:[4,2]}") +} diff --git a/textwriter.go b/textwriter.go index ea1173e5..26501ad7 100644 --- a/textwriter.go +++ b/textwriter.go @@ -11,18 +11,35 @@ import ( "time" ) +// TextWriterOpts defines a set of bit flag options for text writers. +type TextWriterOpts uint8 + +const ( + // OptQuietFinish disables emiting a newline in Finish(). Convenient if you know + // you're only emiting one datagram; dangerous if there's a chance you're going to + // emit another datagram using the same Writer. + OptQuietFinish TextWriterOpts = 1 +) + // textWriter is a writer that writes human-readable text type textWriter struct { writer needsSeparator bool + opts TextWriterOpts } // NewTextWriter returns a new text writer. func NewTextWriter(out io.Writer) Writer { + return NewTextWriterOpts(out, 0) +} + +// NewTextWriterOpts returns a new text writer with the given options. +func NewTextWriterOpts(out io.Writer, opts TextWriterOpts) Writer { return &textWriter{ writer: writer{ out: out, }, + opts: opts, } } @@ -204,7 +221,7 @@ func (w *textWriter) WriteNullWithType(t Type) { } w.err = w.writeValue(func() string { switch t { - case NullType: + case NoType, NullType: return "null" case BoolType: return "null.bool" @@ -231,7 +248,7 @@ func (w *textWriter) WriteNullWithType(t Type) { case SexpType: return "null.sexp" default: - panic("invalid type") + panic(fmt.Sprintf("invalid type: %v", t)) } }) } @@ -402,8 +419,10 @@ func (w *textWriter) Finish() error { return w.err } - if w.err = writeRawChar('\n', w.out); w.err != nil { - return w.err + if w.opts&OptQuietFinish == 0 { + if w.err = writeRawChar('\n', w.out); w.err != nil { + return w.err + } } w.fieldName = "" From 84003078c5fd5183462fb0e2c16061f0a1f494c9 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 23 Jul 2019 18:51:23 +1000 Subject: [PATCH 22/56] make marshaller public --- marshal.go | 95 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 29 deletions(-) diff --git a/marshal.go b/marshal.go index 1769197c..b11b09e0 100644 --- a/marshal.go +++ b/marshal.go @@ -3,42 +3,69 @@ package ion import ( "bytes" "fmt" + "io" "math/big" "reflect" "sort" "strings" ) -type marshaller struct { - buf *bytes.Buffer - w Writer -} +type marshallerOpts uint8 -func newTextMarshaller() *marshaller { - buf := &bytes.Buffer{} - return &marshaller{ - buf: buf, - w: NewTextWriterOpts(buf, OptQuietFinish), - } -} +const ( + optSortStructs marshallerOpts = 1 +) // MarshalText marshals values to text ion. func MarshalText(v interface{}) ([]byte, error) { - m := newTextMarshaller() - if err := m.marshal(v); err != nil { + buf := bytes.Buffer{} + m := Marshaller{ + w: NewTextWriterOpts(&buf, OptQuietFinish), + opts: optSortStructs, + } + + if err := m.Marshal(v); err != nil { return nil, err } - return m.buf.Bytes(), nil + if err := m.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// A Marshaller marshals golang values to Ion. +type Marshaller struct { + w Writer + opts marshallerOpts +} + +// NewMarshaller creates a new marshaller that marshals to the given writer. +func NewMarshaller(w Writer) *Marshaller { + return &Marshaller{ + w: w, + } } -func (m *marshaller) marshal(v interface{}) error { - if err := m.marshalValue(reflect.ValueOf(v)); err != nil { - return err +// NewTextMarshaller creates a new marshaller that marshals text Ion to the given writer. +func NewTextMarshaller(w io.Writer) *Marshaller { + return &Marshaller{ + w: NewTextWriter(w), + opts: optSortStructs, } +} + +// Marshal marshals the given value to Ion, writing it to the underlying writer. +func (m *Marshaller) Marshal(v interface{}) error { + return m.marshalValue(reflect.ValueOf(v)) +} + +// Finish finishes writing the current Ion datagram. +func (m *Marshaller) Finish() error { return m.w.Finish() } -func (m *marshaller) marshalValue(r reflect.Value) error { +func (m *Marshaller) marshalValue(r reflect.Value) error { if !r.IsValid() { m.w.WriteNull() return nil @@ -95,7 +122,7 @@ func (m *marshaller) marshalValue(r reflect.Value) error { } } -func (m *marshaller) marshalInterfaceOrPtr(r reflect.Value) error { +func (m *Marshaller) marshalInterfaceOrPtr(r reflect.Value) error { if r.IsNil() { m.w.WriteNull() return m.w.Err() @@ -103,7 +130,7 @@ func (m *marshaller) marshalInterfaceOrPtr(r reflect.Value) error { return m.marshalValue(r.Elem()) } -func (m *marshaller) marshalMap(r reflect.Value) error { +func (m *Marshaller) marshalMap(r reflect.Value) error { if r.IsNil() { m.w.WriteNull() return m.w.Err() @@ -111,7 +138,15 @@ func (m *marshaller) marshalMap(r reflect.Value) error { m.w.BeginStruct() - for _, key := range getKeys(r) { + keys := getKeys(r) + if m.opts&optSortStructs != 0 { + // We do this for text Ion because json.Marshal does, and it's useful for testing. + // For binary Ion, skip it and write things in whatever order they come back from + // the map. + sort.Slice(keys, func(i, j int) bool { return keys[i].s < keys[j].s }) + } + + for _, key := range keys { m.w.FieldName(key.s) value := r.MapIndex(key.v) if err := m.marshalValue(value); err != nil { @@ -143,12 +178,10 @@ func getKeys(r reflect.Value) []mapkey { } } - sort.Slice(res, func(i, j int) bool { return res[i].s < res[j].s }) - return res } -func (m *marshaller) marshalSlice(r reflect.Value) error { +func (m *Marshaller) marshalSlice(r reflect.Value) error { if r.Type().Elem().Kind() == reflect.Uint8 { return m.marshalBlob(r) } @@ -161,7 +194,7 @@ func (m *marshaller) marshalSlice(r reflect.Value) error { return m.marshalArray(r) } -func (m *marshaller) marshalBlob(r reflect.Value) error { +func (m *Marshaller) marshalBlob(r reflect.Value) error { if r.IsNil() { m.w.WriteNull() } else { @@ -170,7 +203,7 @@ func (m *marshaller) marshalBlob(r reflect.Value) error { return m.w.Err() } -func (m *marshaller) marshalArray(r reflect.Value) error { +func (m *Marshaller) marshalArray(r reflect.Value) error { m.w.BeginList() for i := 0; i < r.Len(); i++ { @@ -183,10 +216,16 @@ func (m *marshaller) marshalArray(r reflect.Value) error { return m.w.Err() } -func (m *marshaller) marshalStruct(r reflect.Value) error { +func (m *Marshaller) marshalStruct(r reflect.Value) error { m.w.BeginStruct() fields := getFields(r.Type()) + if m.opts&optSortStructs != 0 { + // We do this for text Ion because json.Marshal does, and it's useful for testing. + // For binary Ion, skip it and write things in whatever order they happen to be in. + sort.Slice(fields, func(i, j int) bool { return fields[i].index < fields[j].index }) + } + for i := range fields { f := &fields[i] m.w.FieldName(f.name) @@ -255,8 +294,6 @@ func getFields(t reflect.Type) []field { // } // } - sort.Slice(fields, func(i, j int) bool { return fields[i].index < fields[j].index }) - return fields } From e35a521084136abc610757954bb4dd7e04273160 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 23 Jul 2019 18:59:08 +1000 Subject: [PATCH 23/56] Writer.WriteValue --- api.go | 2 ++ textwriter.go | 8 ++++++++ textwriter_test.go | 13 +++++++++++++ 3 files changed, 23 insertions(+) diff --git a/api.go b/api.go index 059a1efa..91ad290d 100644 --- a/api.go +++ b/api.go @@ -141,5 +141,7 @@ type Writer interface { WriteBlob(val []byte) WriteClob(val []byte) + WriteValue(val interface{}) + Finish() error } diff --git a/textwriter.go b/textwriter.go index 26501ad7..a6abf67e 100644 --- a/textwriter.go +++ b/textwriter.go @@ -409,6 +409,14 @@ func (w *textWriter) WriteClob(val []byte) { }) } +func (w *textWriter) WriteValue(val interface{}) { + m := Marshaller{ + w: w, + opts: optSortStructs, + } + m.Marshal(val) +} + // Finish finishes the current datagram. func (w *textWriter) Finish() error { if w.err != nil { diff --git a/textwriter_test.go b/textwriter_test.go index a90fb95e..e48a412a 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -315,6 +315,19 @@ func TestClob(t *testing.T) { }) } +func TestWriteValue(t *testing.T) { + expected := "{s:{b:2,a:1}}" + testTextWriter(t, expected, func(w Writer) { + w.BeginStruct() + w.FieldName("s") + w.WriteValue(struct { + b int + a int + }{2, 1}) + w.EndStruct() + }) +} + func TestFinish(t *testing.T) { expected := "1\nfoo\n\"bar\"\n{}\n" testTextWriter(t, expected, func(w Writer) { From 9c1a18681dfcdbbf130eb93ddf738a8754d07cee Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 23 Jul 2019 21:35:46 +1000 Subject: [PATCH 24/56] marshaller->encoder --- marshal.go | 292 +++++++++++++++++++++++++++------------------ marshal_test.go | 66 ++++++++-- textwriter.go | 8 +- textwriter_test.go | 6 +- 4 files changed, 241 insertions(+), 131 deletions(-) diff --git a/marshal.go b/marshal.go index b11b09e0..2b7b8227 100644 --- a/marshal.go +++ b/marshal.go @@ -10,21 +10,15 @@ import ( "strings" ) -type marshallerOpts uint8 - -const ( - optSortStructs marshallerOpts = 1 -) - // MarshalText marshals values to text ion. func MarshalText(v interface{}) ([]byte, error) { buf := bytes.Buffer{} - m := Marshaller{ - w: NewTextWriterOpts(&buf, OptQuietFinish), - opts: optSortStructs, + m := Encoder{ + w: NewTextWriterOpts(&buf, OptQuietFinish), + sortMaps: true, } - if err := m.Marshal(v); err != nil { + if err := m.Encode(v); err != nil { return nil, err } if err := m.Finish(); err != nil { @@ -34,112 +28,112 @@ func MarshalText(v interface{}) ([]byte, error) { return buf.Bytes(), nil } -// A Marshaller marshals golang values to Ion. -type Marshaller struct { - w Writer - opts marshallerOpts +// An Encoder writes Ion values to an output stream. +type Encoder struct { + w Writer + sortMaps bool } -// NewMarshaller creates a new marshaller that marshals to the given writer. -func NewMarshaller(w Writer) *Marshaller { - return &Marshaller{ +// NewEncoder creates a new encoder. +func NewEncoder(w Writer) *Encoder { + return &Encoder{ w: w, } } -// NewTextMarshaller creates a new marshaller that marshals text Ion to the given writer. -func NewTextMarshaller(w io.Writer) *Marshaller { - return &Marshaller{ - w: NewTextWriter(w), - opts: optSortStructs, +// NewTextEncoder creates a new Encoder that marshals text Ion to the given writer. +func NewTextEncoder(w io.Writer) *Encoder { + return &Encoder{ + w: NewTextWriter(w), + sortMaps: true, } } -// Marshal marshals the given value to Ion, writing it to the underlying writer. -func (m *Marshaller) Marshal(v interface{}) error { +// Encode marshals the given value to Ion, writing it to the underlying writer. +func (m *Encoder) Encode(v interface{}) error { return m.marshalValue(reflect.ValueOf(v)) } // Finish finishes writing the current Ion datagram. -func (m *Marshaller) Finish() error { +func (m *Encoder) Finish() error { return m.w.Finish() } -func (m *Marshaller) marshalValue(r reflect.Value) error { - if !r.IsValid() { +func (m *Encoder) marshalValue(v reflect.Value) error { + if !v.IsValid() { m.w.WriteNull() return nil } - t := r.Type() + t := v.Type() switch t.Kind() { case reflect.Bool: - m.w.WriteBool(r.Bool()) + m.w.WriteBool(v.Bool()) return m.w.Err() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - m.w.WriteInt(r.Int()) + m.w.WriteInt(v.Int()) return m.w.Err() case reflect.Uint8, reflect.Uint16, reflect.Uint32: - m.w.WriteInt(int64(r.Uint())) + m.w.WriteInt(int64(v.Uint())) return m.w.Err() case reflect.Uint, reflect.Uint64, reflect.Uintptr: i := big.Int{} - i.SetUint64(r.Uint()) + i.SetUint64(v.Uint()) m.w.WriteBigInt(&i) return m.w.Err() case reflect.Float32, reflect.Float64: - m.w.WriteFloat(r.Float()) + m.w.WriteFloat(v.Float()) return m.w.Err() // TODO: Decimal // TODO: Time case reflect.String: - m.w.WriteString(r.String()) + m.w.WriteString(v.String()) return m.w.Err() case reflect.Interface, reflect.Ptr: - return m.marshalInterfaceOrPtr(r) + return m.marshalPtr(v) case reflect.Struct: - return m.marshalStruct(r) + return m.marshalStruct(v) case reflect.Map: - return m.marshalMap(r) + return m.marshalMap(v) case reflect.Slice: - return m.marshalSlice(r) + return m.marshalSlice(v) case reflect.Array: - return m.marshalArray(r) + return m.marshalArray(v) default: - return fmt.Errorf("unsupported type %v", r.Type()) + return fmt.Errorf("ion: unsupported type: %v", v.Type().String()) } } -func (m *Marshaller) marshalInterfaceOrPtr(r reflect.Value) error { - if r.IsNil() { +func (m *Encoder) marshalPtr(v reflect.Value) error { + if v.IsNil() { m.w.WriteNull() return m.w.Err() } - return m.marshalValue(r.Elem()) + return m.marshalValue(v.Elem()) } -func (m *Marshaller) marshalMap(r reflect.Value) error { - if r.IsNil() { +func (m *Encoder) marshalMap(v reflect.Value) error { + if v.IsNil() { m.w.WriteNull() return m.w.Err() } m.w.BeginStruct() - keys := getKeys(r) - if m.opts&optSortStructs != 0 { + keys := getKeys(v) + if m.sortMaps { // We do this for text Ion because json.Marshal does, and it's useful for testing. // For binary Ion, skip it and write things in whatever order they come back from // the map. @@ -148,7 +142,7 @@ func (m *Marshaller) marshalMap(r reflect.Value) error { for _, key := range keys { m.w.FieldName(key.s) - value := r.MapIndex(key.v) + value := v.MapIndex(key.v) if err := m.marshalValue(value); err != nil { return err } @@ -163,8 +157,8 @@ type mapkey struct { s string } -func getKeys(r reflect.Value) []mapkey { - keys := r.MapKeys() +func getKeys(v reflect.Value) []mapkey { + keys := v.MapKeys() res := make([]mapkey, len(keys)) for i, key := range keys { @@ -181,33 +175,33 @@ func getKeys(r reflect.Value) []mapkey { return res } -func (m *Marshaller) marshalSlice(r reflect.Value) error { - if r.Type().Elem().Kind() == reflect.Uint8 { - return m.marshalBlob(r) +func (m *Encoder) marshalSlice(v reflect.Value) error { + if v.Type().Elem().Kind() == reflect.Uint8 { + return m.marshalBlob(v) } - if r.IsNil() { + if v.IsNil() { m.w.WriteNull() return m.w.Err() } - return m.marshalArray(r) + return m.marshalArray(v) } -func (m *Marshaller) marshalBlob(r reflect.Value) error { - if r.IsNil() { +func (m *Encoder) marshalBlob(v reflect.Value) error { + if v.IsNil() { m.w.WriteNull() } else { - m.w.WriteBlob(r.Bytes()) + m.w.WriteBlob(v.Bytes()) } return m.w.Err() } -func (m *Marshaller) marshalArray(r reflect.Value) error { +func (m *Encoder) marshalArray(v reflect.Value) error { m.w.BeginList() - for i := 0; i < r.Len(); i++ { - if err := m.marshalValue(r.Index(i)); err != nil { + for i := 0; i < v.Len(); i++ { + if err := m.marshalValue(v.Index(i)); err != nil { return err } } @@ -216,20 +210,32 @@ func (m *Marshaller) marshalArray(r reflect.Value) error { return m.w.Err() } -func (m *Marshaller) marshalStruct(r reflect.Value) error { - m.w.BeginStruct() +func (m *Encoder) marshalStruct(v reflect.Value) error { + fields := fieldsFor(v.Type()) - fields := getFields(r.Type()) - if m.opts&optSortStructs != 0 { - // We do this for text Ion because json.Marshal does, and it's useful for testing. - // For binary Ion, skip it and write things in whatever order they happen to be in. - sort.Slice(fields, func(i, j int) bool { return fields[i].index < fields[j].index }) - } + m.w.BeginStruct() +FieldLoop: for i := range fields { f := &fields[i] + + fv := v + for _, i := range f.path { + if fv.Kind() == reflect.Ptr { + if fv.IsNil() { + continue FieldLoop + } + fv = fv.Elem() + } + fv = fv.Field(i) + } + + if f.omitEmpty && emptyValue(fv) { + continue + } + m.w.FieldName(f.name) - if err := m.marshalValue(r.Field(f.index)); err != nil { + if err := m.marshalValue(fv); err != nil { return err } } @@ -238,69 +244,127 @@ func (m *Marshaller) marshalStruct(r reflect.Value) error { return m.w.Err() } -type field struct { - name string - typ reflect.Type - index int +func emptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false } -func getFields(t reflect.Type) []field { - fields := []field{} - - // current := []reflect.Type{} - // next := []reflect.Type{t} - // visited := map[reflect.Type]bool{} +type field struct { + name string + typ reflect.Type + path []int + omitEmpty bool +} - // for len(next) > 0 { - // current, next = next, current[:0] - // for _, c := range current { - // if visited[c] { - // continue - // } - // visited[c] = true +type fielder struct { + fields []field + index map[string]bool +} - c := t +func fieldsFor(t reflect.Type) []field { + fldr := fielder{index: map[string]bool{}} + fldr.inspect(t, nil) + return fldr.fields +} - for i := 0; i < c.NumField(); i++ { - f := c.Field(i) +func (f *fielder) inspect(t reflect.Type, path []int) { + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + if !visible(&sf) { + // Skip non-visible fields. + continue + } - tag := f.Tag.Get("json") + tag := sf.Tag.Get("json") if tag == "-" { + // Skip fields that are explicitly hidden by tag. continue } - name := parseTag(tag) + name, opts := parseTag(tag) - fType := f.Type - if fType.Name() == "" && fType.Kind() == reflect.Ptr { - fType = fType.Elem() - } + newpath := make([]int, len(path)+1) + copy(newpath, path) + newpath[len(path)] = i - if name == "" && f.Anonymous && fType.Kind() == reflect.Struct { - // next = append(next, fType) - continue + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + ft = ft.Elem() } - if name == "" { - name = f.Name + if name == "" && sf.Anonymous && ft.Kind() == reflect.Struct { + // Dig in to the embedded struct. + f.inspect(ft, newpath) + } else { + // Add this named field. + if name == "" { + name = sf.Name + } + + if f.index[name] { + panic(fmt.Sprintf("too many fields named %v", name)) + } + f.index[name] = true + + f.fields = append(f.fields, field{ + name: name, + typ: ft, + path: newpath, + omitEmpty: omitEmpty(opts), + }) } - - fields = append(fields, field{ - name: name, - typ: fType, - index: i, - }) } +} - // } - // } - - return fields +func visible(sf *reflect.StructField) bool { + exported := sf.PkgPath == "" + if sf.Anonymous { + t := sf.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + // Fields of embedded structs are visible even if the struct type itself is not. + return true + } + } + return exported } -func parseTag(tag string) string { +func parseTag(tag string) (string, string) { if idx := strings.Index(tag, ","); idx != -1 { // Ignore additional JSON options, at least for now. - return tag[:idx] + return tag[:idx], tag[idx+1:] + } + return tag, "" +} + +func omitEmpty(opts string) bool { + for opts != "" { + var o string + + i := strings.Index(opts, ",") + if i >= 0 { + o, opts = opts[:i], opts[i+1:] + } else { + o, opts = opts, "" + } + + if o == "omitempty" { + return true + } } - return tag + return false } diff --git a/marshal_test.go b/marshal_test.go index f7d7e756..8c1ae4ec 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -38,23 +38,69 @@ func TestMarshalText(t *testing.T) { test(struct { A int `json:"val,ignoreme"` B int `json:"-"` - }{42, 0}, "{val:42}") + C int `json:",omitempty"` + d int + }{42, 0, 0, 0}, "{val:42}") - test(struct{ v interface{} }{}, "{v:null}") - test(struct{ v interface{} }{"42"}, "{v:\"42\"}") + test(struct{ V interface{} }{}, "{V:null}") + test(struct{ V interface{} }{"42"}, "{V:\"42\"}") fourtytwo := 42 - test(struct{ v *int }{}, "{v:null}") - test(struct{ v *int }{&fourtytwo}, "{v:42}") + test(struct{ V *int }{}, "{V:null}") + test(struct{ V *int }{&fourtytwo}, "{V:42}") test(map[string]int{"b": 2, "a": 1}, "{a:1,b:2}") - test(struct{ v []int }{}, "{v:null}") - test(struct{ v []int }{[]int{4, 2}}, "{v:[4,2]}") + test(struct{ V []int }{}, "{V:null}") + test(struct{ V []int }{[]int{4, 2}}, "{V:[4,2]}") - test(struct{ v []byte }{}, "{v:null}") - test(struct{ v []byte }{[]byte{4, 2}}, "{v:{{BAI=}}}") + test(struct{ V []byte }{}, "{V:null}") + test(struct{ V []byte }{[]byte{4, 2}}, "{V:{{BAI=}}}") - test(struct{ v [2]byte }{[2]byte{4, 2}}, "{v:[4,2]}") + test(struct{ V [2]byte }{[2]byte{4, 2}}, "{V:[4,2]}") +} + +func TestMarshalNestedStructs(t *testing.T) { + type gp struct { + A int `json:"a"` + } + + type gp2 struct { + B int `json:"b"` + } + + type parent struct { + gp + *gp2 + C int `json:"c"` + } + + type root struct { + parent + D int `json:"d"` + } + + v := root{ + parent: parent{ + gp: gp{ + A: 1, + }, + gp2: &gp2{ + B: 2, + }, + C: 3, + }, + D: 4, + } + + val, err := MarshalText(v) + if err != nil { + t.Fatal(err) + } + + eval := "{a:1,b:2,c:3,d:4}" + if string(val) != eval { + t.Errorf("expected %v, got %v", eval, string(val)) + } } diff --git a/textwriter.go b/textwriter.go index a6abf67e..97da5eee 100644 --- a/textwriter.go +++ b/textwriter.go @@ -410,11 +410,11 @@ func (w *textWriter) WriteClob(val []byte) { } func (w *textWriter) WriteValue(val interface{}) { - m := Marshaller{ - w: w, - opts: optSortStructs, + m := Encoder{ + w: w, + sortMaps: true, } - m.Marshal(val) + m.Encode(val) } // Finish finishes the current datagram. diff --git a/textwriter_test.go b/textwriter_test.go index e48a412a..c93d8dce 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -316,13 +316,13 @@ func TestClob(t *testing.T) { } func TestWriteValue(t *testing.T) { - expected := "{s:{b:2,a:1}}" + expected := "{s:{B:2,A:1}}" testTextWriter(t, expected, func(w Writer) { w.BeginStruct() w.FieldName("s") w.WriteValue(struct { - b int - a int + B int + A int }{2, 1}) w.EndStruct() }) From d5ef3095a7a7e596ed63b55b2bbffd9d5af047f1 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 29 Jul 2019 11:46:28 +1000 Subject: [PATCH 25/56] unmarshal --- api.go | 16 ++ reader_test.go | 104 +++++--- textreader.go | 15 ++ unmarshal.go | 617 ++++++++++++++++++++++++++++++++++++++++++++++ unmarshal_test.go | 546 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1260 insertions(+), 38 deletions(-) create mode 100644 unmarshal.go create mode 100644 unmarshal_test.go diff --git a/api.go b/api.go index 91ad290d..4e17a1c6 100644 --- a/api.go +++ b/api.go @@ -75,6 +75,20 @@ func (t Type) String() string { } } +// IntSize is the size of an integer. +type IntSize uint8 + +const ( + // NullInt is the size of null.int. + NullInt IntSize = iota + // Int32 is a 32-bit integer. + Int32 + // Int64 is a 64-bit integer. + Int64 + // BigInt is too big for a 64-bit integer. + BigInt +) + // A Reader reads Ion values from an input stream. type Reader interface { SymbolTable() SymbolTable @@ -87,6 +101,8 @@ type Reader interface { TypeAnnotations() []string IsNull() bool + IntSize() IntSize + StepIn() error StepOut() error diff --git a/reader_test.go b/reader_test.go index 6aad0833..25f40598 100644 --- a/reader_test.go +++ b/reader_test.go @@ -18,12 +18,12 @@ var blacklist = map[string]bool{ "ion-tests/iontestdata/good/item1.10n": true, } -func print(level int, obj interface{}) { - fmt.Print(" > ") - for i := 0; i < level; i++ { - fmt.Print(" ") - } - fmt.Println(obj) +type drainfunc func(t *testing.T, r Reader, f string) + +func TestReadFiles(t *testing.T) { + testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { + drain(t, r, 0) + }) } func drain(t *testing.T, r Reader, level int) { @@ -51,7 +51,65 @@ func drain(t *testing.T, r Reader, level int) { } } -func testReaderFile(t *testing.T, path string) { +func print(level int, obj interface{}) { + fmt.Print(" > ") + for i := 0; i < level; i++ { + fmt.Print(" ") + } + fmt.Println(obj) +} + +func TestDecodeFiles(t *testing.T) { + testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { + // fmt.Println(f) + d := NewDecoder(r) + for { + v, err := d.Decode() + if err == ErrNoInput { + break + } + if err != nil { + t.Fatal(err) + } + // fmt.Println(v) + _ = v + } + }) +} + +var emptyFiles = []string{ + "ion-tests/iontestdata/good/blank.ion", + "ion-tests/iontestdata/good/empty.ion", +} + +func isEmptyFile(f string) bool { + for _, s := range emptyFiles { + if f == s { + return true + } + } + return false +} + +func testReadDir(t *testing.T, path string, d drainfunc) { + files, err := ioutil.ReadDir(path) + if err != nil { + t.Fatal(err) + } + + for _, file := range files { + fp := filepath.Join(path, file.Name()) + if file.IsDir() { + testReadDir(t, fp, d) + } else { + t.Run(fp, func(t *testing.T) { + testReadFile(t, fp, d) + }) + } + } +} + +func testReadFile(t *testing.T, path string, d drainfunc) { if _, ok := blacklist[path]; ok { return } @@ -76,35 +134,5 @@ func testReaderFile(t *testing.T, path string) { t.Fatal("unexpected suffix on file", path) } - drain(t, r, 0) -} - -func testReaderDir(t *testing.T, path string) { - files, err := ioutil.ReadDir(path) - if err != nil { - t.Fatal(err) - } - - for _, file := range files { - fp := filepath.Join(path, file.Name()) - if file.IsDir() { - testReaderDir(t, fp) - } else { - t.Run(fp, func(t *testing.T) { - testReaderFile(t, fp) - }) - } - } + d(t, r, path) } - -func TestReader(t *testing.T) { - testReaderDir(t, "ion-tests/iontestdata/good") -} - -// func TestAllNulls(t *testing.T) { -// testReaderFile(t, "ion-tests/iontestdata/good/allNulls.ion") -// } - -// func TestStructWhitespace(t *testing.T) { -// testReaderFile(t, "ion-tests/iontestdata/good/equivs/structWhitespace.ion") -// } diff --git a/textreader.go b/textreader.go index 33ead3a3..37495b53 100644 --- a/textreader.go +++ b/textreader.go @@ -708,6 +708,21 @@ func (t *textReader) BoolValue() (bool, error) { return false, errors.New("value is not a bool") } +func (t *textReader) IntSize() IntSize { + if t.valueType != IntType || t.value == nil { + return NullInt + } + + if i, ok := t.value.(int64); ok { + if i > math.MaxInt32 || i < math.MinInt32 { + return Int64 + } + return Int32 + } + + return BigInt +} + func (t *textReader) IntValue() (int, error) { i, err := t.Int64Value() if err != nil { diff --git a/unmarshal.go b/unmarshal.go new file mode 100644 index 00000000..89c7ac50 --- /dev/null +++ b/unmarshal.go @@ -0,0 +1,617 @@ +package ion + +import ( + "bytes" + "errors" + "fmt" + "math/big" + "reflect" + "strings" + "time" +) + +var ( + // ErrNoInput is returned when there is no input to decode + ErrNoInput = errors.New("ion: no input to decode") +) + +// Unmarshal unmarshals Ion data to the given object. +func Unmarshal(data []byte, v interface{}) error { + // TODO: Figure out if it's text or binary instead of hardcoding text. + return NewDecoder(NewTextReader(bytes.NewReader(data))).DecodeTo(v) +} + +// A Decoder decodes go values from an Ion reader. +type Decoder struct { + r Reader +} + +// NewDecoder creates a new decoder. +func NewDecoder(r Reader) *Decoder { + return &Decoder{ + r: r, + } +} + +// Decode decodes a value from the underlying Ion reader without any expectations +// about what it's going to get. Structs become map[string]interface{}s, Lists and +// Sexps become []interface{}s. +func (d *Decoder) Decode() (interface{}, error) { + if !d.r.Next() { + if d.r.Err() != nil { + return nil, d.r.Err() + } + return nil, ErrNoInput + } + + return d.decode() +} + +// Helper form of Decode for when you've already called Next. +func (d *Decoder) decode() (interface{}, error) { + if d.r.IsNull() { + return nil, nil + } + + switch d.r.Type() { + case BoolType: + return d.r.BoolValue() + + case IntType: + return d.decodeInt() + + case FloatType: + return d.r.FloatValue() + + case DecimalType: + return d.r.DecimalValue() + + case TimestampType: + return d.r.TimeValue() + + case StringType, SymbolType: + return d.r.StringValue() + + case BlobType, ClobType: + return d.r.ByteValue() + + case StructType: + return d.decodeMap() + + case ListType, SexpType: + return d.decodeSlice() + + default: + panic("wat?") + } +} + +func (d *Decoder) decodeInt() (interface{}, error) { + switch d.r.IntSize() { + case Int32: + return d.r.IntValue() + case Int64: + return d.r.Int64Value() + default: + return d.r.BigIntValue() + } +} + +// DecodeMap decodes an Ion struct to a go map. +func (d *Decoder) decodeMap() (map[string]interface{}, error) { + if err := d.r.StepIn(); err != nil { + return nil, err + } + + result := map[string]interface{}{} + + for d.r.Next() { + name := d.r.FieldName() + value, err := d.decode() + if err != nil { + return nil, err + } + result[name] = value + } + + if err := d.r.StepOut(); err != nil { + return nil, err + } + + return result, nil +} + +// DecodeSlice decodes an Ion list or sexp to a go slice. +func (d *Decoder) decodeSlice() ([]interface{}, error) { + if err := d.r.StepIn(); err != nil { + return nil, err + } + + result := []interface{}{} + + for d.r.Next() { + value, err := d.decode() + if err != nil { + return nil, err + } + result = append(result, value) + } + + if err := d.r.StepOut(); err != nil { + return nil, err + } + + return result, nil +} + +// DecodeTo decodes an Ion value from the underlying Ion reader into the +// value provided. +func (d *Decoder) DecodeTo(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return errors.New("ion: v must be a pointer") + } + if rv.IsNil() { + return errors.New("ion: v must not be nil") + } + + if !d.r.Next() { + if d.r.Err() != nil { + return d.r.Err() + } + return ErrNoInput + } + + return d.decodeTo(rv) +} + +func (d *Decoder) decodeTo(v reflect.Value) error { + if !v.IsValid() { + // Don't actually have anywhere to put this value; skip it. + return nil + } + + isNull := d.r.IsNull() + v = indirect(v, isNull) + if isNull { + v.Set(reflect.Zero(v.Type())) + return nil + } + + switch d.r.Type() { + case BoolType: + return d.decodeBoolTo(v) + + case IntType: + return d.decodeIntTo(v) + + case FloatType: + return d.decodeFloatTo(v) + + // TODO: Decimal + // TODO: Timestamp + + case TimestampType: + return d.decodeTimestampTo(v) + + case StringType, SymbolType: + return d.decodeStringTo(v) + + case BlobType, ClobType: + return d.decodeLobTo(v) + + case StructType: + return d.decodeStructTo(v) + + case ListType, SexpType: + return d.decodeSliceTo(v) + + default: + panic("wat?") + } +} + +func (d *Decoder) decodeBoolTo(v reflect.Value) error { + val, err := d.r.BoolValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Bool: + // Too easy. + v.SetBool(val) + return nil + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode bool to %v", v.Type().String()) +} + +var bigIntType = reflect.TypeOf(big.Int{}) + +func (d *Decoder) decodeIntTo(v reflect.Value) error { + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val, err := d.r.Int64Value() + if err != nil { + return err + } + if v.OverflowInt(val) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetInt(val) + return nil + + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + val, err := d.r.Int64Value() + if err != nil { + return err + } + if val < 0 || v.OverflowUint(uint64(val)) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetUint(uint64(val)) + return nil + + case reflect.Uint, reflect.Uint64, reflect.Uintptr: + val, err := d.r.BigIntValue() + if err != nil { + return err + } + if !val.IsUint64() { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + uiv := val.Uint64() + if v.OverflowUint(uiv) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetUint(uiv) + return nil + + case reflect.Struct: + if v.Type() == bigIntType { + val, err := d.r.BigIntValue() + if err != nil { + return err + } + v.Set(reflect.ValueOf(*val)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + val, err := d.decodeInt() + if err != nil { + return err + } + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode int to %v", v.Type().String()) +} + +func (d *Decoder) decodeFloatTo(v reflect.Value) error { + val, err := d.r.FloatValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Float32, reflect.Float64: + if v.OverflowFloat(val) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetFloat(val) + return nil + + // TODO: Decimal + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode float to %v", v.Type().String()) +} + +var timeType = reflect.TypeOf(time.Time{}) + +func (d *Decoder) decodeTimestampTo(v reflect.Value) error { + val, err := d.r.TimeValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Struct: + if v.Type() == timeType { + v.Set(reflect.ValueOf(val)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode timestamp to %v", v.Type().String()) +} + +func (d *Decoder) decodeStringTo(v reflect.Value) error { + val, err := d.r.StringValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.String: + v.SetString(val) + return nil + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode string to %v", v.Type().String()) +} + +func (d *Decoder) decodeLobTo(v reflect.Value) error { + val, err := d.r.ByteValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Slice: + if v.Type().Elem().Kind() == reflect.Uint8 { + v.SetBytes(val) + return nil + } + + case reflect.Array: + if v.Type().Elem().Kind() == reflect.Uint8 { + i := reflect.Copy(v, reflect.ValueOf(val)) + for ; i < v.Len(); i++ { + v.Index(i).SetUint(0) + } + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode lob to %v", v.Type().String()) +} + +func (d *Decoder) decodeStructTo(v reflect.Value) error { + switch v.Kind() { + case reflect.Struct: + return d.decodeStructToStruct(v) + + case reflect.Map: + return d.decodeStructToMap(v) + + case reflect.Interface: + if v.NumMethod() == 0 { + m, err := d.decodeMap() + if err != nil { + return err + } + v.Set(reflect.ValueOf(m)) + return nil + } + } + return fmt.Errorf("ion: cannot decode struct to %v", v.Type().String()) +} + +func (d *Decoder) decodeStructToStruct(v reflect.Value) error { + fields := fieldsFor(v.Type()) + + if err := d.r.StepIn(); err != nil { + return err + } + + for d.r.Next() { + name := d.r.FieldName() + field := findField(fields, name) + if field != nil { + subv, err := findSubvalue(v, field) + if err != nil { + return err + } + + if err := d.decodeTo(subv); err != nil { + return err + } + } + } + + return d.r.StepOut() +} + +func findField(fields []field, name string) *field { + var f *field + for i := range fields { + ff := &fields[i] + if ff.name == name { + return ff + } + if f == nil && strings.EqualFold(ff.name, name) { + f = ff + } + } + return f +} + +func findSubvalue(v reflect.Value, f *field) (reflect.Value, error) { + for _, i := range f.path { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + if !v.CanSet() { + return reflect.Value{}, fmt.Errorf("ion: cannot set embedded pointer to unexported struct: %v", v.Type().Elem()) + } + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + v = v.Field(i) + } + return v, nil +} + +func (d *Decoder) decodeStructToMap(v reflect.Value) error { + t := v.Type() + switch t.Key().Kind() { + case reflect.String: + default: + return fmt.Errorf("ion: cannot decode struct to %v", t.String()) + } + + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + + subv := reflect.New(t.Elem()).Elem() + + if err := d.r.StepIn(); err != nil { + return err + } + + for d.r.Next() { + name := d.r.FieldName() + if err := d.decodeTo(subv); err != nil { + return err + } + + var kv reflect.Value + switch t.Key().Kind() { + case reflect.String: + kv = reflect.ValueOf(name) + default: + panic("wat?") + } + + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + return d.r.StepOut() +} + +func (d *Decoder) decodeSliceTo(v reflect.Value) error { + k := v.Kind() + + // If all we know is we need an interface{}, decode an []interface{} with + // types based on the Ion value stream. + if k == reflect.Interface && v.NumMethod() == 0 { + s, err := d.decodeSlice() + if err != nil { + return err + } + v.Set(reflect.ValueOf(s)) + return nil + } + + // Only other valid targets are arrays and slices. + if k != reflect.Array && k != reflect.Slice { + return fmt.Errorf("ion: cannot unmarshal slice to %v", v.Type().String()) + } + + if err := d.r.StepIn(); err != nil { + return err + } + + i := 0 + + // Decode values into the array or slice. + for d.r.Next() { + if v.Kind() == reflect.Slice { + // If it's a slice, we can grow it as needed. + if i >= v.Cap() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + if err := d.decodeTo(v.Index(i)); err != nil { + return err + } + } + + i++ + } + + if err := d.r.StepOut(); err != nil { + return err + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + // Zero out any additional values. + z := reflect.Zero(v.Type().Elem()) + for ; i < v.Len(); i++ { + v.Index(i).Set(z) + } + } else { + v.SetLen(i) + } + } + + return nil +} + +// Dig in through any pointers to find the actual underlying value that we want +// to set. If wantPtr is false, the algorithm terminates at a non-ptr value (e.g., +// if passed an *int, it returns the int it points to, allocating such an int if the +// pointer is currently nil). If wantPtr is true, it terminates on a pointer to that +// value (allowing said pointer to be set to nil, generally). +func indirect(v reflect.Value, wantPtr bool) reflect.Value { + for { + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!wantPtr || e.Elem().Kind() == reflect.Ptr) { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && wantPtr && v.CanSet() { + break + } + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + v = v.Elem() + } + + return v +} diff --git a/unmarshal_test.go b/unmarshal_test.go new file mode 100644 index 00000000..47c7b76a --- /dev/null +++ b/unmarshal_test.go @@ -0,0 +1,546 @@ +package ion + +import ( + "bytes" + "math" + "math/big" + "reflect" + "testing" + "time" +) + +func TestDecodeBool(t *testing.T) { + test := func(str string, eval bool) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val bool + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("null", false) + test("true", true) + test("false", false) +} +func TestDecodeBoolPtr(t *testing.T) { + test := func(str string, eval interface{}) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var bval bool + val := &bval + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if eval == nil { + if val != nil { + t.Errorf("expected , got %v", *val) + } + } else { + switch { + case val == nil: + t.Errorf("expected %v, got ", eval) + case *val != eval.(bool): + t.Errorf("expected %v, got %v", eval, *val) + } + } + }) + } + + test("null", nil) + test("null.bool", nil) + test("false", false) + test("true", true) +} + +func TestDecodeInt(t *testing.T) { + testInt8 := func(str string, eval int8) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val int8 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt8("null", 0) + testInt8("0", 0) + testInt8("0x7F", 0x7F) + testInt8("-0x80", -0x80) + + testInt16 := func(str string, eval int16) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val int16 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt16("0x7F", 0x7F) + testInt16("-0x80", -0x80) + testInt16("0x7FFF", 0x7FFF) + testInt16("-0x8000", -0x8000) + + testInt32 := func(str string, eval int32) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val int32 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt32("0x7FFF", 0x7FFF) + testInt32("-0x8000", -0x8000) + testInt32("0x7FFFFFFF", 0x7FFFFFFF) + testInt32("-0x80000000", -0x80000000) + + testInt := func(str string, eval int) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val int + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt("0x7FFF", 0x7FFF) + testInt("-0x8000", -0x8000) + testInt("0x7FFFFFFF", 0x7FFFFFFF) + testInt("-0x80000000", -0x80000000) + + testInt64 := func(str string, eval int64) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val int64 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt64("0x7FFFFFFF", 0x7FFFFFFF) + testInt64("-0x80000000", -0x80000000) + testInt64("0x7FFFFFFFFFFFFFFF", 0x7FFFFFFFFFFFFFFF) + testInt64("-0x8000000000000000", -0x8000000000000000) +} + +func TestDecodeUint(t *testing.T) { + testUint8 := func(str string, eval uint8) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val uint8 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint8("null", 0) + testUint8("0", 0) + testUint8("0xFF", 0xFF) + + testUint16 := func(str string, eval uint16) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val uint16 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint16("0xFF", 0xFF) + testUint16("0xFFFF", 0xFFFF) + + testUint32 := func(str string, eval uint32) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val uint32 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint32("0xFFFF", 0xFFFF) + testUint32("0xFFFFFFFF", 0xFFFFFFFF) + + testUint := func(str string, eval uint) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val uint + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint("0xFFFF", 0xFFFF) + testUint("0xFFFFFFFF", 0xFFFFFFFF) + + testUintptr := func(str string, eval uintptr) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val uintptr + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUintptr("0xFFFF", 0xFFFF) + testUintptr("0xFFFFFFFF", 0xFFFFFFFF) + + testUint64 := func(str string, eval uint64) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val uint64 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint64("0xFFFFFFFF", 0xFFFFFFFF) + testUint64("0xFFFFFFFFFFFFFFFF", 0xFFFFFFFFFFFFFFFF) +} + +func TestDecodeBigInt(t *testing.T) { + test := func(str string, eval *big.Int) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val big.Int + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val.Cmp(eval) != 0 { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test("null", new(big.Int)) + test("1", new(big.Int).SetUint64(1)) + test("-0xFFFFFFFFFFFFFFFF", new(big.Int).Neg(new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF))) +} + +func TestDecodeFloat(t *testing.T) { + test32 := func(str string, eval float32) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val float32 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test32("null", 0) + test32("1e0", 1) + test32("1e38", 1e38) + test32("+inf", float32(math.Inf(1))) + + test64 := func(str string, eval float64) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val float64 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test64("1e0", 1) + test64("1e308", 1e308) + test64("+inf", math.Inf(1)) +} + +func TestDecodeTimeTo(t *testing.T) { + test := func(str string, eval time.Time) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val time.Time + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test("null", time.Time{}) + test("2020T", time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) +} + +func TestDecodeStringTo(t *testing.T) { + test := func(str string, eval string) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val string + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("null", "") + test("hello", "hello") + test("\"hello\"", "hello") +} + +func TestDecodeLobTo(t *testing.T) { + testSlice := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val []byte + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testSlice("null", nil) + testSlice("{{}}", []byte{}) + testSlice("{{aGVsbG8=}}", []byte("hello")) + testSlice("{{'''hello'''}}", []byte("hello")) + + testArray := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val [8]byte + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(val[:], eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testArray("null", make([]byte, 8)) + testArray("{{aGVsbG8=}}", append([]byte("hello"), []byte{0, 0, 0}...)) +} + +func TestDecodeStructTo(t *testing.T) { + test := func(str string, val, eval interface{}) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + err := d.DecodeTo(val) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + type foo struct { + Foo string + Baz int `json:"bar"` + } + + test("{}", &struct{}{}, &struct{}{}) + test("{bogus:(ignore me)}", &foo{}, &foo{}) + test("{foo:bar}", &foo{}, &foo{"bar", 0}) + test("{bar:42}", &foo{}, &foo{"", 42}) + test("{foo:bar,bar:42,bogus:(ignore me)}", &foo{}, &foo{"bar", 42}) + + test("{}", &map[string]string{}, &map[string]string{}) + test("{foo:bar}", &map[string]string{}, &map[string]string{"foo": "bar"}) + test("{a:4,b:2}", &map[string]int{}, &map[string]int{"a": 4, "b": 2}) +} + +func TestDecodeListTo(t *testing.T) { + test := func(str string, val, eval interface{}) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + err := d.DecodeTo(val) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + f := false + pf := &f + ppf := &pf + + test("[]", &[]bool{}, &[]bool{}) + test("[]", &[]bool{true}, &[]bool{}) + + test("[false]", &[]bool{}, &[]bool{false}) + test("[false]", &[]*bool{}, &[]*bool{pf}) + test("[false,false]", &[]**bool{}, &[]**bool{ppf, ppf}) + + test("[true,false]", &[]interface{}{}, &[]interface{}{true, false}) + + var i interface{} + var ei interface{} = []interface{}{true, false} + test("[true,false]", &i, &ei) +} + +func TestDecode(t *testing.T) { + test := func(data string, eval interface{}) { + t.Run(data, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(data)) + val, err := d.Decode() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("null", nil) + test("null.null", nil) + + test("null.bool", nil) + test("true", true) + test("false", false) + + test("null.int", nil) + test("0", int(0)) + test("2147483647", math.MaxInt32) + test("-2147483648", math.MinInt32) + test("2147483648", int64(math.MaxInt32)+1) + test("-2147483649", int64(math.MinInt32)-1) + test("9223372036854775808", new(big.Int).SetUint64(math.MaxInt64+1)) + + test("0e0", float64(0.0)) + test("1e100", float64(1e100)) + + test("0.", MustParseDecimal("0.")) + + test("2020T", time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) + + test("hello", "hello") + test("\"hello\"", "hello") + + test("null.blob", nil) + test("{{}}", []byte{}) + test("{{aGVsbG8=}}", []byte("hello")) + + test("null.clob", nil) + test("{{''''''}}", []byte{}) + test("{{'''hello'''}}", []byte("hello")) + + test("null.struct", nil) + test("{}", map[string]interface{}{}) + test("{a:1,b:two}", map[string]interface{}{ + "a": 1, + "b": "two", + }) + + test("null.list", nil) + test("[]", []interface{}{}) + test("[1, two]", []interface{}{1, "two"}) + + test("null.sexp", nil) + test("()", []interface{}{}) + test("(1 + two)", []interface{}{1, "+", "two"}) +} From 22e024c88a842b5f29016fa45c584e48c7139430 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 29 Jul 2019 12:27:21 +1000 Subject: [PATCH 26/56] tackling some todos --- decimal.go | 4 +--- marshal.go | 26 +++++++++++++++++++++++--- marshal_test.go | 4 ++++ unmarshal.go | 38 +++++++++++++++++++++++++++++++++++--- unmarshal_test.go | 21 +++++++++++++++++++++ 5 files changed, 84 insertions(+), 9 deletions(-) diff --git a/decimal.go b/decimal.go index ee4ad829..cf4cf223 100644 --- a/decimal.go +++ b/decimal.go @@ -81,7 +81,7 @@ func ParseDecimal(in string) (*Decimal, error) { n, ok := new(big.Int).SetString(in, 10) if !ok { // Unfortunately this is all we get? - return nil, errors.New("not a valid number") + return nil, fmt.Errorf("not a valid number: %v", in) } return NewDecimalWithScale(n, -shift), nil @@ -258,8 +258,6 @@ func (d *Decimal) String() string { switch { case d.scale == 0: // Value is an unscaled integer. Just mark it as a decimal. - // TODO: If there are enough trailing zeros should we knock them - // off and do nnn'd'sss here? That'd technically erase precision. return d.n.String() + "." case d.scale < 0: diff --git a/marshal.go b/marshal.go index 2b7b8227..9cc7967d 100644 --- a/marshal.go +++ b/marshal.go @@ -8,6 +8,7 @@ import ( "reflect" "sort" "strings" + "time" ) // MarshalText marshals values to text ion. @@ -89,9 +90,6 @@ func (m *Encoder) marshalValue(v reflect.Value) error { m.w.WriteFloat(v.Float()) return m.w.Err() - // TODO: Decimal - // TODO: Time - case reflect.String: m.w.WriteString(v.String()) return m.w.Err() @@ -210,7 +208,17 @@ func (m *Encoder) marshalArray(v reflect.Value) error { return m.w.Err() } +var decimalType = reflect.TypeOf(Decimal{}) + func (m *Encoder) marshalStruct(v reflect.Value) error { + t := v.Type() + if t == timeType { + return m.marshalTime(v) + } + if t == decimalType { + return m.marshalDecimal(v) + } + fields := fieldsFor(v.Type()) m.w.BeginStruct() @@ -244,6 +252,18 @@ FieldLoop: return m.w.Err() } +func (m *Encoder) marshalTime(v reflect.Value) error { + t := v.Interface().(time.Time) + m.w.WriteTimestamp(t) + return m.w.Err() +} + +func (m *Encoder) marshalDecimal(v reflect.Value) error { + d := v.Addr().Interface().(*Decimal) + m.w.WriteDecimal(d) + return m.w.Err() +} + func emptyValue(v reflect.Value) bool { switch v.Kind() { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: diff --git a/marshal_test.go b/marshal_test.go index 8c1ae4ec..85703933 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -3,6 +3,7 @@ package ion import ( "math" "testing" + "time" ) func TestMarshalText(t *testing.T) { @@ -32,6 +33,9 @@ func TestMarshalText(t *testing.T) { test(math.Inf(-1), "-inf") test(math.NaN(), "nan") + test(MustParseDecimal("1.20"), "1.20") + test(time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC), "2010-01-01T00:00:00Z") + test("hello\tworld", "\"hello\\tworld\"") test(struct{ A, B int }{42, 0}, "{A:42,B:0}") diff --git a/unmarshal.go b/unmarshal.go index 89c7ac50..cda3a8a1 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -6,6 +6,7 @@ import ( "fmt" "math/big" "reflect" + "strconv" "strings" "time" ) @@ -188,8 +189,8 @@ func (d *Decoder) decodeTo(v reflect.Value) error { case FloatType: return d.decodeFloatTo(v) - // TODO: Decimal - // TODO: Timestamp + case DecimalType: + return d.decodeDecimalTo(v) case TimestampType: return d.decodeTimestampTo(v) @@ -310,7 +311,16 @@ func (d *Decoder) decodeFloatTo(v reflect.Value) error { v.SetFloat(val) return nil - // TODO: Decimal + case reflect.Struct: + if v.Type() == decimalType { + flt := strconv.FormatFloat(val, 'g', -1, 64) + dec, err := ParseDecimal(strings.Replace(flt, "e", "d", 1)) + if err != nil { + return err + } + v.Set(reflect.ValueOf(*dec)) + return nil + } case reflect.Interface: if v.NumMethod() == 0 { @@ -321,6 +331,28 @@ func (d *Decoder) decodeFloatTo(v reflect.Value) error { return fmt.Errorf("ion: cannot decode float to %v", v.Type().String()) } +func (d *Decoder) decodeDecimalTo(v reflect.Value) error { + val, err := d.r.DecimalValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Struct: + if v.Type() == decimalType { + v.Set(reflect.ValueOf(*val)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode decimal to %v", v.Type().String()) +} + var timeType = reflect.TypeOf(time.Time{}) func (d *Decoder) decodeTimestampTo(v reflect.Value) error { diff --git a/unmarshal_test.go b/unmarshal_test.go index 47c7b76a..26bd7dd8 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -338,6 +338,27 @@ func TestDecodeFloat(t *testing.T) { test64("+inf", math.Inf(1)) } +func TestDecodeDecimal(t *testing.T) { + test := func(str string, eval *Decimal) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewTextReaderString(str)) + + var val *Decimal + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if !val.Equal(eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("1e10", MustParseDecimal("1d10")) + test("1.20", MustParseDecimal("1.20")) +} + func TestDecodeTimeTo(t *testing.T) { test := func(str string, eval time.Time) { t.Run(str, func(t *testing.T) { From 7501e3ae647f28176ea6520a5143c7fcff295de3 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 6 Aug 2019 11:59:46 +1000 Subject: [PATCH 27/56] binary writer --- binarywriter.go | 517 +++++++++++++++++++++++++++++++++++++++++++ binarywriter_test.go | 362 ++++++++++++++++++++++++++++++ bits.go | 184 +++++++++++++++ bits_test.go | 138 ++++++++++++ bufnode.go | 94 ++++++++ bufnode_test.go | 79 +++++++ ctx.go | 2 +- decimal.go | 32 +-- symboltable.go | 17 +- symboltable_test.go | 37 ++-- textwriter_test.go | 4 +- 11 files changed, 1431 insertions(+), 35 deletions(-) create mode 100644 binarywriter.go create mode 100644 binarywriter_test.go create mode 100644 bits.go create mode 100644 bits_test.go create mode 100644 bufnode.go create mode 100644 bufnode_test.go diff --git a/binarywriter.go b/binarywriter.go new file mode 100644 index 00000000..4080718a --- /dev/null +++ b/binarywriter.go @@ -0,0 +1,517 @@ +package ion + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "math/big" + "time" +) + +// A cstack is a stack of containers. +type cstack struct { + arr []*container +} + +func (c *cstack) peek() *container { + if len(c.arr) == 0 { + return nil + } + return c.arr[len(c.arr)-1] +} + +func (c *cstack) push(code byte) { + c.arr = append(c.arr, &container{code: code}) +} + +func (c *cstack) pop() { + if len(c.arr) == 0 { + panic("pop called at top level") + } + c.arr = c.arr[:len(c.arr)-1] +} + +type binaryWriterLST struct { + writer + cs cstack + lst SymbolTable + + wroteLST bool +} + +// NewBinaryWriter creates a new binary writer. +func NewBinaryWriter(out io.Writer, lst SymbolTable) Writer { + return &binaryWriterLST{ + writer: writer{ + out: out, + }, + lst: lst, + } +} + +func (w *binaryWriterLST) write(c bufnode) error { + p := w.cs.peek() + if p == nil { + return c.WriteTo(w.out) + } + p.Add(c) + return nil +} + +func (w *binaryWriterLST) writeTag(code byte, len int) error { + buf := bytes.Buffer{} + writeTag(&buf, code, uint64(len)) + return w.write(atom(buf.Bytes())) +} + +func (w *binaryWriterLST) writeLST() error { + if _, err := w.out.Write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { + return err + } + + // Prevent recursion... + w.wroteLST = true + + return w.lst.WriteTo(w) +} + +func (w *binaryWriterLST) beginValue() error { + // Have to record/empty these before calling writeLST, which + // will end up modifying them. Ugh. + name := w.fieldName + w.fieldName = "" + tas := w.typeAnnotations + w.typeAnnotations = nil + + if !w.wroteLST { + if err := w.writeLST(); err != nil { + return err + } + } + + if w.InStruct() { + if name == "" { + return errors.New("ion: field name not set") + } + + id, ok := w.lst.FindByName(name) + if !ok { + return fmt.Errorf("ion: symbol '%v' not defined", name) + } + if id < 0 { + panic("negative id") + } + + if err := w.write(fieldname(id)); err != nil { + return err + } + } + + if len(tas) > 0 { + w.cs.push(0xE0) + + ids := make([]uint64, len(tas)) + idlen := uint64(0) + + for i, a := range tas { + id, ok := w.lst.FindByName(a) + if !ok { + return fmt.Errorf("ion: symbol '%v' not defined", a) + } + if id < 0 { + panic("negative id") + } + ids[i] = uint64(id) + idlen += varUintLen(uint64(id)) + } + + buf := bytes.Buffer{} + buf.Write(packVarUint(idlen)) + + for _, id := range ids { + buf.Write(packVarUint(id)) + } + + if err := w.write(atom(buf.Bytes())); err != nil { + return err + } + } + + return nil +} + +func (w *binaryWriterLST) endValue() error { + cur := w.cs.peek() + if cur != nil && cur.code == 0xE0 { + // If we're in an annotation container, write it up a level now that we + // know the length of the value. + w.cs.pop() + return w.write(cur) + } + return nil +} + +func (w *binaryWriterLST) writeValue(f func() []byte) error { + if err := w.beginValue(); err != nil { + return err + } + + val := f() + + if err := w.write(atom(val)); err != nil { + return err + } + + return w.endValue() +} + +func (w *binaryWriterLST) writeValueStreaming(f func() error) error { + if err := w.beginValue(); err != nil { + return err + } + + if err := f(); err != nil { + return err + } + + return w.endValue() +} + +func (w *binaryWriterLST) begin(t ctxType, code byte) error { + if err := w.beginValue(); err != nil { + return err + } + + w.ctx.push(t) + w.cs.push(code) + + return nil +} + +func (w *binaryWriterLST) end(t ctxType) error { + if w.ctx.peek() != t { + return errors.New("ion: not in that kind of container") + } + + cur := w.cs.peek() + if cur != nil { + w.cs.pop() + if err := w.write(cur); err != nil { + return err + } + } + + w.fieldName = "" + w.typeAnnotations = nil + w.ctx.pop() + + return w.endValue() +} + +func (w *binaryWriterLST) BeginStruct() { + if w.err != nil { + return + } + w.err = w.begin(ctxInStruct, 0xD0) +} + +func (w *binaryWriterLST) EndStruct() { + if w.err != nil { + return + } + w.err = w.end(ctxInStruct) +} + +func (w *binaryWriterLST) BeginList() { + if w.err != nil { + return + } + w.err = w.begin(ctxInList, 0xB0) +} + +func (w *binaryWriterLST) EndList() { + if w.err != nil { + return + } + w.err = w.end(ctxInList) +} + +func (w *binaryWriterLST) BeginSexp() { + if w.err != nil { + return + } + w.err = w.begin(ctxInSexp, 0xC0) +} + +func (w *binaryWriterLST) EndSexp() { + if w.err != nil { + return + } + w.err = w.end(ctxInSexp) +} + +func (w *binaryWriterLST) WriteNull() { + w.WriteNullWithType(NullType) +} + +func (w *binaryWriterLST) WriteNullWithType(t Type) { + if w.err != nil { + return + } + + w.err = w.writeValue(func() []byte { + var b byte + switch t { + case NoType, NullType: + b = 0x0F + case BoolType: + b = 0x1F + case IntType: + b = 0x2F + case FloatType: + b = 0x4F + case DecimalType: + b = 0x5F + case TimestampType: + b = 0x6F + case SymbolType: + b = 0x7F + case StringType: + b = 0x8F + case ClobType: + b = 0x9F + case BlobType: + b = 0xAF + case ListType: + b = 0xBF + case SexpType: + b = 0xCF + case StructType: + b = 0xDF + default: + panic(fmt.Sprintf("invalid type: %v", t)) + } + + return []byte{b} + }) +} + +func (w *binaryWriterLST) WriteBool(val bool) { + if w.err != nil { + return + } + + w.err = w.writeValue(func() []byte { + if val { + return []byte{0x11} + } + return []byte{0x10} + }) +} + +func (w *binaryWriterLST) WriteInt(val int64) { + if w.err != nil { + return + } + + w.err = w.writeValueStreaming(func() error { + if val == 0 { + return w.write(atom([]byte{0x20})) + } + + code := byte(0x20) + mag := uint64(val) + + if val < 0 { + code = 0x30 + mag = uint64(-val) + } + + bs := packUint(mag) + + if err := w.writeTag(code, len(bs)); err != nil { + return err + } + return w.write(atom(bs)) + }) +} + +func (w *binaryWriterLST) WriteBigInt(val *big.Int) { + if w.err != nil { + return + } + + w.err = w.writeValueStreaming(func() error { + sign := val.Sign() + if sign == 0 { + return w.write(atom([]byte{0x20})) + } + + code := byte(0x20) + if sign < 0 { + code = 0x30 + } + + bs := val.Bytes() + + if err := w.writeTag(code, len(bs)); err != nil { + return err + } + return w.write(atom(bs)) + }) +} + +func (w *binaryWriterLST) WriteFloat(val float64) { + if w.err != nil { + return + } + + w.err = w.writeValue(func() []byte { + if val == 0 { + return []byte{0x40} + } + + bs := make([]byte, 9) + bs[0] = 0x48 + + bits := math.Float64bits(val) + binary.BigEndian.PutUint64(bs[1:], bits) + + return bs + }) +} + +func (w *binaryWriterLST) WriteDecimal(val *Decimal) { + if w.err != nil { + return + } + + w.writeValueStreaming(func() error { + coef, exp := val.CoEx() + + ebs := []byte{} + if exp != 0 { + ebs = packVarInt(int64(exp)) + } + + cbs := packBigInt(coef) + + if err := w.writeTag(0x50, len(cbs)+len(ebs)); err != nil { + return err + } + + if len(ebs) > 0 { + if err := w.write(atom(ebs)); err != nil { + return err + } + } + + if len(cbs) > 0 { + if err := w.write(atom(cbs)); err != nil { + return err + } + } + + return nil + }) +} + +func (w *binaryWriterLST) WriteTimestamp(val time.Time) { + if w.err != nil { + return + } + + w.err = w.writeValueStreaming(func() error { + bs := packTime(val) + if err := w.writeTag(0x60, len(bs)); err != nil { + return err + } + return w.write(atom(bs)) + }) +} + +func (w *binaryWriterLST) WriteSymbol(val string) { + if w.err != nil { + return + } + + id, ok := w.lst.FindByName(val) + if !ok { + w.err = fmt.Errorf("ion: symbol '%v' not defined in local symbol table", val) + return + } + + w.err = w.writeValueStreaming(func() error { + bs := packUint(uint64(id)) + if err := w.writeTag(0x70, len(bs)); err != nil { + return err + } + return w.write(atom(bs)) + }) +} + +func (w *binaryWriterLST) WriteString(val string) { + if w.err != nil { + return + } + + w.err = w.writeValueStreaming(func() error { + if len(val) == 0 { + return w.write(atom([]byte{0x80})) + } + + bs := []byte(val) + + if err := w.writeTag(0x80, len(bs)); err != nil { + return err + } + return w.write(atom(bs)) + }) +} + +func (w *binaryWriterLST) WriteClob(val []byte) { + if w.err != nil { + return + } + + w.err = w.writeValueStreaming(func() error { + if err := w.writeTag(0x90, len(val)); err != nil { + return err + } + return w.write(atom(val)) + }) +} + +func (w *binaryWriterLST) WriteBlob(val []byte) { + if w.err != nil { + return + } + + w.err = w.writeValueStreaming(func() error { + if err := w.writeTag(0xA0, len(val)); err != nil { + return err + } + return w.write(atom(val)) + }) +} + +func (w *binaryWriterLST) WriteValue(val interface{}) { + w.err = errors.New("not yet implemented") +} + +func (w *binaryWriterLST) Finish() error { + if w.err != nil { + return w.err + } + if w.ctx.peek() != ctxAtTopLevel { + w.err = errors.New("ion: not at top level") + return w.err + } + + // TODO: Flush all them buffers mate! + + return nil +} diff --git a/binarywriter_test.go b/binarywriter_test.go new file mode 100644 index 00000000..3ff1da6b --- /dev/null +++ b/binarywriter_test.go @@ -0,0 +1,362 @@ +package ion + +import ( + "bytes" + "encoding/hex" + "fmt" + "math" + "math/big" + "strings" + "testing" + "time" +) + +func TestWriteBinaryStruct(t *testing.T) { + eval := []byte{ + 0xD0, // {} + 0xEA, 0x81, 0xEE, 0xD7, // foo::{ + 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, + 0x88, 0x20, // max_id:0 + // } + } + testBinaryWriter(t, eval, func(w Writer) { + w.BeginStruct() + w.EndStruct() + + w.TypeAnnotation("foo") + w.BeginStruct() + { + w.FieldName("name") + w.TypeAnnotation("bar") + w.WriteNull() + + w.FieldName("max_id") + w.WriteInt(0) + } + w.EndStruct() + }) +} + +func TestWriteBinarySexp(t *testing.T) { + eval := []byte{ + 0xC0, // () + 0xE8, 0x81, 0xEE, 0xC5, // foo::( + 0xE3, 0x81, 0xEF, 0x0F, // bar::null, + 0x20, // 0 + // ) + } + testBinaryWriter(t, eval, func(w Writer) { + w.BeginSexp() + w.EndSexp() + + w.TypeAnnotation("foo") + w.BeginSexp() + { + w.TypeAnnotation("bar") + w.WriteNull() + + w.WriteInt(0) + } + w.EndSexp() + }) +} + +func TestWriteBinaryList(t *testing.T) { + eval := []byte{ + 0xB0, // [] + 0xE8, 0x81, 0xEE, 0xB5, // foo::[ + 0xE3, 0x81, 0xEF, 0x0F, // bar::null, + 0x20, // 0 + // ] + } + testBinaryWriter(t, eval, func(w Writer) { + w.BeginList() + w.EndList() + + w.TypeAnnotation("foo") + w.BeginList() + { + w.TypeAnnotation("bar") + w.WriteNull() + + w.WriteInt(0) + } + w.EndList() + }) +} + +func TestWriteBinaryBlob(t *testing.T) { + eval := []byte{ + 0xA0, + 0xAB, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBlob([]byte{}) + w.WriteBlob([]byte("Hello World")) + }) +} + +func TestWriteBinaryClob(t *testing.T) { + eval := []byte{ + 0x90, + 0x9B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteClob([]byte{}) + w.WriteClob([]byte("Hello World")) + }) +} + +func TestWriteBinaryString(t *testing.T) { + eval := []byte{ + 0x80, // "" + 0x8B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + 0x8E, 0x9B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + ' ', 'B', 'u', 't', ' ', 'E', 'v', 'e', 'n', ' ', 'L', 'o', 'n', 'g', 'e', 'r', + 0x84, 0xE0, 0x01, 0x00, 0xEA, + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteString("") + w.WriteString("Hello World") + w.WriteString("Hello World But Even Longer") + w.WriteString("\xE0\x01\x00\xEA") + }) +} + +func TestWriteBinarySymbol(t *testing.T) { + eval := []byte{ + 0x71, 0x01, // $ion + 0x71, 0x04, // name + 0x71, 0x05, // version + 0x71, 0x09, // $ion_shared_symbol_table + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteSymbol("$ion") + w.WriteSymbol("name") + w.WriteSymbol("version") + w.WriteSymbol("$ion_shared_symbol_table") + }) +} + +func TestWriteBinaryTimestamp(t *testing.T) { + eval := []byte{ + 0x67, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80, // 0001-01-01T00:00:00Z + 0x6E, 0x8E, // 0x0E-bit timestamp + 0x04, 0xD8, // offset: +600 minutes (+10:00) + 0x0F, 0xE3, // year: 2019 + 0x88, // month: 8 + 0x84, // day: 4 + 0x88, // hour: 8 utc (18 local) + 0x8F, // minute: 15 + 0xAB, // second: 43 + 0xC9, // exp: -9 + 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 + } + + nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteTimestamp(time.Time{}) + w.WriteTimestamp(nowish) + }) +} + +func TestWriteBinaryDecimal(t *testing.T) { + eval := []byte{ + 0x50, // 0. + 0x51, 0xC3, // 0.000, aka 0 x 10^-3 + 0x53, 0xC3, 0x03, 0xE8, // 1.000, aka 1000 x 10^-3 + 0x53, 0xC3, 0x83, 0xE8, // -1.000, aka -1000 x 10^-3 + 0x53, 0x00, 0xE4, 0x01, // 1d100, aka 1 * 10^100 + 0x53, 0x00, 0xE4, 0x81, // -1d100, aka -1 * 10^100 + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteDecimal(MustParseDecimal("0.")) + w.WriteDecimal(MustParseDecimal("0.000")) + w.WriteDecimal(MustParseDecimal("1.000")) + w.WriteDecimal(MustParseDecimal("-1.000")) + w.WriteDecimal(MustParseDecimal("1d100")) + w.WriteDecimal(MustParseDecimal("-1d100")) + }) +} + +func TestWriteBinaryFloats(t *testing.T) { + eval := []byte{ + 0x40, // 0 + 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 + 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 + 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf + 0x48, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -inf + 0x48, 0x7F, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // NaN + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteFloat(0) + w.WriteFloat(math.MaxFloat64) + w.WriteFloat(-math.MaxFloat64) + w.WriteFloat(math.Inf(1)) + w.WriteFloat(math.Inf(-1)) + w.WriteFloat(math.NaN()) + }) +} + +func TestWriteBinaryBigInts(t *testing.T) { + eval := []byte{ + 0x20, // 0 + 0x21, 0xFF, // 0xFF + 0x31, 0xFF, // -0xFF + 0x2E, 0x90, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // a really big integer + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBigInt(big.NewInt(0)) + w.WriteBigInt(big.NewInt(0xFF)) + w.WriteBigInt(big.NewInt(-0xFF)) + w.WriteBigInt(new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})) + }) +} + +func TestWriteBinaryInts(t *testing.T) { + eval := []byte{ + 0x20, // 0 + 0x21, 0xFF, // 0xFF + 0x31, 0xFF, // -0xFF + 0x22, 0xFF, 0xFF, // 0xFFFF + 0x33, 0xFF, 0xFF, 0xFF, // -0xFFFFFF + 0x28, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // math.MaxInt64 + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteInt(0) + w.WriteInt(0xFF) + w.WriteInt(-0xFF) + w.WriteInt(0xFFFF) + w.WriteInt(-0xFFFFFF) + w.WriteInt(math.MaxInt64) + }) +} + +func TestWriteBinaryBoolAnnotated(t *testing.T) { + eval := []byte{ + 0xE4, // 4-byte annotated value + 0x82, // 2 bytes of annotations + 0x84, // $4 (name) + 0x85, // $5 (version) + 0x10, // false + } + + testBinaryWriter(t, eval, func(w Writer) { + w.TypeAnnotations("name", "version") + w.WriteBool(false) + }) +} + +func TestWriteBinaryBools(t *testing.T) { + eval := []byte{ + 0x10, // false + 0x11, // true + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBool(false) + w.WriteBool(true) + }) +} + +func TestWriteBinaryNulls(t *testing.T) { + eval := []byte{ + 0x0F, + 0x1F, + 0x2F, + // 0x3F, // negative integer, not actually valid + 0x4F, + 0x5F, + 0x6F, + 0x7F, + 0x8F, + 0x9F, + 0xAF, + 0xBF, + 0xCF, + 0xDF, + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteNull() + w.WriteNullWithType(BoolType) + w.WriteNullWithType(IntType) + w.WriteNullWithType(FloatType) + w.WriteNullWithType(DecimalType) + w.WriteNullWithType(TimestampType) + w.WriteNullWithType(SymbolType) + w.WriteNullWithType(StringType) + w.WriteNullWithType(ClobType) + w.WriteNullWithType(BlobType) + w.WriteNullWithType(ListType) + w.WriteNullWithType(SexpType) + w.WriteNullWithType(StructType) + }) +} + +func testBinaryWriter(t *testing.T, eval []byte, f func(w Writer)) { + val := writeBinary(t, f) + + prefix := []byte{ + 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ + 0x86, 0xBE, 0x8E, // imports:[ + 0xDD, // { + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" + 0x85, 0x21, 0x2A, // version: 42 + 0x88, 0x21, 0x64, // max_id: 100 + // }] + 0x87, 0xB8, // symbols: [ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" + // ] + // } + } + eval = append(prefix, eval...) + + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", fmtbytes(eval), fmtbytes(val)) + } +} + +func fmtbytes(bs []byte) string { + buf := strings.Builder{} + buf.WriteByte('[') + for i, b := range bs { + if i > 0 { + buf.WriteByte(' ') + } + buf.WriteString(hex.EncodeToString([]byte{b})) + } + buf.WriteByte(']') + return buf.String() +} + +func writeBinary(t *testing.T, f func(w Writer)) []byte { + bogusSyms := []string{} + for i := 0; i < 100; i++ { + bogusSyms = append(bogusSyms, fmt.Sprintf("bogus_sym_%v", i)) + } + + bogus := []SharedSymbolTable{ + NewSharedSymbolTable("bogus", 42, bogusSyms), + } + + buf := bytes.Buffer{} + w := NewBinaryWriter(&buf, NewLocalSymbolTable(bogus, []string{ + "foo", + "bar", + })) + + f(w) + + if w.Err() != nil { + t.Fatal(w.Err()) + } + + return buf.Bytes() +} diff --git a/bits.go b/bits.go new file mode 100644 index 00000000..2209d50a --- /dev/null +++ b/bits.go @@ -0,0 +1,184 @@ +package ion + +import ( + "bytes" + "math/big" + "time" +) + +func uintLen(v uint64) uint64 { + len := uint64(1) + v >>= 8 + + for v > 0 { + len++ + v >>= 8 + } + + return len +} + +func packUint(v uint64) []byte { + var buf [8]byte + + i := 7 + buf[i] = byte(v & 0xFF) + v >>= 8 + + for v > 0 { + i-- + buf[i] = byte(v & 0xFF) + v >>= 8 + } + + return buf[i:] +} + +func packInt(n int64) []byte { + if n == 0 { + return []byte{} + } + + neg := false + mag := uint64(n) + + if n < 0 { + neg = true + mag = uint64(-n) + } + + bits := packUint(mag) + if bits[0]&0x80 != 0 { + bits = append([]byte{0}, bits...) + } + + if neg { + bits[0] ^= 0x80 + } + + return bits +} + +func packBigInt(v *big.Int) []byte { + sign := v.Sign() + if sign == 0 { + return []byte{} + } + + bits := v.Bytes() + + if bits[0]&0x80 != 0 { + // Need to make room for the sign bit. + bits = append([]byte{0}, bits...) + } + + if sign < 0 { + bits[0] ^= 0x80 + } + + return bits +} + +func varUintLen(v uint64) uint64 { + len := uint64(1) + v >>= 7 + + for v > 0 { + len++ + v >>= 7 + } + + return len +} + +func packVarUint(v uint64) []byte { + var buf [10]byte + + i := 9 + buf[i] = 0x80 | byte(v&0x7F) + v >>= 7 + + for v > 0 { + i-- + buf[i] = byte(v & 0x7F) + v >>= 7 + } + + return buf[i:] +} + +func varIntLen(v int64) uint64 { + mag := uint64(v) + if v < 0 { + mag = uint64(-v) + } + + // Reserve one extra bit of the first byte for sign. + len := uint64(1) + mag >>= 6 + + for mag > 0 { + len++ + mag >>= 7 + } + + return len +} + +func packVarInt(v int64) []byte { + var buf [10]byte + + signbit := byte(0) + mag := uint64(v) + if v < 0 { + signbit = 0x40 + mag = uint64(-v) + } + + next := mag >> 6 + if next == 0 { + // The whole thing fits in one byte. + return []byte{0x80 | signbit | byte(mag&0x3F)} + } + + i := 9 + buf[i] = 0x80 | byte(mag&0x7F) + mag >>= 7 + next = mag >> 6 + + for next > 0 { + i-- + buf[i] = byte(mag & 0x7F) + mag >>= 7 + next = mag >> 6 + } + + i-- + buf[i] = signbit | byte(mag&0x3F) + + return buf[i:] +} + +func packTime(t time.Time) []byte { + _, offset := t.Zone() + utc := t.In(time.UTC) + + buf := bytes.Buffer{} + buf.Write(packVarInt(int64(offset / 60))) + + buf.Write(packVarUint(uint64(utc.Year()))) + buf.Write(packVarUint(uint64(utc.Month()))) + buf.Write(packVarUint(uint64(utc.Day()))) + + buf.Write(packVarUint(uint64(utc.Hour()))) + buf.Write(packVarUint(uint64(utc.Minute()))) + buf.Write(packVarUint(uint64(utc.Second()))) + + ns := utc.Nanosecond() + if ns > 0 { + buf.Write(packVarInt(-9)) + buf.Write(packInt(int64(ns))) + } + + return buf.Bytes() +} diff --git a/bits_test.go b/bits_test.go new file mode 100644 index 00000000..24a6346c --- /dev/null +++ b/bits_test.go @@ -0,0 +1,138 @@ +package ion + +import ( + "bytes" + "fmt" + "math" + "math/big" + "testing" +) + +func TestPackUint(t *testing.T) { + test := func(val uint64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := uintLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := packUint(val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 1, []byte{0}) + test(0xFF, 1, []byte{0xFF}) + test(0x1FF, 2, []byte{0x01, 0xFF}) + test(math.MaxUint64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) +} + +func TestPackInt(t *testing.T) { + test := func(val int64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + bits := packInt(val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, []byte{}) + test(0x7F, []byte{0x7F}) + test(-0x7F, []byte{0xFF}) + + test(0xFF, []byte{0x00, 0xFF}) + test(-0xFF, []byte{0x80, 0xFF}) + + test(0x7FFF, []byte{0x7F, 0xFF}) + test(-0x7FFF, []byte{0xFF, 0xFF}) + + test(math.MaxInt64, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + test(-math.MaxInt64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + test(math.MinInt64, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) +} + +func TestPackBigInt(t *testing.T) { + test := func(val *big.Int, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + bits := packBigInt(val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(big.NewInt(0), []byte{}) + test(big.NewInt(0x7F), []byte{0x7F}) + test(big.NewInt(-0x7F), []byte{0xFF}) + + test(big.NewInt(0xFF), []byte{0x00, 0xFF}) + test(big.NewInt(-0xFF), []byte{0x80, 0xFF}) + + test(big.NewInt(0x7FFF), []byte{0x7F, 0xFF}) + test(big.NewInt(-0x7FFF), []byte{0xFF, 0xFF}) +} + +func TestPackVarUint(t *testing.T) { + test := func(val uint64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := varUintLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := packVarUint(val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 1, []byte{0x80}) + test(0x7F, 1, []byte{0xFF}) + test(0xFF, 2, []byte{0x01, 0xFF}) + test(0x1FF, 2, []byte{0x03, 0xFF}) + test(0x3FFF, 2, []byte{0x7F, 0xFF}) + test(0x7FFF, 3, []byte{0x01, 0x7F, 0xFF}) + test(0x7FFFFFFFFFFFFFFF, 9, []byte{0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(0xFFFFFFFFFFFFFFFF, 10, []byte{0x01, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) +} + +func TestPackVarInt(t *testing.T) { + test := func(val int64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := varIntLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := packVarInt(val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 1, []byte{0x80}) + + test(0x3F, 1, []byte{0xBF}) // 1011 1111 + test(-0x3F, 1, []byte{0xFF}) + + test(0x7F, 2, []byte{0x00, 0xFF}) + test(-0x7F, 2, []byte{0x40, 0xFF}) + + test(0x1FFF, 2, []byte{0x3F, 0xFF}) + test(-0x1FFF, 2, []byte{0x7F, 0xFF}) + + test(0x3FFF, 3, []byte{0x00, 0x7F, 0xFF}) + test(-0x3FFF, 3, []byte{0x40, 0x7F, 0xFF}) + + test(0x3FFFFFFFFFFFFFFF, 9, []byte{0x3F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(-0x3FFFFFFFFFFFFFFF, 9, []byte{0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + + test(math.MaxInt64, 10, []byte{0x00, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(-math.MaxInt64, 10, []byte{0x40, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(math.MinInt64, 10, []byte{0x41, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}) +} diff --git a/bufnode.go b/bufnode.go new file mode 100644 index 00000000..befeabe7 --- /dev/null +++ b/bufnode.go @@ -0,0 +1,94 @@ +package ion + +import ( + "io" +) + +// Writing binary ion is a bit tricky: values are preceded by their length, +// which can be hard to predict until we've actually written out the value. +// To make matters worse, we can't predict the length of the /length/ ahead +// of time in order to reserve space for it, because it uses a variable-length +// encoding. To avoid copying bytes around all over the place, we write into +// an in-memory tree structure, which we then blast out to the actual io.Writer +// once all the relevant lengths are known. + +// A bufnode is a node in the partially-serialized tree. +type bufnode interface { + Len() uint64 + WriteTo(w io.Writer) error +} + +// An atom is a value that has been fully serialized and can be written directly. +type atom []byte + +func (a atom) Len() uint64 { + return uint64(len(a)) +} + +func (a atom) WriteTo(w io.Writer) error { + _, err := w.Write(a) + return err +} + +// A fieldname is the symbol id of a field name inside a struct. +type fieldname uint64 + +func (f fieldname) Len() uint64 { + return varUintLen(uint64(f)) +} + +func (f fieldname) WriteTo(w io.Writer) error { + _, err := w.Write(packVarUint(uint64(f))) + return err +} + +// A container holds multiple child values and serializes them together with a +// tag and length on demand. +type container struct { + code byte + len uint64 + children []bufnode +} + +func (c *container) Add(n bufnode) { + c.len += n.Len() + c.children = append(c.children, n) +} + +func (c *container) Len() uint64 { + if c.len < 0x0E { + // Short tag + return c.len + 1 + } + // Long tag. + return c.len + varUintLen(c.len) + 1 +} + +func (c *container) WriteTo(w io.Writer) error { + if err := writeTag(w, c.code, c.len); err != nil { + return nil + } + + for _, child := range c.children { + if err := child.WriteTo(w); err != nil { + return err + } + } + + return nil +} + +func writeTag(w io.Writer, code byte, len uint64) error { + if len < 0x0E { + // Short form, with length embedded in code byte. + _, err := w.Write([]byte{code | byte(len)}) + return err + } + + // Long form, with separate length. + if _, err := w.Write([]byte{code | 0x0E}); err != nil { + return err + } + _, err := w.Write(packVarUint(uint64(len))) + return err +} diff --git a/bufnode_test.go b/bufnode_test.go new file mode 100644 index 00000000..f2ffe669 --- /dev/null +++ b/bufnode_test.go @@ -0,0 +1,79 @@ +package ion + +import ( + "bytes" + "testing" +) + +func TestBufnode(t *testing.T) { + root := container{code: 0xE0} + root.Add(atom([]byte{0x81, 0x83})) + { + symtab := &container{code: 0xD0} + { + symtab.Add(fieldname(6)) + { + imps := &container{code: 0xB0} + { + imp0 := &container{code: 0xD0} + { + imp0.Add(fieldname(4)) + imp0.Add(atom([]byte{0x85, 'b', 'o', 'g', 'u', 's'})) + imp0.Add(fieldname(5)) + imp0.Add(atom([]byte{0x21, 0x2A})) + imp0.Add(fieldname(8)) + imp0.Add(atom([]byte{0x21, 0x64})) + } + imps.Add(imp0) + } + symtab.Add(imps) + } + + symtab.Add(fieldname(7)) + { + syms := &container{code: 0xB0} + { + syms.Add(atom([]byte{0x83, 'f', 'o', 'o'})) + syms.Add(atom([]byte{0x83, 'b', 'a', 'r'})) + } + symtab.Add(syms) + } + } + root.Add(symtab) + } + + buf := bytes.Buffer{} + if err := root.WriteTo(&buf); err != nil { + t.Fatal(err) + } + + val := buf.Bytes() + eval := []byte{ + // $ion_symbol_table::{ + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, + // imports:[ + 0x86, 0xBE, 0x8E, + // { + 0xDD, + // name: "bogus" + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', + // version: 42 + 0x85, 0x21, 0x2A, + // max_id: 100 + 0x88, 0x21, 0x64, + // } + // ], + // symbols:[ + 0x87, 0xB8, + // "foo", + 0x83, 'f', 'o', 'o', + // "bar" + 0x83, 'b', 'a', 'r', + // ] + // } + } + + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", fmtbytes(eval), fmtbytes(val)) + } +} diff --git a/ctx.go b/ctx.go index f92b2077..572e1d54 100644 --- a/ctx.go +++ b/ctx.go @@ -14,7 +14,7 @@ type ctx struct { stack []ctxType } -// peek returns the current context +// peek returns the current context. func (c *ctx) peek() ctxType { if len(c.stack) == 0 { return ctxAtTopLevel diff --git a/decimal.go b/decimal.go index cf4cf223..3f6fc68e 100644 --- a/decimal.go +++ b/decimal.go @@ -17,18 +17,11 @@ type Decimal struct { scale int } -// NewDecimal creates a new decimal whose value is equal to the given -// (big) integer. -func NewDecimal(n *big.Int) *Decimal { - return NewDecimalWithScale(n, 0) -} - -// NewDecimalWithScale creates a new scaled decimal whose value is -// equal to n * 10^-scale. -func NewDecimalWithScale(n *big.Int, scale int) *Decimal { +// NewDecimal creates a new decimal whose value is equal to n * 10^exp. +func NewDecimal(n *big.Int, exp int) *Decimal { return &Decimal{ n: n, - scale: scale, + scale: -exp, } } @@ -49,7 +42,7 @@ func ParseDecimal(in string) (*Decimal, error) { return nil, errors.New("empty string") } - shift := 0 + exponent := 0 d := strings.IndexAny(in, "Dd") if d != -1 { @@ -64,7 +57,7 @@ func ParseDecimal(in string) (*Decimal, error) { return nil, err } - shift = int(tmp) + exponent = int(tmp) in = in[:d] } @@ -74,7 +67,7 @@ func ParseDecimal(in string) (*Decimal, error) { ipart := in[:d] fpart := in[d+1:] - shift -= len(fpart) + exponent -= len(fpart) in = ipart + fpart } @@ -84,7 +77,12 @@ func ParseDecimal(in string) (*Decimal, error) { return nil, fmt.Errorf("not a valid number: %v", in) } - return NewDecimalWithScale(n, -shift), nil + return NewDecimal(n, exponent), nil +} + +// CoEx returns this decimal's coefficient and exponent. +func (d *Decimal) CoEx() (*big.Int, int) { + return d.n, -d.scale } // Abs returns the absolute value of this Decimal. @@ -168,6 +166,12 @@ func (d *Decimal) ShiftR(shift int) *Decimal { // TODO: Div, Exp, etc? +// Sign returns -1 if the value is less than 0, 0 if it is equal to zero, +// and +1 if it is greater than zero. +func (d *Decimal) Sign() int { + return d.n.Sign() +} + // Cmp compares two decimals, returning -1 if d is smaller, +1 if d is // larger, and 0 if they are equal (ignoring precision). func (d *Decimal) Cmp(o *Decimal) int { diff --git a/symboltable.go b/symboltable.go index f386c358..94403943 100644 --- a/symboltable.go +++ b/symboltable.go @@ -163,9 +163,15 @@ func NewLocalSymbolTable(imports []SharedSymbolTable, symbols []string) SymbolTa } func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []int, int) { - imps := append([]SharedSymbolTable{}, imports...) - - // TODO: Automatically add V1SystemSymbolTable? + var imps []SharedSymbolTable + if len(imports) > 0 && imports[0].Name() == "$ion" { + imps = make([]SharedSymbolTable, len(imports)) + copy(imps, imports) + } else { + imps = make([]SharedSymbolTable, len(imports)+1) + imps[0] = V1SystemSymbolTable + copy(imps[1:], imports) + } maxID := 0 offsets := make([]int, len(imps)) @@ -230,10 +236,11 @@ func (t *localSymbolTable) WriteTo(w Writer) error { w.TypeAnnotation("$ion_symbol_table") w.BeginStruct() - if len(t.imports) > 0 { + if len(t.imports) > 1 { w.FieldName("imports") w.BeginList() - for _, imp := range t.imports { + for i := 1; i < len(t.imports); i++ { + imp := t.imports[i] w.BeginStruct() w.FieldName("name") diff --git a/symboltable_test.go b/symboltable_test.go index bcadca36..57d50669 100644 --- a/symboltable_test.go +++ b/symboltable_test.go @@ -40,37 +40,46 @@ func TestSharedSymbolTable(t *testing.T) { func TestLocalSymbolTable(t *testing.T) { st := NewLocalSymbolTable(nil, []string{"foo", "bar"}) - if st.MaxID() != 2 { + if st.MaxID() != 11 { t.Errorf("wrong maxid: %v", st.MaxID()) } - testFindByName(t, st, "$ion", 0) - testFindByName(t, st, "foo", 1) - testFindByName(t, st, "bar", 2) + testFindByName(t, st, "$ion", 1) + testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bar", 11) testFindByName(t, st, "bogus", 0) testFindByID(t, st, 0, "") - testFindByID(t, st, 1, "foo") - testFindByID(t, st, 2, "bar") - testFindByID(t, st, 3, "") + testFindByID(t, st, 1, "$ion") + testFindByID(t, st, 10, "foo") + testFindByID(t, st, 11, "bar") + testFindByID(t, st, 12, "") testString(t, st, `$ion_symbol_table::{symbols:["foo","bar"]}`) } func TestLocalSymbolTableWithImports(t *testing.T) { - imports := []SharedSymbolTable{V1SystemSymbolTable} - st := NewLocalSymbolTable(imports, []string{ + shared := NewSharedSymbolTable("shared", 1, []string{ "foo", "bar", }) + imports := []SharedSymbolTable{shared} - if st.MaxID() != 11 { // 9 from $ion.1, 2 local. + st := NewLocalSymbolTable(imports, []string{ + "foo2", + "bar2", + }) + + if st.MaxID() != 13 { // 9 from $ion.1, 2 from test.1, 2 local. t.Errorf("wrong maxid: %v", st.MaxID()) } testFindByName(t, st, "$ion", 1) testFindByName(t, st, "$ion_shared_symbol_table", 9) testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bar", 11) + testFindByName(t, st, "foo2", 12) + testFindByName(t, st, "bar2", 13) testFindByName(t, st, "bogus", 0) testFindByID(t, st, 0, "") @@ -78,13 +87,15 @@ func TestLocalSymbolTableWithImports(t *testing.T) { testFindByID(t, st, 9, "$ion_shared_symbol_table") testFindByID(t, st, 10, "foo") testFindByID(t, st, 11, "bar") - testFindByID(t, st, 12, "") + testFindByID(t, st, 12, "foo2") + testFindByID(t, st, 13, "bar2") + testFindByID(t, st, 14, "") - testString(t, st, `$ion_symbol_table::{imports:[{name:"$ion",version:1,max_id:9}],symbols:["foo","bar"]}`) + testString(t, st, `$ion_symbol_table::{imports:[{name:"shared",version:1,max_id:2}],symbols:["foo2","bar2"]}`) } func TestSymbolTableBuilder(t *testing.T) { - b := NewSymbolTableBuilder(V1SystemSymbolTable) + b := NewSymbolTableBuilder() id, ok := b.Add("name") if ok { diff --git a/textwriter_test.go b/textwriter_test.go index c93d8dce..9228d430 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -174,7 +174,7 @@ func TestBool(t *testing.T) { }) } -func TestInt(t *testing.T) { +func TestWriteTextInt(t *testing.T) { expected := "(zero::0 1 -1 (9223372036854775807 -9223372036854775808))" testTextWriter(t, expected, func(w Writer) { w.BeginSexp() @@ -193,7 +193,7 @@ func TestInt(t *testing.T) { }) } -func TestBigInt(t *testing.T) { +func TestWriteTextBigInt(t *testing.T) { expected := "[0,big::18446744073709551616]" testTextWriter(t, expected, func(w Writer) { w.BeginList() From a9ee27bdba3f7aae4f4a12f240b6ede7651242b1 Mon Sep 17 00:00:00 2001 From: David Murray Date: Tue, 6 Aug 2019 21:15:29 +1000 Subject: [PATCH 28/56] stuff and stuff --- binarywriter.go | 6 +++++- marshal.go | 23 +++++++++++++++++++++++ marshal_test.go | 19 +++++++++++++++++++ symboltable.go | 4 ++++ textwriter.go | 2 +- 5 files changed, 52 insertions(+), 2 deletions(-) diff --git a/binarywriter.go b/binarywriter.go index 4080718a..4a163391 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -499,7 +499,11 @@ func (w *binaryWriterLST) WriteBlob(val []byte) { } func (w *binaryWriterLST) WriteValue(val interface{}) { - w.err = errors.New("not yet implemented") + m := Encoder{ + w: w, + sortMaps: true, + } + w.err = m.Encode(val) } func (w *binaryWriterLST) Finish() error { diff --git a/marshal.go b/marshal.go index 9cc7967d..5e4da9c1 100644 --- a/marshal.go +++ b/marshal.go @@ -29,6 +29,21 @@ func MarshalText(v interface{}) ([]byte, error) { return buf.Bytes(), nil } +// MarshalBinary marshals values to binary ion. +func MarshalBinary(v interface{}, lst SymbolTable) ([]byte, error) { + buf := bytes.Buffer{} + m := NewBinaryEncoder(&buf, lst) + + if err := m.Encode(v); err != nil { + return nil, err + } + if err := m.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + // An Encoder writes Ion values to an output stream. type Encoder struct { w Writer @@ -50,6 +65,14 @@ func NewTextEncoder(w io.Writer) *Encoder { } } +// NewBinaryEncoder creates a new Encoder that marshals binary Ion to the given writer. +func NewBinaryEncoder(w io.Writer, lst SymbolTable) *Encoder { + return &Encoder{ + w: NewBinaryWriter(w, lst), + sortMaps: false, + } +} + // Encode marshals the given value to Ion, writing it to the underlying writer. func (m *Encoder) Encode(v interface{}) error { return m.marshalValue(reflect.ValueOf(v)) diff --git a/marshal_test.go b/marshal_test.go index 85703933..abe80e25 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1,6 +1,7 @@ package ion import ( + "bytes" "math" "testing" "time" @@ -65,6 +66,24 @@ func TestMarshalText(t *testing.T) { test(struct{ V [2]byte }{[2]byte{4, 2}}, "{V:[4,2]}") } +func TestMarshalBinary(t *testing.T) { + lst := NewLocalSymbolTable(nil, nil) + + test := func(v interface{}, name string, eval []byte) { + t.Run(name, func(t *testing.T) { + val, err := MarshalBinary(v, lst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected '%v', got '%v'", fmtbytes(eval), fmtbytes(val)) + } + }) + } + + test(nil, "null", []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) +} + func TestMarshalNestedStructs(t *testing.T) { type gp struct { A int `json:"a"` diff --git a/symboltable.go b/symboltable.go index 94403943..b3330481 100644 --- a/symboltable.go +++ b/symboltable.go @@ -233,6 +233,10 @@ func (t *localSymbolTable) findByIDInImports(id int) (string, bool) { } func (t *localSymbolTable) WriteTo(w Writer) error { + if len(t.imports) == 1 && len(t.symbols) == 0 { + return nil + } + w.TypeAnnotation("$ion_symbol_table") w.BeginStruct() diff --git a/textwriter.go b/textwriter.go index 97da5eee..0230537e 100644 --- a/textwriter.go +++ b/textwriter.go @@ -414,7 +414,7 @@ func (w *textWriter) WriteValue(val interface{}) { w: w, sortMaps: true, } - m.Encode(val) + w.err = m.Encode(val) } // Finish finishes the current datagram. From c5e0d4ce48d275874c684feb79f99e1a4e7312cd Mon Sep 17 00:00:00 2001 From: David Murray Date: Fri, 9 Aug 2019 17:22:25 +1000 Subject: [PATCH 29/56] refactoring binary writer --- api.go | 12 +- binarywriter.go | 530 +++++++++++++++++++++++-------------------- binarywriter_test.go | 2 +- bits.go | 181 +++++++++++---- bits_test.go | 125 +++++++--- bufnode.go | 107 +++++---- bufnode_test.go | 32 +-- marshal.go | 2 +- 8 files changed, 611 insertions(+), 380 deletions(-) diff --git a/api.go b/api.go index 4e17a1c6..3b4270b6 100644 --- a/api.go +++ b/api.go @@ -24,20 +24,20 @@ const ( DecimalType // TimestampType is the type of a timestamp. TimestampType - // StringType is the type of a Unicode string. - StringType // SymbolType is the type of an interned string. SymbolType - // BlobType is the type of a binary large object. - BlobType + // StringType is the type of a Unicode string. + StringType // ClobType is the type of a character large object. ClobType - // StructType is the type of a structure. - StructType + // BlobType is the type of a binary large object. + BlobType // ListType is the type of a list. ListType // SexpType is the type of an s-expression. SexpType + // StructType is the type of a structure. + StructType ) func (t Type) String() string { diff --git a/binarywriter.go b/binarywriter.go index 4a163391..d633bb5c 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -1,7 +1,6 @@ package ion import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -11,40 +10,32 @@ import ( "time" ) -// A cstack is a stack of containers. -type cstack struct { - arr []*container -} +// A binaryWriter writes binary ion. +type binaryWriter struct { + writer + bufs bufstack -func (c *cstack) peek() *container { - if len(c.arr) == 0 { - return nil - } - return c.arr[len(c.arr)-1] -} + lst SymbolTable + lstb SymbolTableBuilder -func (c *cstack) push(code byte) { - c.arr = append(c.arr, &container{code: code}) + wroteLST bool } -func (c *cstack) pop() { - if len(c.arr) == 0 { - panic("pop called at top level") +// NewBinaryWriter creates a new binary writer that will construct a +// local symbol table as it is written to. +func NewBinaryWriter(out io.Writer, sts ...SharedSymbolTable) Writer { + return &binaryWriter{ + writer: writer{ + out: out, + }, + lstb: NewSymbolTableBuilder(sts...), } - c.arr = c.arr[:len(c.arr)-1] } -type binaryWriterLST struct { - writer - cs cstack - lst SymbolTable - - wroteLST bool -} - -// NewBinaryWriter creates a new binary writer. -func NewBinaryWriter(out io.Writer, lst SymbolTable) Writer { - return &binaryWriterLST{ +// NewBinaryWriterLST creates a new binary writer with a pre-built local +// symbol table. +func NewBinaryWriterLST(out io.Writer, lst SymbolTable) Writer { + return &binaryWriter{ writer: writer{ out: out, }, @@ -52,42 +43,56 @@ func NewBinaryWriter(out io.Writer, lst SymbolTable) Writer { } } -func (w *binaryWriterLST) write(c bufnode) error { - p := w.cs.peek() - if p == nil { - return c.WriteTo(w.out) +// Emit emits the given node. If we're currently at the top level, that +// means actually emitting to the output stream. If not, we emit append +// to the current bufseq. +func (w *binaryWriter) emit(node bufnode) error { + s := w.bufs.peek() + if s == nil { + return node.EmitTo(w.out) } - p.Add(c) + s.Append(node) return nil } -func (w *binaryWriterLST) writeTag(code byte, len int) error { - buf := bytes.Buffer{} - writeTag(&buf, code, uint64(len)) - return w.write(atom(buf.Bytes())) +// Write emits the given bytes as an atom. +func (w *binaryWriter) write(bs []byte) error { + return w.emit(atom(bs)) } -func (w *binaryWriterLST) writeLST() error { - if _, err := w.out.Write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { - return err - } +// WriteTag writes out a type+length tag. Use me when you've already got the value to +// be written as a []byte and don't want to copy it. +func (w *binaryWriter) writeTag(code byte, len uint64) error { + tl := tagLen(len) - // Prevent recursion... - w.wroteLST = true + tag := make([]byte, tl) + tag = appendTag(tag, code, len) - return w.lst.WriteTo(w) + return w.write(tag) +} + +// WriteLST writes out a local symbol table. +func (w *binaryWriter) writeLST(lst SymbolTable) error { + if err := w.write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { + return err + } + return lst.WriteTo(w) } -func (w *binaryWriterLST) beginValue() error { - // Have to record/empty these before calling writeLST, which - // will end up modifying them. Ugh. +// BeginValue begins the process of writing a value by writing out +// its field name and annotations. +func (w *binaryWriter) beginValue() error { + // We have to record/empty these before calling writeLST, which + // will end up using/modifying them. Ugh. name := w.fieldName w.fieldName = "" tas := w.typeAnnotations w.typeAnnotations = nil - if !w.wroteLST { - if err := w.writeLST(); err != nil { + // If we have a local symbol table and haven't written it out yet, do that now. + if w.lst != nil && !w.wroteLST { + w.wroteLST = true + if err := w.writeLST(w.lst); err != nil { return err } } @@ -105,14 +110,14 @@ func (w *binaryWriterLST) beginValue() error { panic("negative id") } - if err := w.write(fieldname(id)); err != nil { + buf := make([]byte, 0, 10) + buf = appendVarUint(buf, uint64(id)) + if err := w.write(buf); err != nil { return err } } if len(tas) > 0 { - w.cs.push(0xE0) - ids := make([]uint64, len(tas)) idlen := uint64(0) @@ -128,14 +133,16 @@ func (w *binaryWriterLST) beginValue() error { idlen += varUintLen(uint64(id)) } - buf := bytes.Buffer{} - buf.Write(packVarUint(idlen)) + buflen := idlen + varUintLen(idlen) + buf := make([]byte, 0, buflen) + buf = appendVarUint(buf, idlen) for _, id := range ids { - buf.Write(packVarUint(id)) + buf = appendVarUint(buf, id) } - if err := w.write(atom(buf.Bytes())); err != nil { + w.bufs.push(&container{code: 0xE0}) + if err := w.write(buf); err != nil { return err } } @@ -143,32 +150,22 @@ func (w *binaryWriterLST) beginValue() error { return nil } -func (w *binaryWriterLST) endValue() error { - cur := w.cs.peek() - if cur != nil && cur.code == 0xE0 { - // If we're in an annotation container, write it up a level now that we - // know the length of the value. - w.cs.pop() - return w.write(cur) +// EndValue ends the process of writing a value by flushing it and its annotations +// up a level, if needed. +func (w *binaryWriter) endValue() error { + seq := w.bufs.peek() + if seq != nil { + if c, ok := seq.(*container); ok && c.code == 0xE0 { + w.bufs.pop() + return w.emit(seq) + } } return nil } -func (w *binaryWriterLST) writeValue(f func() []byte) error { - if err := w.beginValue(); err != nil { - return err - } - - val := f() - - if err := w.write(atom(val)); err != nil { - return err - } - - return w.endValue() -} - -func (w *binaryWriterLST) writeValueStreaming(f func() error) error { +// WriteValue writes an atomic value, invoking the given function to write the +// actual value contents. +func (w *binaryWriter) writeValue(f func() error) error { if err := w.beginValue(); err != nil { return err } @@ -180,26 +177,29 @@ func (w *binaryWriterLST) writeValueStreaming(f func() error) error { return w.endValue() } -func (w *binaryWriterLST) begin(t ctxType, code byte) error { +// BeginContainer begins writing a new container. +func (w *binaryWriter) beginContainer(t ctxType, code byte) error { if err := w.beginValue(); err != nil { return err } w.ctx.push(t) - w.cs.push(code) + w.bufs.push(&container{code: code}) return nil } -func (w *binaryWriterLST) end(t ctxType) error { +// EndContainer ends writing a container, emitting its buffered contents up +// a level in the stack. +func (w *binaryWriter) endContainer(t ctxType) error { if w.ctx.peek() != t { return errors.New("ion: not in that kind of container") } - cur := w.cs.peek() - if cur != nil { - w.cs.pop() - if err := w.write(cur); err != nil { + seq := w.bufs.peek() + if seq != nil { + w.bufs.pop() + if err := w.emit(seq); err != nil { return err } } @@ -211,115 +211,60 @@ func (w *binaryWriterLST) end(t ctxType) error { return w.endValue() } -func (w *binaryWriterLST) BeginStruct() { - if w.err != nil { - return - } - w.err = w.begin(ctxInStruct, 0xD0) -} - -func (w *binaryWriterLST) EndStruct() { - if w.err != nil { - return - } - w.err = w.end(ctxInStruct) -} - -func (w *binaryWriterLST) BeginList() { - if w.err != nil { - return - } - w.err = w.begin(ctxInList, 0xB0) -} - -func (w *binaryWriterLST) EndList() { - if w.err != nil { - return - } - w.err = w.end(ctxInList) -} - -func (w *binaryWriterLST) BeginSexp() { - if w.err != nil { - return - } - w.err = w.begin(ctxInSexp, 0xC0) -} - -func (w *binaryWriterLST) EndSexp() { - if w.err != nil { - return - } - w.err = w.end(ctxInSexp) -} - -func (w *binaryWriterLST) WriteNull() { +func (w *binaryWriter) WriteNull() { w.WriteNullWithType(NullType) } -func (w *binaryWriterLST) WriteNullWithType(t Type) { +var nulls = func() []byte { + ret := make([]byte, int(StructType)+1) + ret[NoType] = 0x0F + ret[NullType] = 0x0F + ret[BoolType] = 0x1F + ret[IntType] = 0x2F + ret[FloatType] = 0x4F + ret[DecimalType] = 0x5F + ret[TimestampType] = 0x6F + ret[SymbolType] = 0x7F + ret[StringType] = 0x8F + ret[ClobType] = 0x9F + ret[BlobType] = 0xAF + ret[ListType] = 0xBF + ret[SexpType] = 0xCF + ret[StructType] = 0xDF + return ret +}() + +func (w *binaryWriter) WriteNullWithType(t Type) { if w.err != nil { return } - w.err = w.writeValue(func() []byte { - var b byte - switch t { - case NoType, NullType: - b = 0x0F - case BoolType: - b = 0x1F - case IntType: - b = 0x2F - case FloatType: - b = 0x4F - case DecimalType: - b = 0x5F - case TimestampType: - b = 0x6F - case SymbolType: - b = 0x7F - case StringType: - b = 0x8F - case ClobType: - b = 0x9F - case BlobType: - b = 0xAF - case ListType: - b = 0xBF - case SexpType: - b = 0xCF - case StructType: - b = 0xDF - default: - panic(fmt.Sprintf("invalid type: %v", t)) - } - - return []byte{b} + w.err = w.writeValue(func() error { + return w.write([]byte{nulls[t]}) }) } -func (w *binaryWriterLST) WriteBool(val bool) { +func (w *binaryWriter) WriteBool(val bool) { if w.err != nil { return } - w.err = w.writeValue(func() []byte { + w.err = w.writeValue(func() error { if val { - return []byte{0x11} + return w.write([]byte{0x11}) } - return []byte{0x10} + return w.write([]byte{0x10}) }) } -func (w *binaryWriterLST) WriteInt(val int64) { +func (w *binaryWriter) WriteInt(val int64) { if w.err != nil { return } - w.err = w.writeValueStreaming(func() error { + w.err = w.writeValue(func() error { if val == 0 { - return w.write(atom([]byte{0x20})) + return w.write([]byte{0x20}) } code := byte(0x20) @@ -330,24 +275,26 @@ func (w *binaryWriterLST) WriteInt(val int64) { mag = uint64(-val) } - bs := packUint(mag) + len := uintLen(mag) + buflen := len + tagLen(len) - if err := w.writeTag(code, len(bs)); err != nil { - return err - } - return w.write(atom(bs)) + buf := make([]byte, 0, buflen) + buf = appendTag(buf, code, len) + buf = appendUint(buf, mag) + + return w.write(buf) }) } -func (w *binaryWriterLST) WriteBigInt(val *big.Int) { +func (w *binaryWriter) WriteBigInt(val *big.Int) { if w.err != nil { return } - w.err = w.writeValueStreaming(func() error { + w.err = w.writeValue(func() error { sign := val.Sign() if sign == 0 { - return w.write(atom([]byte{0x20})) + return w.write([]byte{0x20}) } code := byte(0x20) @@ -357,21 +304,31 @@ func (w *binaryWriterLST) WriteBigInt(val *big.Int) { bs := val.Bytes() - if err := w.writeTag(code, len(bs)); err != nil { + if bl := uint64(len(bs)); bl < 64 { + buflen := bl + tagLen(bl) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, code, bl) + buf = append(buf, bs...) + return w.write(buf) + } + + // no sense in copying, emit tag separately. + if err := w.writeTag(code, uint64(len(bs))); err != nil { return err } - return w.write(atom(bs)) + return w.write(bs) }) } -func (w *binaryWriterLST) WriteFloat(val float64) { +func (w *binaryWriter) WriteFloat(val float64) { if w.err != nil { return } - w.err = w.writeValue(func() []byte { + w.err = w.writeValue(func() error { if val == 0 { - return []byte{0x40} + return w.write([]byte{0x40}) } bs := make([]byte, 9) @@ -380,125 +337,146 @@ func (w *binaryWriterLST) WriteFloat(val float64) { bits := math.Float64bits(val) binary.BigEndian.PutUint64(bs[1:], bits) - return bs + return w.write(bs) }) } -func (w *binaryWriterLST) WriteDecimal(val *Decimal) { +func (w *binaryWriter) WriteDecimal(val *Decimal) { if w.err != nil { return } - w.writeValueStreaming(func() error { + w.writeValue(func() error { coef, exp := val.CoEx() - ebs := []byte{} + vlen := uint64(0) if exp != 0 { - ebs = packVarInt(int64(exp)) + vlen += varIntLen(int64(exp)) } - - cbs := packBigInt(coef) - - if err := w.writeTag(0x50, len(cbs)+len(ebs)); err != nil { - return err + if coef.Sign() != 0 { + vlen += bigIntLen(coef) } - if len(ebs) > 0 { - if err := w.write(atom(ebs)); err != nil { - return err - } - } + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - if len(cbs) > 0 { - if err := w.write(atom(cbs)); err != nil { - return err - } + buf = appendTag(buf, 0x50, vlen) + if exp != 0 { + buf = appendVarInt(buf, int64(exp)) } + buf = appendBigInt(buf, coef) - return nil + return w.write(buf) }) } -func (w *binaryWriterLST) WriteTimestamp(val time.Time) { +func (w *binaryWriter) WriteTimestamp(val time.Time) { if w.err != nil { return } - w.err = w.writeValueStreaming(func() error { - bs := packTime(val) - if err := w.writeTag(0x60, len(bs)); err != nil { - return err - } - return w.write(atom(bs)) + w.err = w.writeValue(func() error { + _, offset := val.Zone() + offset /= 60 + utc := val.In(time.UTC) + + vlen := timeLen(offset, utc) + buflen := vlen + tagLen(vlen) + + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x60, vlen) + buf = appendTime(buf, offset, utc) + + return w.write(buf) }) } -func (w *binaryWriterLST) WriteSymbol(val string) { +func (w *binaryWriter) WriteSymbol(val string) { if w.err != nil { return } - id, ok := w.lst.FindByName(val) - if !ok { - w.err = fmt.Errorf("ion: symbol '%v' not defined in local symbol table", val) - return - } + id := 0 + ok := true - w.err = w.writeValueStreaming(func() error { - bs := packUint(uint64(id)) - if err := w.writeTag(0x70, len(bs)); err != nil { - return err + if w.lst != nil { + id, ok = w.lst.FindByName(val) + if !ok { + w.err = fmt.Errorf("ion: symbol '%v' not defined in local symbol table", val) + return } - return w.write(atom(bs)) + } else { + id, _ = w.lstb.Add(val) + } + + w.err = w.writeValue(func() error { + vlen := uintLen(uint64(id)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x70, vlen) + buf = appendUint(buf, uint64(id)) + + return w.write(buf) }) } -func (w *binaryWriterLST) WriteString(val string) { +func (w *binaryWriter) WriteString(val string) { if w.err != nil { return } - w.err = w.writeValueStreaming(func() error { + w.err = w.writeValue(func() error { if len(val) == 0 { - return w.write(atom([]byte{0x80})) + return w.write([]byte{0x80}) } - bs := []byte(val) + vlen := uint64(len(val)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - if err := w.writeTag(0x80, len(bs)); err != nil { - return err - } - return w.write(atom(bs)) + buf = appendTag(buf, 0x80, vlen) + buf = append(buf, val...) + + return w.write(buf) }) } -func (w *binaryWriterLST) WriteClob(val []byte) { - if w.err != nil { - return - } +func (w *binaryWriter) WriteClob(val []byte) { + w.writeLob(0x90, val) +} - w.err = w.writeValueStreaming(func() error { - if err := w.writeTag(0x90, len(val)); err != nil { - return err - } - return w.write(atom(val)) - }) +func (w *binaryWriter) WriteBlob(val []byte) { + w.writeLob(0xA0, val) } -func (w *binaryWriterLST) WriteBlob(val []byte) { +func (w *binaryWriter) writeLob(code byte, val []byte) { if w.err != nil { return } - w.err = w.writeValueStreaming(func() error { - if err := w.writeTag(0xA0, len(val)); err != nil { + w.err = w.writeValue(func() error { + vlen := uint64(len(val)) + + if vlen < 64 { + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, code, vlen) + buf = append(buf, val...) + + return w.write(buf) + } + + if err := w.writeTag(code, vlen); err != nil { return err } - return w.write(atom(val)) + return w.write(val) }) } -func (w *binaryWriterLST) WriteValue(val interface{}) { +func (w *binaryWriter) WriteValue(val interface{}) { m := Encoder{ w: w, sortMaps: true, @@ -506,7 +484,49 @@ func (w *binaryWriterLST) WriteValue(val interface{}) { w.err = m.Encode(val) } -func (w *binaryWriterLST) Finish() error { +func (w *binaryWriter) BeginList() { + if w.err != nil { + return + } + w.err = w.beginContainer(ctxInList, 0xB0) +} + +func (w *binaryWriter) EndList() { + if w.err != nil { + return + } + w.err = w.endContainer(ctxInList) +} + +func (w *binaryWriter) BeginSexp() { + if w.err != nil { + return + } + w.err = w.beginContainer(ctxInSexp, 0xC0) +} + +func (w *binaryWriter) EndSexp() { + if w.err != nil { + return + } + w.err = w.endContainer(ctxInSexp) +} + +func (w *binaryWriter) BeginStruct() { + if w.err != nil { + return + } + w.err = w.beginContainer(ctxInStruct, 0xD0) +} + +func (w *binaryWriter) EndStruct() { + if w.err != nil { + return + } + w.err = w.endContainer(ctxInStruct) +} + +func (w *binaryWriter) Finish() error { if w.err != nil { return w.err } @@ -515,7 +535,25 @@ func (w *binaryWriterLST) Finish() error { return w.err } - // TODO: Flush all them buffers mate! + w.fieldName = "" + w.typeAnnotations = nil + w.wroteLST = false + + seq := w.bufs.peek() + if seq != nil { + w.bufs.pop() + if w.bufs.peek() != nil { + panic("at top level but too many bufseqs") + } + + lst := w.lstb.Build() + if err := w.writeLST(lst); err != nil { + return err + } + if w.err = w.emit(seq); w.err != nil { + return w.err + } + } return nil } diff --git a/binarywriter_test.go b/binarywriter_test.go index 3ff1da6b..bcadadb1 100644 --- a/binarywriter_test.go +++ b/binarywriter_test.go @@ -347,7 +347,7 @@ func writeBinary(t *testing.T, f func(w Writer)) []byte { } buf := bytes.Buffer{} - w := NewBinaryWriter(&buf, NewLocalSymbolTable(bogus, []string{ + w := NewBinaryWriterLST(&buf, NewLocalSymbolTable(bogus, []string{ "foo", "bar", })) diff --git a/bits.go b/bits.go index 2209d50a..f3772702 100644 --- a/bits.go +++ b/bits.go @@ -1,11 +1,11 @@ package ion import ( - "bytes" "math/big" "time" ) +// uintLen pre-calculates the length, in bytes, of the given uint value. func uintLen(v uint64) uint64 { len := uint64(1) v >>= 8 @@ -18,7 +18,9 @@ func uintLen(v uint64) uint64 { return len } -func packUint(v uint64) []byte { +// appendUint appends a uint value to the given slice. The reader is +// expected to know how many bytes the value takes up. +func appendUint(b []byte, v uint64) []byte { var buf [8]byte i := 7 @@ -31,12 +33,36 @@ func packUint(v uint64) []byte { v >>= 8 } - return buf[i:] + return append(b, buf[i:]...) } -func packInt(n int64) []byte { +// intLen pre-calculates the length, in bytes, of the given int value. +func intLen(n int64) uint64 { if n == 0 { - return []byte{} + return 0 + } + + mag := uint64(n) + if n < 0 { + mag = uint64(-n) + } + + len := uintLen(mag) + + // If the high bit is a one, we need an extra byte to store the sign bit. + hb := mag >> ((len - 1) * 8) + if hb&0x80 != 0 { + len++ + } + + return len +} + +// appendInt appends a (signed) int to the given slice. The reader is +// expected to know how many bytes the value takes up. +func appendInt(b []byte, n int64) []byte { + if n == 0 { + return b } neg := false @@ -47,38 +73,70 @@ func packInt(n int64) []byte { mag = uint64(-n) } - bits := packUint(mag) - if bits[0]&0x80 != 0 { - bits = append([]byte{0}, bits...) + var buf [8]byte + bits := buf[:0] + bits = appendUint(bits, mag) + + if bits[0]&0x80 == 0 { + // We've got space we can use for the sign bit. + if neg { + bits[0] ^= 0x80 + } + } else { + // We need to add more space. + bit := byte(0) + if neg { + bit = 0x80 + } + b = append(b, bit) } - if neg { - bits[0] ^= 0x80 + return append(b, bits...) +} + +// bigIntLen pre-calculates the length, in bytes, of the given big.Int value. +func bigIntLen(v *big.Int) uint64 { + if v.Sign() == 0 { + return 0 } - return bits + bitl := v.BitLen() + bytel := bitl / 8 + + // Either bitl is evenly divisibly by 8, in which case we need another + // byte for the sign bit, or its not in which case we need to round up + // (but will then have room for the sign bit). + return uint64(bytel) + 1 } -func packBigInt(v *big.Int) []byte { +// appendBigInt appends a (signed) big.Int to the given slice. The reader is +// expected to know how many bytes the value takes up. +func appendBigInt(b []byte, v *big.Int) []byte { sign := v.Sign() if sign == 0 { - return []byte{} + return b } bits := v.Bytes() - if bits[0]&0x80 != 0 { - // Need to make room for the sign bit. - bits = append([]byte{0}, bits...) - } - - if sign < 0 { - bits[0] ^= 0x80 + if bits[0]&0x80 == 0 { + // We've got space we can use for the sign bit. + if sign < 0 { + bits[0] ^= 0x80 + } + } else { + // We need to add more space. + bit := byte(0) + if sign < 0 { + bit = 0x80 + } + b = append(b, bit) } - return bits + return append(b, bits...) } +// varUintLen pre-calculates the length, in bytes, of the given varUint value. func varUintLen(v uint64) uint64 { len := uint64(1) v >>= 7 @@ -91,7 +149,10 @@ func varUintLen(v uint64) uint64 { return len } -func packVarUint(v uint64) []byte { +// appendVarUint appends a variable-length-encoded uint to the given slice. +// Each byte stores seven bits of value; the high bit is a flag marking the +// last byte of the value. +func appendVarUint(b []byte, v uint64) []byte { var buf [10]byte i := 9 @@ -104,9 +165,10 @@ func packVarUint(v uint64) []byte { v >>= 7 } - return buf[i:] + return append(b, buf[i:]...) } +// varIntLen pre-calculates the length, in bytes, of the given varInt value. func varIntLen(v int64) uint64 { mag := uint64(v) if v < 0 { @@ -125,7 +187,10 @@ func varIntLen(v int64) uint64 { return len } -func packVarInt(v int64) []byte { +// appendVarInt appends a variable-length-encoded int to the given slice. +// Most bytes store seven bits of value; the high bit is a flag marking the +// last byte of the value. The first byte additionally stores a sign bit. +func appendVarInt(b []byte, v int64) []byte { var buf [10]byte signbit := byte(0) @@ -138,7 +203,7 @@ func packVarInt(v int64) []byte { next := mag >> 6 if next == 0 { // The whole thing fits in one byte. - return []byte{0x80 | signbit | byte(mag&0x3F)} + return append(b, 0x80|signbit|byte(mag&0x3F)) } i := 9 @@ -156,29 +221,65 @@ func packVarInt(v int64) []byte { i-- buf[i] = signbit | byte(mag&0x3F) - return buf[i:] + return append(b, buf[i:]...) +} + +// tagLen pre-calculates the length, in bytes, of a tag. +func tagLen(len uint64) uint64 { + if len < 0x0E { + return 1 + } + return 1 + varUintLen(len) +} + +// appendTag appends a code+len tag to the given slice. +func appendTag(b []byte, code byte, len uint64) []byte { + if len < 0x0E { + // Short form, with length embedded in the code byte. + return append(b, code|byte(len)) + } + + // Long form, with separate length. + b = append(b, code|0x0E) + return appendVarUint(b, len) } -func packTime(t time.Time) []byte { - _, offset := t.Zone() - utc := t.In(time.UTC) +// timeLen pre-calculates the length, in bytes, of the given time value. +func timeLen(offset int, utc time.Time) uint64 { + ret := varIntLen(int64(offset)) + + // Almost certainly two but let's be safe. + ret += varUintLen(uint64(utc.Year())) + + // Month, day, hour, minute, and second are all guaranteed to be one byte. + ret += 5 + + ns := utc.Nanosecond() + if ns > 0 { + ret++ // varIntLen(-9) + ret += intLen(int64(ns)) + } + + return ret +} - buf := bytes.Buffer{} - buf.Write(packVarInt(int64(offset / 60))) +// appendTime appends a timestamp value +func appendTime(b []byte, offset int, utc time.Time) []byte { + b = appendVarInt(b, int64(offset)) - buf.Write(packVarUint(uint64(utc.Year()))) - buf.Write(packVarUint(uint64(utc.Month()))) - buf.Write(packVarUint(uint64(utc.Day()))) + b = appendVarUint(b, uint64(utc.Year())) + b = appendVarUint(b, uint64(utc.Month())) + b = appendVarUint(b, uint64(utc.Day())) - buf.Write(packVarUint(uint64(utc.Hour()))) - buf.Write(packVarUint(uint64(utc.Minute()))) - buf.Write(packVarUint(uint64(utc.Second()))) + b = appendVarUint(b, uint64(utc.Hour())) + b = appendVarUint(b, uint64(utc.Minute())) + b = appendVarUint(b, uint64(utc.Second())) ns := utc.Nanosecond() if ns > 0 { - buf.Write(packVarInt(-9)) - buf.Write(packInt(int64(ns))) + b = appendVarInt(b, -9) + b = appendInt(b, int64(ns)) } - return buf.Bytes() + return b } diff --git a/bits_test.go b/bits_test.go index 24a6346c..db2c15c4 100644 --- a/bits_test.go +++ b/bits_test.go @@ -6,9 +6,10 @@ import ( "math" "math/big" "testing" + "time" ) -func TestPackUint(t *testing.T) { +func TestAppendUint(t *testing.T) { test := func(val uint64, elen uint64, ebits []byte) { t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { len := uintLen(val) @@ -16,7 +17,7 @@ func TestPackUint(t *testing.T) { t.Errorf("expected len=%v, got len=%v", elen, len) } - bits := packUint(val) + bits := appendUint(nil, val) if !bytes.Equal(bits, ebits) { t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) } @@ -29,53 +30,63 @@ func TestPackUint(t *testing.T) { test(math.MaxUint64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) } -func TestPackInt(t *testing.T) { - test := func(val int64, ebits []byte) { +func TestAppendInt(t *testing.T) { + test := func(val int64, elen uint64, ebits []byte) { t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - bits := packInt(val) + len := intLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendInt(nil, val) if !bytes.Equal(bits, ebits) { t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) } }) } - test(0, []byte{}) - test(0x7F, []byte{0x7F}) - test(-0x7F, []byte{0xFF}) + test(0, 0, []byte{}) + test(0x7F, 1, []byte{0x7F}) + test(-0x7F, 1, []byte{0xFF}) - test(0xFF, []byte{0x00, 0xFF}) - test(-0xFF, []byte{0x80, 0xFF}) + test(0xFF, 2, []byte{0x00, 0xFF}) + test(-0xFF, 2, []byte{0x80, 0xFF}) - test(0x7FFF, []byte{0x7F, 0xFF}) - test(-0x7FFF, []byte{0xFF, 0xFF}) + test(0x7FFF, 2, []byte{0x7F, 0xFF}) + test(-0x7FFF, 2, []byte{0xFF, 0xFF}) - test(math.MaxInt64, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) - test(-math.MaxInt64, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) - test(math.MinInt64, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + test(math.MaxInt64, 8, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + test(-math.MaxInt64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + test(math.MinInt64, 9, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) } -func TestPackBigInt(t *testing.T) { - test := func(val *big.Int, ebits []byte) { +func TestAppendBigInt(t *testing.T) { + test := func(val *big.Int, elen uint64, ebits []byte) { t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - bits := packBigInt(val) + len := bigIntLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendBigInt(nil, val) if !bytes.Equal(bits, ebits) { t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) } }) } - test(big.NewInt(0), []byte{}) - test(big.NewInt(0x7F), []byte{0x7F}) - test(big.NewInt(-0x7F), []byte{0xFF}) + test(big.NewInt(0), 0, []byte{}) + test(big.NewInt(0x7F), 1, []byte{0x7F}) + test(big.NewInt(-0x7F), 1, []byte{0xFF}) - test(big.NewInt(0xFF), []byte{0x00, 0xFF}) - test(big.NewInt(-0xFF), []byte{0x80, 0xFF}) + test(big.NewInt(0xFF), 2, []byte{0x00, 0xFF}) + test(big.NewInt(-0xFF), 2, []byte{0x80, 0xFF}) - test(big.NewInt(0x7FFF), []byte{0x7F, 0xFF}) - test(big.NewInt(-0x7FFF), []byte{0xFF, 0xFF}) + test(big.NewInt(0x7FFF), 2, []byte{0x7F, 0xFF}) + test(big.NewInt(-0x7FFF), 2, []byte{0xFF, 0xFF}) } -func TestPackVarUint(t *testing.T) { +func TestAppendVarUint(t *testing.T) { test := func(val uint64, elen uint64, ebits []byte) { t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { len := varUintLen(val) @@ -83,7 +94,7 @@ func TestPackVarUint(t *testing.T) { t.Errorf("expected len=%v, got len=%v", elen, len) } - bits := packVarUint(val) + bits := appendVarUint(nil, val) if !bytes.Equal(bits, ebits) { t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) } @@ -100,7 +111,7 @@ func TestPackVarUint(t *testing.T) { test(0xFFFFFFFFFFFFFFFF, 10, []byte{0x01, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) } -func TestPackVarInt(t *testing.T) { +func TestAppendVarInt(t *testing.T) { test := func(val int64, elen uint64, ebits []byte) { t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { len := varIntLen(val) @@ -108,7 +119,7 @@ func TestPackVarInt(t *testing.T) { t.Errorf("expected len=%v, got len=%v", elen, len) } - bits := packVarInt(val) + bits := appendVarInt(nil, val) if !bytes.Equal(bits, ebits) { t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) } @@ -136,3 +147,59 @@ func TestPackVarInt(t *testing.T) { test(-math.MaxInt64, 10, []byte{0x40, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) test(math.MinInt64, 10, []byte{0x41, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}) } + +func TestAppendTag(t *testing.T) { + test := func(code byte, vlen uint64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("(%x,%v)", code, vlen), func(t *testing.T) { + len := tagLen(vlen) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendTag(nil, code, vlen) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0x20, 1, 1, []byte{0x21}) + test(0x30, 0x0D, 1, []byte{0x3D}) + test(0x40, 0x0E, 2, []byte{0x4E, 0x8E}) + test(0x50, math.MaxInt64, 10, []byte{0x5E, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) +} + +func TestAppendTime(t *testing.T) { + test := func(val time.Time, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + _, offset := val.Zone() + offset /= 60 + utc := val.In(time.UTC) + + len := timeLen(offset, utc) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendTime(nil, offset, utc) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") + + test(time.Time{}, 7, []byte{0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80}) + test(nowish, 14, []byte{ + 0x04, 0xD8, // offset: +600 minutes (+10:00) + 0x0F, 0xE3, // year: 2019 + 0x88, // month: 8 + 0x84, // day: 4 + 0x88, // hour: 8 utc (18 local) + 0x8F, // minute: 15 + 0xAB, // second: 43 + 0xC9, // exp: -9 + 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 + }) +} diff --git a/bufnode.go b/bufnode.go index befeabe7..b4a30fe1 100644 --- a/bufnode.go +++ b/bufnode.go @@ -15,80 +15,105 @@ import ( // A bufnode is a node in the partially-serialized tree. type bufnode interface { Len() uint64 - WriteTo(w io.Writer) error + EmitTo(w io.Writer) error } -// An atom is a value that has been fully serialized and can be written directly. +// A bufseq is a bufnode that's also an appendable sequence of bufnodes. +type bufseq interface { + bufnode + Append(n bufnode) +} + +var _ bufnode = atom([]byte{}) +var _ bufseq = &datagram{} +var _ bufseq = &container{} + +// An atom is a value that has been fully serialized and can be emitted directly. type atom []byte func (a atom) Len() uint64 { return uint64(len(a)) } -func (a atom) WriteTo(w io.Writer) error { +func (a atom) EmitTo(w io.Writer) error { _, err := w.Write(a) return err } -// A fieldname is the symbol id of a field name inside a struct. -type fieldname uint64 +// A datagram is a sequence of nodes that will be emitted one +// after another. Most notably, used to buffer top-level values +// when we haven't yet finalized the local symbol table. +type datagram struct { + len uint64 + children []bufnode +} -func (f fieldname) Len() uint64 { - return varUintLen(uint64(f)) +func (d *datagram) Append(n bufnode) { + d.len += n.Len() + d.children = append(d.children, n) } -func (f fieldname) WriteTo(w io.Writer) error { - _, err := w.Write(packVarUint(uint64(f))) - return err +func (d *datagram) Len() uint64 { + return d.len } -// A container holds multiple child values and serializes them together with a -// tag and length on demand. -type container struct { - code byte - len uint64 - children []bufnode +func (d *datagram) EmitTo(w io.Writer) error { + for _, child := range d.children { + if err := child.EmitTo(w); err != nil { + return err + } + } + + return nil } -func (c *container) Add(n bufnode) { - c.len += n.Len() - c.children = append(c.children, n) +// A container is a datagram that's preceeded by a code+length tag. +type container struct { + code byte + datagram } func (c *container) Len() uint64 { if c.len < 0x0E { - // Short tag return c.len + 1 } - // Long tag. - return c.len + varUintLen(c.len) + 1 + return c.len + (varUintLen(c.len) + 1) } -func (c *container) WriteTo(w io.Writer) error { - if err := writeTag(w, c.code, c.len); err != nil { - return nil - } +func (c *container) EmitTo(w io.Writer) error { + var arr [11]byte + buf := arr[:0] + buf = appendTag(buf, c.code, c.len) - for _, child := range c.children { - if err := child.WriteTo(w); err != nil { - return err - } + if _, err := w.Write(buf); err != nil { + return err } + return c.datagram.EmitTo(w) +} - return nil +// A bufstack is a stack of bufseqs, more or less matching the +// stack of BeginList/Sexp/Struct calls made on a binaryWriter. +// The top of the stack is the sequence we're currently writing +// values into; when it's popped off, it will be appended to the +// bufseq below it. +type bufstack struct { + arr []bufseq } -func writeTag(w io.Writer, code byte, len uint64) error { - if len < 0x0E { - // Short form, with length embedded in code byte. - _, err := w.Write([]byte{code | byte(len)}) - return err +func (s *bufstack) peek() bufseq { + if len(s.arr) == 0 { + return nil } + return s.arr[len(s.arr)-1] +} - // Long form, with separate length. - if _, err := w.Write([]byte{code | 0x0E}); err != nil { - return err +func (s *bufstack) push(b bufseq) { + s.arr = append(s.arr, b) +} + +func (s *bufstack) pop() { + if len(s.arr) == 0 { + panic("pop called on an empty stack") } - _, err := w.Write(packVarUint(uint64(len))) - return err + s.arr = s.arr[:len(s.arr)-1] } diff --git a/bufnode_test.go b/bufnode_test.go index f2ffe669..d4b6a308 100644 --- a/bufnode_test.go +++ b/bufnode_test.go @@ -7,43 +7,43 @@ import ( func TestBufnode(t *testing.T) { root := container{code: 0xE0} - root.Add(atom([]byte{0x81, 0x83})) + root.Append(atom([]byte{0x81, 0x83})) { symtab := &container{code: 0xD0} { - symtab.Add(fieldname(6)) + symtab.Append(atom([]byte{0x86})) // varUint(6) { imps := &container{code: 0xB0} { imp0 := &container{code: 0xD0} { - imp0.Add(fieldname(4)) - imp0.Add(atom([]byte{0x85, 'b', 'o', 'g', 'u', 's'})) - imp0.Add(fieldname(5)) - imp0.Add(atom([]byte{0x21, 0x2A})) - imp0.Add(fieldname(8)) - imp0.Add(atom([]byte{0x21, 0x64})) + imp0.Append(atom([]byte{0x84})) // varUint(4) + imp0.Append(atom([]byte{0x85, 'b', 'o', 'g', 'u', 's'})) + imp0.Append(atom([]byte{0x85})) // varUint(5) + imp0.Append(atom([]byte{0x21, 0x2A})) + imp0.Append(atom([]byte{0x88})) // varUint(8) + imp0.Append(atom([]byte{0x21, 0x64})) } - imps.Add(imp0) + imps.Append(imp0) } - symtab.Add(imps) + symtab.Append(imps) } - symtab.Add(fieldname(7)) + symtab.Append(atom([]byte{0x87})) // varUint(7) { syms := &container{code: 0xB0} { - syms.Add(atom([]byte{0x83, 'f', 'o', 'o'})) - syms.Add(atom([]byte{0x83, 'b', 'a', 'r'})) + syms.Append(atom([]byte{0x83, 'f', 'o', 'o'})) + syms.Append(atom([]byte{0x83, 'b', 'a', 'r'})) } - symtab.Add(syms) + symtab.Append(syms) } } - root.Add(symtab) + root.Append(symtab) } buf := bytes.Buffer{} - if err := root.WriteTo(&buf); err != nil { + if err := root.EmitTo(&buf); err != nil { t.Fatal(err) } diff --git a/marshal.go b/marshal.go index 5e4da9c1..ae8b7285 100644 --- a/marshal.go +++ b/marshal.go @@ -68,7 +68,7 @@ func NewTextEncoder(w io.Writer) *Encoder { // NewBinaryEncoder creates a new Encoder that marshals binary Ion to the given writer. func NewBinaryEncoder(w io.Writer, lst SymbolTable) *Encoder { return &Encoder{ - w: NewBinaryWriter(w, lst), + w: NewBinaryWriterLST(w, lst), sortMaps: false, } } From cb4e1ba56d36e0f24606b09c3570e4af2c08b94a Mon Sep 17 00:00:00 2001 From: David Murray Date: Sun, 11 Aug 2019 08:16:17 +1000 Subject: [PATCH 30/56] changing some things around here --- api.go | 197 +++++++++++--- api_test.go | 23 ++ binarywriter.go | 138 +++++----- binarywriter_test.go | 37 ++- consts.go | 33 +++ ctx.go | 55 +++- fields.go | 122 +++++++++ marshal.go | 276 ++++++++------------ marshal_test.go | 39 ++- skipper.go | 80 +++--- symboltable.go | 4 +- textreader.go | 594 +++++++++++++++++++++---------------------- textreader_test.go | 49 ++-- textutils.go | 2 - textwriter.go | 32 +-- textwriter_test.go | 47 ++-- tokenizer.go | 96 ++++--- tokenizer_test.go | 17 +- unmarshal.go | 17 +- unmarshal_test.go | 86 +++---- writer.go | 19 +- 21 files changed, 1161 insertions(+), 802 deletions(-) create mode 100644 api_test.go create mode 100644 consts.go create mode 100644 fields.go diff --git a/api.go b/api.go index 3b4270b6..bc0da9ad 100644 --- a/api.go +++ b/api.go @@ -6,40 +6,59 @@ import ( "time" ) -// Type is the type of an Ion Value. +// A Type represents the type of an Ion Value. type Type uint8 const ( - // NoType is returned by a Reader that's not currently pointing at a value. + // NoType is returned by a Reader that is not currently pointing at a value. NoType Type = iota - // NullType is the type of the (unqualified) null value. + + // NullType is the type of the (unqualified) Ion null value. NullType - // BoolType is the type of a boolean, true or false. + + // BoolType is the type of an Ion boolean, true or false. BoolType - // IntType is the type of a signed integer of arbitrary size. + + // IntType is the type of a signed Ion integer of arbitrary size. IntType - // FloatType is the type of a 64-bit floating-point value. + + // FloatType is the type of a fixed-precision Ion floating-point value. FloatType - // DecimalType is the type of an arbitrary-precision decimal value. + + // DecimalType is the type of an arbitrary-precision Ion decimal value. DecimalType - // TimestampType is the type of a timestamp. + + // TimestampType is the type of an arbitrary-precision Ion timestamp. TimestampType - // SymbolType is the type of an interned string. + + // SymbolType is the type of an Ion symbol, mapped to an integer ID by a SymbolTable + // to (potentially) save space. SymbolType - // StringType is the type of a Unicode string. + + // StringType is the type of a non-symbol Unicode string, represented directly. StringType - // ClobType is the type of a character large object. + + // ClobType is the type of a character large object. Like a BlobType, it stores an + // arbitrary sequence of bytes, but it represents them in text form as an escaped-ASCII + // string rather than a base64-encoded string. ClobType - // BlobType is the type of a binary large object. + + // BlobType is the type of a binary large object; a sequence of arbitrary bytes. BlobType - // ListType is the type of a list. + + // ListType is the type of a list, recursively containing zero or more Ion values. ListType - // SexpType is the type of an s-expression. + + // SexpType is the type of an s-expression. Like a ListType, it contains a sequence + // of zero or more Ion values, but with a lisp-like syntax when encoded as text. SexpType - // StructType is the type of a structure. + + // StructType is the type of a structure, recursively containing a sequence of named + // (by an Ion symbol) Ion values. StructType ) +// String implements fmt.Stringer for Type. func (t Type) String() string { switch t { case NoType: @@ -75,47 +94,169 @@ func (t Type) String() string { } } -// IntSize is the size of an integer. +// IntSize returns the size of an integer, allowing you to pick the +// appropriate Reader method to call to retrieve the value without loss. type IntSize uint8 const ( - // NullInt is the size of null.int. + // NullInt is the size of null.int and other things that aren't actually ints. NullInt IntSize = iota - // Int32 is a 32-bit integer. + // Int32 is an integer that can be losslessly stored in an int32. Int32 - // Int64 is a 64-bit integer. + // Int64 is an integer that can be losslessly stored in an int64. Int64 - // BigInt is too big for a 64-bit integer. + // BigInt is an integer that can only be losslessly stored in a big.Int. BigInt ) -// A Reader reads Ion values from an input stream. +// String implements fmt.Stringer for IntSize. +func (i IntSize) String() string { + switch i { + case NullInt: + return "null.int" + case Int32: + return "int32" + case Int64: + return "int64" + case BigInt: + return "big.Int" + default: + return fmt.Sprintf("", uint8(i)) + } +} + +// A Reader reads a stream of Ion values. +// +// The Reader has a logical position within the stream of values, influencing the +// values returnedd from its methods. Initially, the Reader is positioned before the +// first value in the stream. A call to Next advances the Reader to the first value +// in the stream, with subsequent calls advancing to subsequent values. When a call to +// Next moves the Reader to the position after the final value in the stream, it returns +// false, making it easy to loop through the values in a stream. +// +// var r Reader +// for r.Next() { +// // ... +// } +// +// Next also returns false in case of error. This can be distinguished from a legitimate +// end-of-stream by calling Err after exiting the loop. +// +// When positioned on an Ion value, the type of the value can be retrieved by calling +// Type. If it has an associated field name (inside a struct) or annotations, they can +// be read by calling FieldName and Annotations respectively. +// +// For atomic values, an appropriate XxxValue method can be called to read the value. +// For lists, sexps, and structs, you should instead call StepIn to move the Reader in +// to the contained sequence of values. The Reader will initially be positioned before +// the first value in the container. Calling Next without calling StepIn will skip over +// the composite value and return the next value in the outer value stream. +// +// At any point while reading through a composite value, including when Next returns false +// to indicate the end of the contained values, you may call StepOut to move back to the +// outer sequence of values. The Reader will be positioned at the end of the composite value, +// such that a call to Next will move to the immediately-following value (if any). +// +// r := NewTextReaderStr("[foo, bar] [") +// for r.Next() { +// if err := r.StepIn(); err != nil { +// return err +// } +// for r.Next() { +// fmt.Println(r.StringValue()) +// } +// if err := r.StepOut(); err != nil { +// return err +// } +// } +// if err := r.Err(); err != nil { +// return err +// } +// type Reader interface { + + // SymbolTable returns the current symbol table, or nil if there isn't one. + // Text Readers do not, generally speaking, have an associated symbol table. + // Binary Readers do. SymbolTable() SymbolTable + // Next advances the Reader to the next position in the current value stream. + // It returns true if this is the position of an Ion value, and false if it + // is not. On error, it returns false and sets Err. Next() bool - Type() Type + + // Err returns an error if a previous call call to Next has failed. Err() error - FieldName() string - TypeAnnotations() []string + // Type returns the type of the Ion value the Reader is currently positioned on. + // It returns NoType if the Reader is positioned before or after a value. + Type() Type + + // IsNull returns true if the current value is an explicit null. This may be true + // even if the Type is not NullType (for example, null.struct has type Struct). Yes, + // that's a bit confusing. IsNull() bool - IntSize() IntSize + // FieldName returns the field name associated with the current value. It returns + // the empty string if there is no current value or the current value has no field + // name. + FieldName() string + + // Annotations returns the set of annotations associated with the current value. + // It returns nil if there is no current value or the current value has no annotations. + Annotations() []string + // StepIn steps in to the current value if it is a container. It returns an error if there + // is no current value or if the value is not a container. On success, the Reader is + // positioned before the first value in the container. StepIn() error + + // StepOut steps out of the current container value being read. It returns an error if + // this Reader is not currently stepped in to a container. On success, the Reader is + // positioned after the end of the container, but before any subsequent values in the + // stream. StepOut() error + // BoolValue returns the current value as a boolean (if that makes sense). It returns + // an error if the current value is not an Ion bool. BoolValue() (bool, error) + + // IntSize returns the size of integer needed to losslessly represent the current value + // (if that makes sense). It returns an error if the current value is not an Ion int. + IntSize() (IntSize, error) + + // IntValue returns the current value as a 32-bit integer (if that makes sense). It + // returns an error if the current value is not an Ion integer or requires more than + // 32 bits to represent losslessly. IntValue() (int, error) + + // Int64Value returns the current value as a 64-bit integer (if that makes sense). It + // returns an error if the current value is not an Ion integer or requires more than + // 64 bits to represent losslessly. Int64Value() (int64, error) + + // BigIntValue returns the current value as a big.Integer (if that makes sense). It + // returns an error if the current value is not an Ion integer. BigIntValue() (*big.Int, error) + + // FloatValue returns the current value as a 64-bit floating point number (if that + // makes sense). It returns an error if the current value is not an Ion float. FloatValue() (float64, error) + + // DecimalValue returns the current value as an arbitrary-precision Decimal (if that + // makes sense). It returns an error if the current value is not an Ion decimal. DecimalValue() (*Decimal, error) + // TimeValue returns the current value as a timestamp (if that makes sense). It returns + // an error if the current value is not an Ion timestamp. TimeValue() (time.Time, error) + + // StringValue returns the current value as a string (if that makes sense). It returns + // an error if the current value is not an Ion symbol or an Ion string. StringValue() (string, error) + // ByteValue returns the current value as a byte slice (if that makes sense). It returns + // an error if the current value is not an Ion clob or an Ion blob. ByteValue() ([]byte, error) } @@ -127,8 +268,8 @@ type Writer interface { Err() error FieldName(val string) - TypeAnnotation(val string) - TypeAnnotations(vals ...string) + Annotation(val string) + Annotations(vals ...string) BeginStruct() EndStruct() @@ -157,7 +298,5 @@ type Writer interface { WriteBlob(val []byte) WriteClob(val []byte) - WriteValue(val interface{}) - Finish() error } diff --git a/api_test.go b/api_test.go new file mode 100644 index 00000000..bbc50a82 --- /dev/null +++ b/api_test.go @@ -0,0 +1,23 @@ +package ion + +import ( + "testing" +) + +func TestTypeToString(t *testing.T) { + for i := NoType; i <= StructType+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected a non-empty string for type %v", uint8(i)) + } + } +} + +func TestIntSizeToString(t *testing.T) { + for i := NullInt; i <= BigInt+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected a non-empty string for size %v", uint8(i)) + } + } +} diff --git a/binarywriter.go b/binarywriter.go index d633bb5c..b509b289 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -24,12 +24,14 @@ type binaryWriter struct { // NewBinaryWriter creates a new binary writer that will construct a // local symbol table as it is written to. func NewBinaryWriter(out io.Writer, sts ...SharedSymbolTable) Writer { - return &binaryWriter{ + w := &binaryWriter{ writer: writer{ out: out, }, lstb: NewSymbolTableBuilder(sts...), } + w.bufs.push(&datagram{}) + return w } // NewBinaryWriterLST creates a new binary writer with a pre-built local @@ -65,7 +67,7 @@ func (w *binaryWriter) write(bs []byte) error { func (w *binaryWriter) writeTag(code byte, len uint64) error { tl := tagLen(len) - tag := make([]byte, tl) + tag := make([]byte, 0, tl) tag = appendTag(tag, code, len) return w.write(tag) @@ -86,8 +88,8 @@ func (w *binaryWriter) beginValue() error { // will end up using/modifying them. Ugh. name := w.fieldName w.fieldName = "" - tas := w.typeAnnotations - w.typeAnnotations = nil + as := w.annotations + w.annotations = nil // If we have a local symbol table and haven't written it out yet, do that now. if w.lst != nil && !w.wroteLST { @@ -102,35 +104,30 @@ func (w *binaryWriter) beginValue() error { return errors.New("ion: field name not set") } - id, ok := w.lst.FindByName(name) - if !ok { - return fmt.Errorf("ion: symbol '%v' not defined", name) - } - if id < 0 { - panic("negative id") + id, err := w.resolve(name) + if err != nil { + return err } buf := make([]byte, 0, 10) - buf = appendVarUint(buf, uint64(id)) + buf = appendVarUint(buf, id) if err := w.write(buf); err != nil { return err } } - if len(tas) > 0 { - ids := make([]uint64, len(tas)) + if len(as) > 0 { + ids := make([]uint64, len(as)) idlen := uint64(0) - for i, a := range tas { - id, ok := w.lst.FindByName(a) - if !ok { - return fmt.Errorf("ion: symbol '%v' not defined", a) - } - if id < 0 { - panic("negative id") + for i, a := range as { + id, err := w.resolve(a) + if err != nil { + return err } - ids[i] = uint64(id) - idlen += varUintLen(uint64(id)) + + ids[i] = id + idlen += varUintLen(id) } buflen := idlen + varUintLen(idlen) @@ -178,7 +175,7 @@ func (w *binaryWriter) writeValue(f func() error) error { } // BeginContainer begins writing a new container. -func (w *binaryWriter) beginContainer(t ctxType, code byte) error { +func (w *binaryWriter) beginContainer(t ctx, code byte) error { if err := w.beginValue(); err != nil { return err } @@ -191,7 +188,7 @@ func (w *binaryWriter) beginContainer(t ctxType, code byte) error { // EndContainer ends writing a container, emitting its buffered contents up // a level in the stack. -func (w *binaryWriter) endContainer(t ctxType) error { +func (w *binaryWriter) endContainer(t ctx) error { if w.ctx.peek() != t { return errors.New("ion: not in that kind of container") } @@ -205,45 +202,29 @@ func (w *binaryWriter) endContainer(t ctxType) error { } w.fieldName = "" - w.typeAnnotations = nil + w.annotations = nil w.ctx.pop() return w.endValue() } +// WriteNull writes an untyped null. func (w *binaryWriter) WriteNull() { w.WriteNullWithType(NullType) } -var nulls = func() []byte { - ret := make([]byte, int(StructType)+1) - ret[NoType] = 0x0F - ret[NullType] = 0x0F - ret[BoolType] = 0x1F - ret[IntType] = 0x2F - ret[FloatType] = 0x4F - ret[DecimalType] = 0x5F - ret[TimestampType] = 0x6F - ret[SymbolType] = 0x7F - ret[StringType] = 0x8F - ret[ClobType] = 0x9F - ret[BlobType] = 0xAF - ret[ListType] = 0xBF - ret[SexpType] = 0xCF - ret[StructType] = 0xDF - return ret -}() - +// WriteNullWithType writes a typed null. func (w *binaryWriter) WriteNullWithType(t Type) { if w.err != nil { return } w.err = w.writeValue(func() error { - return w.write([]byte{nulls[t]}) + return w.write([]byte{binaryNulls[t]}) }) } +// WriteBool writes a bool. func (w *binaryWriter) WriteBool(val bool) { if w.err != nil { return @@ -257,6 +238,7 @@ func (w *binaryWriter) WriteBool(val bool) { }) } +// WriteInt writes an integer. func (w *binaryWriter) WriteInt(val int64) { if w.err != nil { return @@ -286,6 +268,7 @@ func (w *binaryWriter) WriteInt(val int64) { }) } +// WriteBigInt writes a big integer. func (w *binaryWriter) WriteBigInt(val *big.Int) { if w.err != nil { return @@ -304,7 +287,8 @@ func (w *binaryWriter) WriteBigInt(val *big.Int) { bs := val.Bytes() - if bl := uint64(len(bs)); bl < 64 { + bl := uint64(len(bs)) + if bl < 64 { buflen := bl + tagLen(bl) buf := make([]byte, 0, buflen) @@ -314,13 +298,14 @@ func (w *binaryWriter) WriteBigInt(val *big.Int) { } // no sense in copying, emit tag separately. - if err := w.writeTag(code, uint64(len(bs))); err != nil { + if err := w.writeTag(code, bl); err != nil { return err } return w.write(bs) }) } +// WriteFloat writes a floating-point value. func (w *binaryWriter) WriteFloat(val float64) { if w.err != nil { return @@ -341,6 +326,7 @@ func (w *binaryWriter) WriteFloat(val float64) { }) } +// WriteDecimal writes a decimal value. func (w *binaryWriter) WriteDecimal(val *Decimal) { if w.err != nil { return @@ -370,6 +356,7 @@ func (w *binaryWriter) WriteDecimal(val *Decimal) { }) } +// WriteTimestamp writes a timestamp value. func (w *binaryWriter) WriteTimestamp(val time.Time) { if w.err != nil { return @@ -392,22 +379,16 @@ func (w *binaryWriter) WriteTimestamp(val time.Time) { }) } +// WriteSymbol writes a symbol value. func (w *binaryWriter) WriteSymbol(val string) { if w.err != nil { return } - id := 0 - ok := true - - if w.lst != nil { - id, ok = w.lst.FindByName(val) - if !ok { - w.err = fmt.Errorf("ion: symbol '%v' not defined in local symbol table", val) - return - } - } else { - id, _ = w.lstb.Add(val) + id, err := w.resolve(val) + if err != nil { + w.err = err + return } w.err = w.writeValue(func() error { @@ -422,6 +403,27 @@ func (w *binaryWriter) WriteSymbol(val string) { }) } +// Resolve resolves a symbol to its ID. +func (w *binaryWriter) resolve(sym string) (uint64, error) { + if w.lst != nil { + id, ok := w.lst.FindByName(sym) + if !ok { + return 0, fmt.Errorf("ion: symbol '%v' not defined in local symbol table", sym) + } + if id < 0 { + panic("negative id") + } + return uint64(id), nil + } + + id, _ := w.lstb.Add(sym) + if id < 0 { + panic("negative id") + } + return uint64(id), nil +} + +// WriteString writes a string. func (w *binaryWriter) WriteString(val string) { if w.err != nil { return @@ -443,14 +445,17 @@ func (w *binaryWriter) WriteString(val string) { }) } +// WriteClob writes a clob. func (w *binaryWriter) WriteClob(val []byte) { w.writeLob(0x90, val) } +// WriteBlob writes a blob. func (w *binaryWriter) WriteBlob(val []byte) { w.writeLob(0xA0, val) } +// WriteLob writes a [bc]lob. func (w *binaryWriter) writeLob(code byte, val []byte) { if w.err != nil { return @@ -476,14 +481,7 @@ func (w *binaryWriter) writeLob(code byte, val []byte) { }) } -func (w *binaryWriter) WriteValue(val interface{}) { - m := Encoder{ - w: w, - sortMaps: true, - } - w.err = m.Encode(val) -} - +// BeginList begins writing a list. func (w *binaryWriter) BeginList() { if w.err != nil { return @@ -491,6 +489,7 @@ func (w *binaryWriter) BeginList() { w.err = w.beginContainer(ctxInList, 0xB0) } +// EndList finishes writing a list. func (w *binaryWriter) EndList() { if w.err != nil { return @@ -498,6 +497,7 @@ func (w *binaryWriter) EndList() { w.err = w.endContainer(ctxInList) } +// BeginSexp begins writing an s-expression. func (w *binaryWriter) BeginSexp() { if w.err != nil { return @@ -505,6 +505,7 @@ func (w *binaryWriter) BeginSexp() { w.err = w.beginContainer(ctxInSexp, 0xC0) } +// EndSexp finishes writing an s-expression. func (w *binaryWriter) EndSexp() { if w.err != nil { return @@ -512,6 +513,7 @@ func (w *binaryWriter) EndSexp() { w.err = w.endContainer(ctxInSexp) } +// BeginStruct begins writing a struct. func (w *binaryWriter) BeginStruct() { if w.err != nil { return @@ -519,6 +521,7 @@ func (w *binaryWriter) BeginStruct() { w.err = w.beginContainer(ctxInStruct, 0xD0) } +// EndStruct finishes writing a struct. func (w *binaryWriter) EndStruct() { if w.err != nil { return @@ -526,6 +529,7 @@ func (w *binaryWriter) EndStruct() { w.err = w.endContainer(ctxInStruct) } +// Finish finishes writing a datagram. func (w *binaryWriter) Finish() error { if w.err != nil { return w.err @@ -536,7 +540,7 @@ func (w *binaryWriter) Finish() error { } w.fieldName = "" - w.typeAnnotations = nil + w.annotations = nil w.wroteLST = false seq := w.bufs.peek() diff --git a/binarywriter_test.go b/binarywriter_test.go index bcadadb1..eaa6c933 100644 --- a/binarywriter_test.go +++ b/binarywriter_test.go @@ -23,11 +23,11 @@ func TestWriteBinaryStruct(t *testing.T) { w.BeginStruct() w.EndStruct() - w.TypeAnnotation("foo") + w.Annotation("foo") w.BeginStruct() { w.FieldName("name") - w.TypeAnnotation("bar") + w.Annotation("bar") w.WriteNull() w.FieldName("max_id") @@ -49,10 +49,10 @@ func TestWriteBinarySexp(t *testing.T) { w.BeginSexp() w.EndSexp() - w.TypeAnnotation("foo") + w.Annotation("foo") w.BeginSexp() { - w.TypeAnnotation("bar") + w.Annotation("bar") w.WriteNull() w.WriteInt(0) @@ -73,10 +73,10 @@ func TestWriteBinaryList(t *testing.T) { w.BeginList() w.EndList() - w.TypeAnnotation("foo") + w.Annotation("foo") w.BeginList() { - w.TypeAnnotation("bar") + w.Annotation("bar") w.WriteNull() w.WriteInt(0) @@ -96,6 +96,16 @@ func TestWriteBinaryBlob(t *testing.T) { }) } +func TestWriteLargeBinaryBlob(t *testing.T) { + eval := make([]byte, 131) + eval[0] = 0xAE + eval[1] = 0x01 + eval[2] = 0x80 + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBlob(make([]byte, 128)) + }) +} + func TestWriteBinaryClob(t *testing.T) { eval := []byte{ 0x90, @@ -216,6 +226,19 @@ func TestWriteBinaryBigInts(t *testing.T) { }) } +func TestWriteBinaryReallyBigInts(t *testing.T) { + eval := []byte{ + 0x2E, 0x01, 0x80, // 128-byte positive integer + 0x80, // high bit set + } + eval = append(eval, make([]byte, 127)...) + testBinaryWriter(t, eval, func(w Writer) { + i := new(big.Int) + i = i.SetBit(i, 1023, 1) + w.WriteBigInt(i) + }) +} + func TestWriteBinaryInts(t *testing.T) { eval := []byte{ 0x20, // 0 @@ -246,7 +269,7 @@ func TestWriteBinaryBoolAnnotated(t *testing.T) { } testBinaryWriter(t, eval, func(w Writer) { - w.TypeAnnotations("name", "version") + w.Annotations("name", "version") w.WriteBool(false) }) } diff --git a/consts.go b/consts.go new file mode 100644 index 00000000..e14ba0a2 --- /dev/null +++ b/consts.go @@ -0,0 +1,33 @@ +package ion + +import ( + "reflect" + "time" +) + +var binaryNulls = func() []byte { + ret := make([]byte, int(StructType)+1) + ret[NoType] = 0x0F + ret[NullType] = 0x0F + ret[BoolType] = 0x1F + ret[IntType] = 0x2F + ret[FloatType] = 0x4F + ret[DecimalType] = 0x5F + ret[TimestampType] = 0x6F + ret[SymbolType] = 0x7F + ret[StringType] = 0x8F + ret[ClobType] = 0x9F + ret[BlobType] = 0xAF + ret[ListType] = 0xBF + ret[SexpType] = 0xCF + ret[StructType] = 0xDF + return ret +}() + +var hexChars = []byte{ + '0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', +} + +var timeType = reflect.TypeOf(time.Time{}) +var decimalType = reflect.TypeOf(Decimal{}) diff --git a/ctx.go b/ctx.go index 572e1d54..d3a1b9ed 100644 --- a/ctx.go +++ b/ctx.go @@ -1,36 +1,65 @@ package ion -type ctxType byte +import "fmt" + +// ctx is the current reader or writer context. +type ctx uint8 const ( - ctxAtTopLevel ctxType = iota + ctxAtTopLevel ctx = iota ctxInStruct ctxInList ctxInSexp ) -// ctx is a context stack. -type ctx struct { - stack []ctxType +func ctxToContainerType(c ctx) Type { + switch c { + case ctxInList: + return ListType + case ctxInSexp: + return SexpType + case ctxInStruct: + return StructType + default: + return NoType + } +} + +func containerTypeToCtx(t Type) ctx { + switch t { + case ListType: + return ctxInList + case SexpType: + return ctxInSexp + case StructType: + return ctxInStruct + default: + panic(fmt.Sprintf("type %v is not a container type", t)) + } +} + +// ctxstack is a context stack. +type ctxstack struct { + arr []ctx } // peek returns the current context. -func (c *ctx) peek() ctxType { - if len(c.stack) == 0 { +func (c *ctxstack) peek() ctx { + if len(c.arr) == 0 { return ctxAtTopLevel } - return c.stack[len(c.stack)-1] + return c.arr[len(c.arr)-1] } // push pushes a new context onto the stack. -func (c *ctx) push(ctx ctxType) { - c.stack = append(c.stack, ctx) +func (c *ctxstack) push(ctx ctx) { + c.arr = append(c.arr, ctx) } // pop pops the top context off the stack. -func (c *ctx) pop() { - if len(c.stack) == 0 { +func (c *ctxstack) pop() { + if len(c.arr) == 0 { panic("pop called at top level") } - c.stack = c.stack[:len(c.stack)-1] + c.arr = c.arr[:len(c.arr)-1] } diff --git a/fields.go b/fields.go new file mode 100644 index 00000000..061113d9 --- /dev/null +++ b/fields.go @@ -0,0 +1,122 @@ +package ion + +import ( + "fmt" + "reflect" + "strings" +) + +// A field is a reflectively-accessed field of a struct type. +type field struct { + name string + typ reflect.Type + path []int + omitEmpty bool +} + +// A fielder maps out the fields of a type. +type fielder struct { + fields []field + index map[string]bool +} + +// FieldsFor returns the fields of the given struct type. +// TODO: cache me. +func fieldsFor(t reflect.Type) []field { + fldr := fielder{index: map[string]bool{}} + fldr.inspect(t, nil) + return fldr.fields +} + +// Inspect recursively inspects a type to determine all of its fields. +func (f *fielder) inspect(t reflect.Type, path []int) { + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + if !visible(&sf) { + // Skip non-visible fields. + continue + } + + tag := sf.Tag.Get("json") + if tag == "-" { + // Skip fields that are explicitly hidden by tag. + continue + } + name, opts := parseTag(tag) + + newpath := make([]int, len(path)+1) + copy(newpath, path) + newpath[len(path)] = i + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + + if name == "" && sf.Anonymous && ft.Kind() == reflect.Struct { + // Dig in to the embedded struct. + f.inspect(ft, newpath) + } else { + // Add this named field. + if name == "" { + name = sf.Name + } + + if f.index[name] { + panic(fmt.Sprintf("too many fields named %v", name)) + } + f.index[name] = true + + f.fields = append(f.fields, field{ + name: name, + typ: ft, + path: newpath, + omitEmpty: omitEmpty(opts), + }) + } + } +} + +// Visible returns true if the given StructField should show up in the output. +func visible(sf *reflect.StructField) bool { + exported := sf.PkgPath == "" + if sf.Anonymous { + t := sf.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + // Fields of embedded structs are visible even if the struct type itself is not. + return true + } + } + return exported +} + +// ParseTag parses a `json:"..."` field tag, returning the name and opts. +func parseTag(tag string) (string, string) { + if idx := strings.Index(tag, ","); idx != -1 { + // Ignore additional JSON options, at least for now. + return tag[:idx], tag[idx+1:] + } + return tag, "" +} + +// OmitEmpty returns true if opts includes "omitempty". +func omitEmpty(opts string) bool { + for opts != "" { + var o string + + i := strings.Index(opts, ",") + if i >= 0 { + o, opts = opts[:i], opts[i+1:] + } else { + o, opts = opts, "" + } + + if o == "omitempty" { + return true + } + } + return false +} diff --git a/marshal.go b/marshal.go index ae8b7285..d0d1dcc4 100644 --- a/marshal.go +++ b/marshal.go @@ -7,37 +7,52 @@ import ( "math/big" "reflect" "sort" - "strings" "time" ) +// EncoderOpts holds bit-flag options for an Encoder. +type EncoderOpts uint + +const ( + // EncodeSortMaps instructs the encoder to write map keys in sorted order. + EncodeSortMaps EncoderOpts = 1 +) + // MarshalText marshals values to text ion. func MarshalText(v interface{}) ([]byte, error) { - buf := bytes.Buffer{} - m := Encoder{ - w: NewTextWriterOpts(&buf, OptQuietFinish), - sortMaps: true, - } + return marshal(func(w io.Writer) Writer { + return NewTextWriterOpts(w, TextWriterQuietFinish) + }, EncodeSortMaps, v) +} - if err := m.Encode(v); err != nil { - return nil, err - } - if err := m.Finish(); err != nil { - return nil, err - } +// MarshalBinary marshals values to binary ion. +func MarshalBinary(v interface{}, ssts ...SharedSymbolTable) ([]byte, error) { + return marshal(func(w io.Writer) Writer { + return NewBinaryWriter(w, ssts...) + }, 0, v) +} - return buf.Bytes(), nil +// MarshalBinaryLST marshals values to binary ion with a fixed local symbol table. +func MarshalBinaryLST(v interface{}, lst SymbolTable) ([]byte, error) { + return marshal(func(w io.Writer) Writer { + return NewBinaryWriterLST(w, lst) + }, 0, v) } -// MarshalBinary marshals values to binary ion. -func MarshalBinary(v interface{}, lst SymbolTable) ([]byte, error) { +// marshal marshals a value using the given writer type. +func marshal(wf func(io.Writer) Writer, opts EncoderOpts, v interface{}) ([]byte, error) { buf := bytes.Buffer{} - m := NewBinaryEncoder(&buf, lst) + w := wf(&buf) + + e := Encoder{ + w: w, + opts: opts, + } - if err := m.Encode(v); err != nil { + if err := e.Encode(v); err != nil { return nil, err } - if err := m.Finish(); err != nil { + if err := e.Finish(); err != nil { return nil, err } @@ -46,36 +61,60 @@ func MarshalBinary(v interface{}, lst SymbolTable) ([]byte, error) { // An Encoder writes Ion values to an output stream. type Encoder struct { - w Writer - sortMaps bool + w Writer + opts EncoderOpts } // NewEncoder creates a new encoder. func NewEncoder(w Writer) *Encoder { + return NewEncoderOpts(w, 0) +} + +// NewEncoderOpts creates a new encoder with the specified options. +func NewEncoderOpts(w Writer, opts EncoderOpts) *Encoder { return &Encoder{ - w: w, + w: w, + opts: opts, } } -// NewTextEncoder creates a new Encoder that marshals text Ion to the given writer. +// NewTextEncoder creates a new text Encoder. func NewTextEncoder(w io.Writer) *Encoder { - return &Encoder{ - w: NewTextWriter(w), - sortMaps: true, + return NewEncoder(NewTextWriter(w)) +} + +// NewBinaryEncoder creates a new binary Encoder. +func NewBinaryEncoder(w io.Writer, ssts ...SharedSymbolTable) *Encoder { + return NewEncoder(NewBinaryWriter(w, ssts...)) +} + +// NewBinaryEncoderLST creates a new binary Encoder with a fixed local symbol table. +func NewBinaryEncoderLST(w io.Writer, lst SymbolTable) *Encoder { + return NewEncoder(NewBinaryWriterLST(w, lst)) +} + +// EncodeTo encodes the given value to the given writer. It does +// not call Finish, so is suitable for encoding values inside of +// a partially-constructed Ion value. +func EncodeTo(w Writer, v interface{}) error { + e := Encoder{ + w: w, } + return e.Encode(v) } -// NewBinaryEncoder creates a new Encoder that marshals binary Ion to the given writer. -func NewBinaryEncoder(w io.Writer, lst SymbolTable) *Encoder { - return &Encoder{ - w: NewBinaryWriterLST(w, lst), - sortMaps: false, +// EncodeToOpts is like EncodeTo but accepts additional opts. +func EncodeToOpts(w Writer, opts EncoderOpts, v interface{}) error { + e := Encoder{ + w: w, + opts: opts, } + return e.Encode(v) } // Encode marshals the given value to Ion, writing it to the underlying writer. func (m *Encoder) Encode(v interface{}) error { - return m.marshalValue(reflect.ValueOf(v)) + return m.encodeValue(reflect.ValueOf(v)) } // Finish finishes writing the current Ion datagram. @@ -83,7 +122,8 @@ func (m *Encoder) Finish() error { return m.w.Finish() } -func (m *Encoder) marshalValue(v reflect.Value) error { +// EncodeValue recursively encodes a value. +func (m *Encoder) encodeValue(v reflect.Value) error { if !v.IsValid() { m.w.WriteNull() return nil @@ -118,34 +158,37 @@ func (m *Encoder) marshalValue(v reflect.Value) error { return m.w.Err() case reflect.Interface, reflect.Ptr: - return m.marshalPtr(v) + return m.encodePtr(v) case reflect.Struct: - return m.marshalStruct(v) + return m.encodeStruct(v) case reflect.Map: - return m.marshalMap(v) + return m.encodeMap(v) case reflect.Slice: - return m.marshalSlice(v) + return m.encodeSlice(v) case reflect.Array: - return m.marshalArray(v) + return m.encodeArray(v) default: return fmt.Errorf("ion: unsupported type: %v", v.Type().String()) } } -func (m *Encoder) marshalPtr(v reflect.Value) error { +// EncodePtr encodes an Ion null if the pointer is nil, and otherwise encodes the value that +// the pointer is pointing to. +func (m *Encoder) encodePtr(v reflect.Value) error { if v.IsNil() { m.w.WriteNull() return m.w.Err() } - return m.marshalValue(v.Elem()) + return m.encodeValue(v.Elem()) } -func (m *Encoder) marshalMap(v reflect.Value) error { +// EncodeMap encodes a map to the output writer as an Ion struct. +func (m *Encoder) encodeMap(v reflect.Value) error { if v.IsNil() { m.w.WriteNull() return m.w.Err() @@ -153,18 +196,15 @@ func (m *Encoder) marshalMap(v reflect.Value) error { m.w.BeginStruct() - keys := getKeys(v) - if m.sortMaps { - // We do this for text Ion because json.Marshal does, and it's useful for testing. - // For binary Ion, skip it and write things in whatever order they come back from - // the map. + keys := keysFor(v) + if m.opts&EncodeSortMaps != 0 { sort.Slice(keys, func(i, j int) bool { return keys[i].s < keys[j].s }) } for _, key := range keys { m.w.FieldName(key.s) value := v.MapIndex(key.v) - if err := m.marshalValue(value); err != nil { + if err := m.encodeValue(value); err != nil { return err } } @@ -173,12 +213,14 @@ func (m *Encoder) marshalMap(v reflect.Value) error { return m.w.Err() } +// A mapkey holds the reflective map key value as well as its stringified form. type mapkey struct { v reflect.Value s string } -func getKeys(v reflect.Value) []mapkey { +// KeysFor returns the stringified keys for the given map. +func keysFor(v reflect.Value) []mapkey { keys := v.MapKeys() res := make([]mapkey, len(keys)) @@ -196,9 +238,10 @@ func getKeys(v reflect.Value) []mapkey { return res } -func (m *Encoder) marshalSlice(v reflect.Value) error { +// EncodeSlice encodes a slice to the output writer as an appropriate Ion type. +func (m *Encoder) encodeSlice(v reflect.Value) error { if v.Type().Elem().Kind() == reflect.Uint8 { - return m.marshalBlob(v) + return m.encodeBlob(v) } if v.IsNil() { @@ -206,10 +249,11 @@ func (m *Encoder) marshalSlice(v reflect.Value) error { return m.w.Err() } - return m.marshalArray(v) + return m.encodeArray(v) } -func (m *Encoder) marshalBlob(v reflect.Value) error { +// EncodeBlob encodes a []byte to the output writer as an Ion blob. +func (m *Encoder) encodeBlob(v reflect.Value) error { if v.IsNil() { m.w.WriteNull() } else { @@ -218,11 +262,12 @@ func (m *Encoder) marshalBlob(v reflect.Value) error { return m.w.Err() } -func (m *Encoder) marshalArray(v reflect.Value) error { +// EncodeArray encodes an array to the output writer as an Ion list. +func (m *Encoder) encodeArray(v reflect.Value) error { m.w.BeginList() for i := 0; i < v.Len(); i++ { - if err := m.marshalValue(v.Index(i)); err != nil { + if err := m.encodeValue(v.Index(i)); err != nil { return err } } @@ -231,15 +276,14 @@ func (m *Encoder) marshalArray(v reflect.Value) error { return m.w.Err() } -var decimalType = reflect.TypeOf(Decimal{}) - -func (m *Encoder) marshalStruct(v reflect.Value) error { +// EncodeStruct encodes a struct to the output writer as an Ion struct. +func (m *Encoder) encodeStruct(v reflect.Value) error { t := v.Type() if t == timeType { - return m.marshalTime(v) + return m.encodeTime(v) } if t == decimalType { - return m.marshalDecimal(v) + return m.encodeDecimal(v) } fields := fieldsFor(v.Type()) @@ -266,7 +310,7 @@ FieldLoop: } m.w.FieldName(f.name) - if err := m.marshalValue(fv); err != nil { + if err := m.encodeValue(fv); err != nil { return err } } @@ -275,18 +319,21 @@ FieldLoop: return m.w.Err() } -func (m *Encoder) marshalTime(v reflect.Value) error { +// EncodeTime encodes a time.Time to the output writer as an Ion timestamp. +func (m *Encoder) encodeTime(v reflect.Value) error { t := v.Interface().(time.Time) m.w.WriteTimestamp(t) return m.w.Err() } -func (m *Encoder) marshalDecimal(v reflect.Value) error { +// EncodeDecimal encodes an ion.Decimal to the output writer as an Ion decimal. +func (m *Encoder) encodeDecimal(v reflect.Value) error { d := v.Addr().Interface().(*Decimal) m.w.WriteDecimal(d) return m.w.Err() } +// EmptyValue returns true if the given value is the empty value for its type. func emptyValue(v reflect.Value) bool { switch v.Kind() { case reflect.Array, reflect.Map, reflect.Slice, reflect.String: @@ -304,110 +351,3 @@ func emptyValue(v reflect.Value) bool { } return false } - -type field struct { - name string - typ reflect.Type - path []int - omitEmpty bool -} - -type fielder struct { - fields []field - index map[string]bool -} - -func fieldsFor(t reflect.Type) []field { - fldr := fielder{index: map[string]bool{}} - fldr.inspect(t, nil) - return fldr.fields -} - -func (f *fielder) inspect(t reflect.Type, path []int) { - for i := 0; i < t.NumField(); i++ { - sf := t.Field(i) - if !visible(&sf) { - // Skip non-visible fields. - continue - } - - tag := sf.Tag.Get("json") - if tag == "-" { - // Skip fields that are explicitly hidden by tag. - continue - } - name, opts := parseTag(tag) - - newpath := make([]int, len(path)+1) - copy(newpath, path) - newpath[len(path)] = i - - ft := sf.Type - if ft.Name() == "" && ft.Kind() == reflect.Ptr { - ft = ft.Elem() - } - - if name == "" && sf.Anonymous && ft.Kind() == reflect.Struct { - // Dig in to the embedded struct. - f.inspect(ft, newpath) - } else { - // Add this named field. - if name == "" { - name = sf.Name - } - - if f.index[name] { - panic(fmt.Sprintf("too many fields named %v", name)) - } - f.index[name] = true - - f.fields = append(f.fields, field{ - name: name, - typ: ft, - path: newpath, - omitEmpty: omitEmpty(opts), - }) - } - } -} - -func visible(sf *reflect.StructField) bool { - exported := sf.PkgPath == "" - if sf.Anonymous { - t := sf.Type - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if t.Kind() == reflect.Struct { - // Fields of embedded structs are visible even if the struct type itself is not. - return true - } - } - return exported -} - -func parseTag(tag string) (string, string) { - if idx := strings.Index(tag, ","); idx != -1 { - // Ignore additional JSON options, at least for now. - return tag[:idx], tag[idx+1:] - } - return tag, "" -} - -func omitEmpty(opts string) bool { - for opts != "" { - var o string - - i := strings.Index(opts, ",") - if i >= 0 { - o, opts = opts[:i], opts[i+1:] - } else { - o, opts = opts, "" - } - - if o == "omitempty" { - return true - } - } - return false -} diff --git a/marshal_test.go b/marshal_test.go index abe80e25..72e8dc7a 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -67,11 +67,9 @@ func TestMarshalText(t *testing.T) { } func TestMarshalBinary(t *testing.T) { - lst := NewLocalSymbolTable(nil, nil) - test := func(v interface{}, name string, eval []byte) { t.Run(name, func(t *testing.T) { - val, err := MarshalBinary(v, lst) + val, err := MarshalBinary(v) if err != nil { t.Fatal(err) } @@ -82,6 +80,41 @@ func TestMarshalBinary(t *testing.T) { } test(nil, "null", []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) + test(struct{ A, B int }{42, 0}, "{A:42,B:0}", []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE9, 0x81, 0x83, 0xD6, 0x87, 0xB4, 0x81, 'A', 0x81, 'B', + 0xD5, + 0x8A, 0x21, 0x2A, + 0x8B, 0x20, + }) +} + +func TestMarshalBinaryLST(t *testing.T) { + lsta := NewLocalSymbolTable(nil, nil) + lstb := NewLocalSymbolTable(nil, []string{ + "A", "B", + }) + + test := func(v interface{}, name string, lst SymbolTable, eval []byte) { + t.Run(name, func(t *testing.T) { + val, err := MarshalBinaryLST(v, lst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected '%v', got '%v'", fmtbytes(eval), fmtbytes(val)) + } + }) + } + + test(nil, "null", lsta, []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) + test(struct{ A, B int }{42, 0}, "{A:42,B:0}", lstb, []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE9, 0x81, 0x83, 0xD6, 0x87, 0xB4, 0x81, 'A', 0x81, 'B', + 0xD5, + 0x8A, 0x21, 0x2A, + 0x8B, 0x20, + }) } func TestMarshalNestedStructs(t *testing.T) { diff --git a/skipper.go b/skipper.go index 18d09c7b..2ce458bf 100644 --- a/skipper.go +++ b/skipper.go @@ -5,19 +5,56 @@ import ( "io" ) -// FinishValue skips to the end of the current value if (and only if) -// we're currently in the middle of reading it. -func (t *tokenizer) finishValue() (bool, error) { - if t.unfinished { - c, err := t.skipValue() - if err != nil { - return true, err - } - t.unread(c) - t.unfinished = false - return true, nil +// SkipContainerContents skips over the contents of a container of the given type. +func (t *tokenizer) SkipContainerContents(typ Type) error { + switch typ { + case StructType: + return t.skipStructHelper() + case ListType: + return t.skipListHelper() + case SexpType: + return t.skipSexpHelper() + default: + panic(fmt.Sprintf("invalid container type: %v", typ)) } - return false, nil +} + +// Skips whitespace and a double-colon token, if there is one. +func (t *tokenizer) SkipDoubleColon() (bool, bool, error) { + ws, err := t.skipWhitespaceHelper() + if err != nil { + return false, false, err + } + + ok, err := t.skipDoubleColon() + if err != nil { + return false, false, err + } + + return ok, ws, nil +} + +// Peeks ahead to see if the next token is a dot, and +// if so skips it. If not, leaves the next token unconsumed. +func (t *tokenizer) SkipDot() (bool, error) { + c, err := t.peek() + if err != nil { + return false, err + } + if c != '.' { + return false, nil + } + + t.read() + return true, nil +} + +// SkipLobWhitespace skips whitespace when we're inside a large +// object ({{ ///= }} or {{ '''///=''' }}) where comments are +// not allowed. +func (t *tokenizer) SkipLobWhitespace() (int, error) { + c, _, err := t.skipLobWhitespace() + return c, err } // SkipValue skips to the end of the current value, if the caller @@ -512,7 +549,7 @@ func (t *tokenizer) skipEndOfLongString(handler commentHandler) (bool, error) { // Check if it's another triple-quote; if so, keep going. if c == '\'' { - ok, err := t.isTripleQuote() + ok, err := t.IsTripleQuote() if err != nil { return false, err } @@ -619,7 +656,7 @@ func (t *tokenizer) skipContainerHelper(term int) error { } case '\'': - ok, err := t.isTripleQuote() + ok, err := t.IsTripleQuote() if err != nil { return err } @@ -824,18 +861,3 @@ func (t *tokenizer) skipDoubleColon() (bool, error) { return false, nil } - -// Peeks ahead to see if the next token is a dot, and -// if so skips it. If not, leaves the next token unconsumed. -func (t *tokenizer) skipDot() (bool, error) { - c, err := t.peek() - if err != nil { - return false, err - } - if c != '.' { - return false, nil - } - - t.read() - return true, nil -} diff --git a/symboltable.go b/symboltable.go index b3330481..03c18633 100644 --- a/symboltable.go +++ b/symboltable.go @@ -93,7 +93,7 @@ func (s *sharedSymbolTable) FindByID(id int) (string, bool) { } func (s *sharedSymbolTable) WriteTo(w Writer) error { - w.TypeAnnotation("$ion_shared_symbol_table") + w.Annotation("$ion_shared_symbol_table") w.BeginStruct() w.FieldName("name") @@ -237,7 +237,7 @@ func (t *localSymbolTable) WriteTo(w Writer) error { return nil } - w.TypeAnnotation("$ion_symbol_table") + w.Annotation("$ion_symbol_table") w.BeginStruct() if len(t.imports) > 1 { diff --git a/textreader.go b/textreader.go index 37495b53..8efabf9b 100644 --- a/textreader.go +++ b/textreader.go @@ -13,19 +13,18 @@ import ( "time" ) -type textReaderState uint8 +// trs is the state of the text reader. +type trs uint8 const ( - trsDone textReaderState = iota + trsDone trs = iota trsBeforeFieldName trsBeforeTypeAnnotations - trsBeforeScalar trsBeforeContainer - trsInValue trsAfterValue ) -func (s textReaderState) String() string { +func (s trs) String() string { switch s { case trsDone: return "" @@ -42,17 +41,18 @@ func (s textReaderState) String() string { } } +// A textReader is a Reader that reads text Ion. type textReader struct { tok tokenizer - state textReaderState - ctx ctx + state trs + ctx ctxstack eof bool err error - fieldName string - typeAnnotations []string - valueType Type - value interface{} + fieldName string + annotations []string + valueType Type + value interface{} debug bool } @@ -67,25 +67,28 @@ func NewTextReader(in io.Reader) Reader { } } -// NewTextReaderString creates a new text reader from a string. -func NewTextReaderString(str string) Reader { +// NewTextReaderStr creates a new text reader from a string. +func NewTextReaderStr(str string) Reader { return NewTextReader(strings.NewReader(str)) } +// SymbolTable returns the current symbol table. func (t *textReader) SymbolTable() SymbolTable { - // Text content doesn't have a symbol table. + // TODO: Include me if present in the input stream? return nil } +// Next moves the reader to the next value. func (t *textReader) Next() bool { if t.state == trsDone || t.eof { return false } if t.debug { - fmt.Println("state:", t.state) + fmt.Println("ion: state =", t.state) } + // If we haven't fully read the current value, skip over it. err := t.finishValue() if err != nil { t.explode(err) @@ -93,51 +96,268 @@ func (t *textReader) Next() bool { } if t.debug { - fmt.Println("state after finish:", t.state) + fmt.Println("ion: state after finish =", t.state) } t.fieldName = "" - t.typeAnnotations = nil + t.annotations = nil t.valueType = NoType t.value = nil - if err := t.tok.Next(); err != nil { - t.explode(err) - return false - } + // Loop until we've consumed enough tokens to know what the next value is. + for { + if err := t.tok.Next(); err != nil { + t.explode(err) + return false + } - if t.debug { - fmt.Println("read token:", t.tok.Token()) - } + if t.debug { + fmt.Println("ion: read token ", t.tok.Token()) + } - for { - var f func() (bool, error) + var done bool + var err error switch t.state { case trsAfterValue: - f = t.nextAfterValue + done, err = t.nextAfterValue() case trsBeforeFieldName: - f = t.nextBeforeFieldName + done, err = t.nextBeforeFieldName() case trsBeforeTypeAnnotations: - f = t.nextBeforeTypeAnnotations + done, err = t.nextBeforeTypeAnnotations() default: - panic(fmt.Sprintf("invalid state: %v", t.state)) + panic(fmt.Sprintf("unexpected state: %v", t.state)) } - - done, err := f() if err != nil { t.explode(err) return false } + if done { + // We're done reading tokens. If we hit the end of the current sequence, + // return false. Otherwise, we've got a value for the caller. return !t.eof } + } +} - if err := t.tok.Next(); err != nil { +// Err returns the current error. +func (t *textReader) Err() error { + return t.err +} + +// Type returns the current value's type. +func (t *textReader) Type() Type { + return t.valueType +} + +// IsNull returns true if the current value is null. +func (t *textReader) IsNull() bool { + return t.valueType != NoType && t.value == nil +} + +// FieldName returns the current value's field name. +func (t *textReader) FieldName() string { + return t.fieldName +} + +// Annotations returns the current value's annotations. +func (t *textReader) Annotations() []string { + return t.annotations +} + +// StepIn steps in to a container. +func (t *textReader) StepIn() error { + if t.err != nil { + return t.err + } + if t.state != trsBeforeContainer { + return errors.New("ion: StepIn called when not on a container") + } + + ctx := containerTypeToCtx(t.valueType) + t.ctx.push(ctx) + + if ctx == ctxInStruct { + t.state = trsBeforeFieldName + } else { + t.state = trsBeforeTypeAnnotations + } + + t.tok.SetFinished() + return nil +} + +// StepOut steps out of a container. +func (t *textReader) StepOut() error { + if t.err != nil { + return t.err + } + + ctx := t.ctx.peek() + if ctx == ctxAtTopLevel { + return errors.New("ion: StepOut called at top level") + } + ctype := ctxToContainerType(ctx) + + // Finish off whatever value *inside* the container that we're currently reading. + _, err := t.tok.FinishValue() + if err != nil { + t.explode(err) + return err + } + + // If we haven't seen the end of the container yet, skip values until we find it. + if !t.eof { + if err := t.tok.SkipContainerContents(ctype); err != nil { t.explode(err) - return false + return err + } + } + + t.ctx.pop() + t.state = t.stateAfterValue() + t.valueType = NoType + t.value = nil + t.eof = false + + return nil +} + +// BoolValue returns the current value as a bool. +func (t *textReader) BoolValue() (bool, error) { + if t.valueType == BoolType { + if t.value == nil { + return false, nil } + return t.value.(bool), nil } + return false, errors.New("ion: value is not a bool") +} + +// IntSize returns the size of the current int value. +func (t *textReader) IntSize() (IntSize, error) { + if t.valueType != IntType { + return NullInt, errors.New("ion: value is not an int") + } + if t.value == nil { + return NullInt, nil + } + + if i, ok := t.value.(int64); ok { + if i > math.MaxInt32 || i < math.MinInt32 { + return Int64, nil + } + return Int32, nil + } + + return BigInt, nil +} + +// IntValue returns the current value as an int. +func (t *textReader) IntValue() (int, error) { + i, err := t.Int64Value() + if err != nil { + return 0, err + } + if i > math.MaxInt32 || i < math.MinInt32 { + return 0, errors.New("ion: int value out of bounds") + } + return int(i), nil +} + +// Int64Value returns the current value as an int64. +func (t *textReader) Int64Value() (int64, error) { + if t.valueType == IntType { + if t.value == nil { + return 0, nil + } + + if i, ok := t.value.(int64); ok { + return i, nil + } + + bi := t.value.(*big.Int) + if bi.IsInt64() { + return bi.Int64(), nil + } + + return 0, errors.New("ion: int value out of bounds") + } + return 0, errors.New("ion: value is not an int") +} + +// BigIntValue returns the current value as a big int. +func (t *textReader) BigIntValue() (*big.Int, error) { + if t.valueType == IntType { + if t.value == nil { + return nil, nil + } + if i, ok := t.value.(int64); ok { + return big.NewInt(i), nil + } + return t.value.(*big.Int), nil + } + return nil, errors.New("ion: value is not an int") +} + +// FloatValue returns the current value as a float. +func (t *textReader) FloatValue() (float64, error) { + if t.valueType == FloatType { + if t.value == nil { + return 0.0, nil + } + return t.value.(float64), nil + } + return 0.0, errors.New("ion: value is not a float") +} + +// DecimalValue returns the current value as a Decimal. +func (t *textReader) DecimalValue() (*Decimal, error) { + switch t.valueType { + case DecimalType: + if t.value == nil { + return nil, nil + } + return t.value.(*Decimal), nil + } + return nil, errors.New("ion: value is not a decimal") +} + +// TimeValue returns the current value as a time. +func (t *textReader) TimeValue() (time.Time, error) { + switch t.valueType { + case TimestampType: + if t.value == nil { + return time.Time{}, nil + } + return t.value.(time.Time), nil + } + return time.Time{}, errors.New("ion: value is not a timestamp") +} + +// StringValue returns the current value as a string. +func (t *textReader) StringValue() (string, error) { + switch t.valueType { + case StringType, SymbolType: + if t.value == nil { + return "", nil + } + return t.value.(string), nil + } + return "", errors.New("ion: value is not a string") +} + +// ByteValue returns the current value as a byte slice. +func (t *textReader) ByteValue() ([]byte, error) { + switch t.valueType { + case BlobType, ClobType: + if t.value == nil { + return nil, nil + } + return t.value.([]byte), nil + } + return nil, errors.New("ion: value is not a byte array") } // NextAfterValue moves to the next value when we're in the @@ -146,7 +366,7 @@ func (t *textReader) nextAfterValue() (bool, error) { tok := t.tok.Token() switch tok { case tokenComma: - // Another value coming; eat the comma and move to the + // There's another value coming; eat the comma and move to the // appropriate next state. switch t.ctx.peek() { case ctxInStruct: @@ -154,7 +374,7 @@ func (t *textReader) nextAfterValue() (bool, error) { case ctxInList: t.state = trsBeforeTypeAnnotations default: - panic(fmt.Sprintf("invalid state: %v", t.ctx.peek())) + panic(fmt.Sprintf("unexpected context: %v", t.ctx.peek())) } return false, nil @@ -164,7 +384,7 @@ func (t *textReader) nextAfterValue() (bool, error) { t.eof = true return true, nil } - return false, errors.New("unexpected token '}'") + return false, errors.New("ion: unexpected token '}'") case tokenCloseBracket: // No more values in this list. @@ -172,10 +392,10 @@ func (t *textReader) nextAfterValue() (bool, error) { t.eof = true return true, nil } - return false, errors.New("unexpected token ']'") + return false, errors.New("ion: unexpected token ']'") default: - return false, fmt.Errorf("unexpected token '%v'", tok) + return false, fmt.Errorf("ion: unexpected token '%v'", tok) } } @@ -206,7 +426,7 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { return false, err } if tok = t.tok.Token(); tok != tokenColon { - return false, fmt.Errorf("unexpected token '%v'", tok) + return false, fmt.Errorf("ion: unexpected token '%v'", tok) } t.fieldName = val @@ -215,7 +435,7 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { return false, nil default: - return false, fmt.Errorf("unexpected token '%v'", tok) + return false, fmt.Errorf("ion: unexpected token '%v'", tok) } } @@ -229,12 +449,12 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.eof = true return true, nil } - return false, errors.New("unexpected EOF") + return false, errors.New("ion: unexpected EOF") case tokenSymbolOperator, tokenDot: if t.ctx.peek() != ctxInSexp { // Operators can only appear inside an sexp. - return false, fmt.Errorf("unexpected token '%v'", tok) + return false, fmt.Errorf("ion: unexpected token '%v'", tok) } fallthrough @@ -244,23 +464,19 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { return false, err } - ws, err := t.tok.skipWhitespaceHelper() + ok, ws, err := t.tok.SkipDoubleColon() if err != nil { return false, err } - ok, err := t.tok.skipDoubleColon() - if err != nil { - return false, err - } if ok { - // val was a type annotation; remember it and keep going. + // val was an annotation; remember it and keep going. if tok == tokenSymbol { - if err := verifyUnquotedSymbol(val, "type annotation"); err != nil { + if err := verifyUnquotedSymbol(val, "annotation"); err != nil { return false, err } } - t.typeAnnotations = append(t.typeAnnotations, val) + t.annotations = append(t.annotations, val) return false, nil } @@ -317,21 +533,13 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.value = SexpType return true, nil - case tokenCloseBrace: - // No more values in this struct. - if t.ctx.peek() == ctxInStruct { - t.eof = true - return true, nil - } - return false, errors.New("unexpected token '}'") - case tokenCloseBracket: // No more values in this list. if t.ctx.peek() == ctxInList { t.eof = true return true, nil } - return false, errors.New("unexpected token ']'") + return false, errors.New("ion: unexpected token ']'") case tokenCloseParen: // No more values in this sexp. @@ -339,22 +547,25 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.eof = true return true, nil } - return false, errors.New("unexpected token ')'") + return false, errors.New("ion: unexpected token ')'") default: - return false, fmt.Errorf("unexpected token '%v'", tok) + return false, fmt.Errorf("ion: unexpected token '%v'", tok) } } +// VerifyUnquotedSymbol checks for certain 'special' values that are returned from +// the tokenizer as symbols but cannot be used as field names or annotations. func verifyUnquotedSymbol(val string, ctx string) error { switch val { case "null", "true", "false", "nan": - return fmt.Errorf("cannot use unquoted keyword %v as %v", val, ctx) + return fmt.Errorf("ion: cannot use unquoted keyword %v as %v", val, ctx) } return nil } -func (t *textReader) onSymbol(val string, tok tokenType, ws bool) error { +// OnSymbol handles finding a symbol-token value. +func (t *textReader) onSymbol(val string, tok token, ws bool) error { valueType := SymbolType var value interface{} = val @@ -389,9 +600,10 @@ func (t *textReader) onSymbol(val string, tok tokenType, ws bool) error { return nil } +// OnNull handles finding a null token. func (t *textReader) onNull(ws bool) (Type, error) { if !ws { - ok, err := t.tok.skipDot() + ok, err := t.tok.SkipDot() if err != nil { return NoType, err } @@ -399,16 +611,16 @@ func (t *textReader) onNull(ws bool) (Type, error) { return t.readNullType() } } - return NullType, nil } +// readNullType reads the null.{this} type symbol. func (t *textReader) readNullType() (Type, error) { if err := t.tok.Next(); err != nil { return NoType, err } if t.tok.Token() != tokenSymbol { - return NoType, fmt.Errorf("unexpected token %v after null", t.tok.Token()) + return NoType, fmt.Errorf("ion: invalid symbol null.%v", t.tok.Token()) } val, err := t.tok.ReadValue(tokenSymbol) @@ -444,11 +656,12 @@ func (t *textReader) readNullType() (Type, error) { case "sexp": return SexpType, nil default: - return NoType, fmt.Errorf("invalid symbol null.%v", val) + return NoType, fmt.Errorf("ion: invalid symbol null.%v", val) } } -func (t *textReader) onNumber(tok tokenType) error { +// OnNumber handles finding a number token. +func (t *textReader) onNumber(tok token) error { var valueType Type var value interface{} @@ -493,7 +706,7 @@ func (t *textReader) onNumber(tok tokenType) error { case DecimalType: value, err = parseDecimal(val) default: - panic("unexpected type") + panic(fmt.Sprintf("unexpected type %v", tt)) } if err != nil { @@ -509,7 +722,7 @@ func (t *textReader) onNumber(tok tokenType) error { value = math.Inf(-1) default: - panic("unexpected token type") + panic(fmt.Sprintf("unexpected token type %v", tok)) } t.state = t.stateAfterValue() @@ -519,6 +732,7 @@ func (t *textReader) onNumber(tok tokenType) error { return nil } +// OnTimestamp handles finding a timestamp token. func (t *textReader) onTimestamp() error { val, err := t.tok.ReadValue(tokenTimestamp) if err != nil { @@ -537,8 +751,9 @@ func (t *textReader) onTimestamp() error { return nil } +// OnLob handles finding a [bc]lob token. func (t *textReader) onLob() error { - c, _, err := t.tok.skipLobWhitespace() + c, err := t.tok.SkipLobWhitespace() if err != nil { return err } @@ -561,7 +776,7 @@ func (t *textReader) onLob() error { } else if c == '\'' { // Long clob. - ok, err := t.tok.isTripleQuote() + ok, err := t.tok.IsTripleQuote() if err != nil { return err } @@ -601,233 +816,9 @@ func (t *textReader) onLob() error { return nil } -func (t *textReader) Type() Type { - return t.valueType -} - -func (t *textReader) Err() error { - return t.err -} - -func (t *textReader) FieldName() string { - return t.fieldName -} - -func (t *textReader) TypeAnnotations() []string { - return t.typeAnnotations -} - -func (t *textReader) IsNull() bool { - return t.value == nil -} - -func (t *textReader) StepIn() error { - if t.err != nil { - return t.err - } - if t.state != trsBeforeContainer { - return fmt.Errorf("stepin called in invalid state %v", t.state) - } - - var ctx ctxType - switch t.valueType { - case StructType: - ctx = ctxInStruct - case ListType: - ctx = ctxInList - case SexpType: - ctx = ctxInSexp - default: - panic("trsBeforeContainer with unexpected valueType") - } - t.ctx.push(ctx) - - if ctx == ctxInStruct { - t.state = trsBeforeFieldName - } else { - t.state = trsBeforeTypeAnnotations - } - - // TODO: Make this less hacky. - t.tok.unfinished = false - return nil -} - -func (t *textReader) StepOut() error { - if t.err != nil { - return t.err - } - - ctx := t.ctx.peek() - if ctx == ctxAtTopLevel { - return errors.New("stepout called at top level") - } - - _, err := t.tok.finishValue() - if err != nil { - t.explode(err) - return err - } - - if !t.eof { - // Haven't seen the end of the container yet; skip until we - // find it. - switch t.ctx.peek() { - case ctxInStruct: - err = t.tok.skipStructHelper() - case ctxInList: - err = t.tok.skipListHelper() - case ctxInSexp: - err = t.tok.skipSexpHelper() - default: - panic("invalid ctx") - } - - if err != nil { - t.explode(err) - return err - } - } - - t.ctx.pop() - t.state = t.stateAfterValue() - t.valueType = NoType - t.value = nil - t.eof = false - - return nil -} - -func (t *textReader) BoolValue() (bool, error) { - if t.valueType == BoolType { - if t.value == nil { - return false, nil - } - return t.value.(bool), nil - } - return false, errors.New("value is not a bool") -} - -func (t *textReader) IntSize() IntSize { - if t.valueType != IntType || t.value == nil { - return NullInt - } - - if i, ok := t.value.(int64); ok { - if i > math.MaxInt32 || i < math.MinInt32 { - return Int64 - } - return Int32 - } - - return BigInt -} - -func (t *textReader) IntValue() (int, error) { - i, err := t.Int64Value() - if err != nil { - return 0, err - } - if i > math.MaxInt32 || i < math.MinInt32 { - return 0, errors.New("value out of bounds") - } - return int(i), nil -} - -func (t *textReader) Int64Value() (int64, error) { - if t.valueType == IntType { - if t.value == nil { - return 0, nil - } - - if i, ok := t.value.(int64); ok { - return i, nil - } - - bi := t.value.(*big.Int) - if bi.IsInt64() { - return bi.Int64(), nil - } - - return 0, errors.New("value out of bounds") - } - return 0, errors.New("value is not an int") -} - -func (t *textReader) BigIntValue() (*big.Int, error) { - if t.valueType == IntType { - if t.value == nil { - return nil, nil - } - if i, ok := t.value.(int64); ok { - return big.NewInt(i), nil - } - return t.value.(*big.Int), nil - } - return nil, errors.New("value is not an int") -} - -func (t *textReader) FloatValue() (float64, error) { - if t.valueType == FloatType { - if t.value == nil { - return 0.0, nil - } - return t.value.(float64), nil - } - // TODO: Cast ints/decimals? - return 0.0, errors.New("value is not a float") -} - -func (t *textReader) DecimalValue() (*Decimal, error) { - switch t.valueType { - case DecimalType: - if t.value == nil { - return nil, nil - } - return t.value.(*Decimal), nil - } - // TODO: Cast floats/ints? - return nil, errors.New("value is not a decimal") -} - -func (t *textReader) TimeValue() (time.Time, error) { - switch t.valueType { - case TimestampType: - if t.value == nil { - return time.Time{}, nil - } - return t.value.(time.Time), nil - } - return time.Time{}, errors.New("value is not a timestamp") -} - -func (t *textReader) StringValue() (string, error) { - switch t.valueType { - case StringType, SymbolType: - if t.value == nil { - return "", nil - } - return t.value.(string), nil - - default: - return "", errors.New("value is not a string") - } -} - -func (t *textReader) ByteValue() ([]byte, error) { - switch t.valueType { - case BlobType, ClobType: - if t.value == nil { - return nil, nil - } - return t.value.([]byte), nil - } - return nil, errors.New("value is not a byte array") -} - // FinishValue finishes reading the current value, if there is one. func (t *textReader) finishValue() error { - ok, err := t.tok.finishValue() + ok, err := t.tok.FinishValue() if err != nil { return err } @@ -839,14 +830,15 @@ func (t *textReader) finishValue() error { return nil } -func (t *textReader) stateAfterValue() textReaderState { - switch t.ctx.peek() { +func (t *textReader) stateAfterValue() trs { + ctx := t.ctx.peek() + switch ctx { case ctxInList, ctxInStruct: return trsAfterValue case ctxInSexp, ctxAtTopLevel: return trsBeforeTypeAnnotations default: - panic("invalid ctx") + panic(fmt.Sprintf("invalid ctx %v", ctx)) } } diff --git a/textreader_test.go b/textreader_test.go index 20ff4592..fec0108c 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -9,7 +9,7 @@ import ( ) func TestIgnoreValues(t *testing.T) { - r := NewTextReaderString("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") + r := NewTextReaderStr("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") _next(t, r, SexpType) _next(t, r, StructType) @@ -22,7 +22,7 @@ func TestIgnoreValues(t *testing.T) { func TestReadSexps(t *testing.T) { test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _sexp(t, r, f) _eof(t, r) }) @@ -51,7 +51,7 @@ func TestReadSexps(t *testing.T) { func TestStructs(t *testing.T) { test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _struct(t, r, f) _eof(t, r) }) @@ -73,7 +73,7 @@ func TestStructs(t *testing.T) { } func TestMultipleStructs(t *testing.T) { - r := NewTextReaderString("{} {} {}") + r := NewTextReaderStr("{} {} {}") for i := 0; i < 3; i++ { _struct(t, r, func(t *testing.T, r Reader) { @@ -85,7 +85,7 @@ func TestMultipleStructs(t *testing.T) { } func TestNullStructs(t *testing.T) { - r := NewTextReaderString("null.struct 'null'::{foo:bar}") + r := NewTextReaderStr("null.struct 'null'::{foo:bar}") _null(t, r, StructType) _nextAF(t, r, StructType, "", []string{"null"}) @@ -95,7 +95,7 @@ func TestNullStructs(t *testing.T) { func TestLists(t *testing.T) { test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _list(t, r, f) _eof(t, r) }) @@ -123,7 +123,7 @@ func TestReadNestedLists(t *testing.T) { _eof(t, r) } - r := NewTextReaderString("[[], [[]]]") + r := NewTextReaderStr("[[], [[]]]") _list(t, r, func(t *testing.T, r Reader) { _list(t, r, empty) @@ -141,7 +141,7 @@ func TestReadNestedLists(t *testing.T) { func TestClobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _next(t, r, ClobType) val, err := r.ByteValue() @@ -165,7 +165,7 @@ func TestClobs(t *testing.T) { func TestBlobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _next(t, r, BlobType) val, err := r.ByteValue() @@ -188,7 +188,7 @@ func TestBlobs(t *testing.T) { func TestTimestamps(t *testing.T) { testA := func(str string, etas []string, eval time.Time) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _nextAF(t, r, TimestampType, "", etas) val, err := r.TimeValue() @@ -228,7 +228,7 @@ func TestDoubles(t *testing.T) { t.Run(str, func(t *testing.T) { ee := MustParseDecimal(eval) - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _nextAF(t, r, DecimalType, "", etas) val, err := r.DecimalValue() @@ -260,7 +260,7 @@ func TestDoubles(t *testing.T) { func TestFloats(t *testing.T) { testA := func(str string, etas []string, eval float64) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _floatAF(t, r, "", etas, eval) _eof(t, r) }) @@ -282,7 +282,7 @@ func TestFloats(t *testing.T) { func TestInts(t *testing.T) { test := func(str string, f func(*testing.T, Reader)) { t.Run(str, func(t *testing.T) { - r := NewTextReaderString(str) + r := NewTextReaderStr(str) _next(t, r, IntType) f(t, r) @@ -352,7 +352,7 @@ func TestInts(t *testing.T) { } func TestStrings(t *testing.T) { - r := NewTextReaderString(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) + r := NewTextReaderStr(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) _stringAF(t, r, "", []string{"foo"}, "bar") _string(t, r, "baz") @@ -363,7 +363,7 @@ func TestStrings(t *testing.T) { } func TestSymbols(t *testing.T) { - r := NewTextReaderString("'null'::foo bar a::b::'baz' null.symbol") + r := NewTextReaderStr("'null'::foo bar a::b::'baz' null.symbol") _symbolAF(t, r, "", []string{"null"}, "foo") _symbol(t, r, "bar") @@ -374,7 +374,7 @@ func TestSymbols(t *testing.T) { } func TestSpecialSymbols(t *testing.T) { - r := NewTextReaderString("null\nnull.struct\ntrue\nfalse\nnan") + r := NewTextReaderStr("null\nnull.struct\ntrue\nfalse\nnan") _null(t, r, NullType) _null(t, r, StructType) @@ -386,7 +386,7 @@ func TestSpecialSymbols(t *testing.T) { } func TestOperators(t *testing.T) { - r := NewTextReaderString("(a*(b+c))") + r := NewTextReaderStr("(a*(b+c))") _sexp(t, r, func(t *testing.T, r Reader) { _symbol(t, r, "a") @@ -402,7 +402,7 @@ func TestOperators(t *testing.T) { } func TestTopLevelOperators(t *testing.T) { - r := NewTextReaderString("a + b") + r := NewTextReaderStr("a + b") _symbol(t, r, "a") @@ -414,6 +414,15 @@ func TestTopLevelOperators(t *testing.T) { } } +func TestTrsToString(t *testing.T) { + for i := trsDone; i <= trsAfterValue+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected a non-empty string for trs %v", uint8(i)) + } + } +} + type containerhandler func(t *testing.T, r Reader) func _sexp(t *testing.T, r Reader, f containerhandler) { @@ -564,8 +573,8 @@ func _nextAF(t *testing.T, r Reader, et Type, efn string, etas []string) { if efn != r.FieldName() { t.Errorf("expected fieldname=%v, got %v", efn, r.FieldName()) } - if !_strequals(etas, r.TypeAnnotations()) { - t.Errorf("expected type annotations=%v, got %v", etas, r.TypeAnnotations()) + if !_strequals(etas, r.Annotations()) { + t.Errorf("expected type annotations=%v, got %v", etas, r.Annotations()) } } diff --git a/textutils.go b/textutils.go index edb4682d..e57d4ad2 100644 --- a/textutils.go +++ b/textutils.go @@ -185,8 +185,6 @@ func fromHex(c int) (int, error) { return 0, invalidChar(c) } -var hexChars = []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'} - // Write out the given character in escaped form. func writeEscapedChar(c byte, out io.Writer) error { switch c { diff --git a/textwriter.go b/textwriter.go index 0230537e..85a855b4 100644 --- a/textwriter.go +++ b/textwriter.go @@ -15,10 +15,10 @@ import ( type TextWriterOpts uint8 const ( - // OptQuietFinish disables emiting a newline in Finish(). Convenient if you know - // you're only emiting one datagram; dangerous if there's a chance you're going to - // emit another datagram using the same Writer. - OptQuietFinish TextWriterOpts = 1 + // TextWriterQuietFinish disables emiting a newline in Finish(). Convenient if you + // know you're only emiting one datagram; dangerous if there's a chance you're going + // to emit another datagram using the same Writer. + TextWriterQuietFinish TextWriterOpts = 1 ) // textWriter is a writer that writes human-readable text @@ -78,9 +78,9 @@ func (w *textWriter) beginValue() error { } } - if len(w.typeAnnotations) > 0 { - as := w.typeAnnotations - w.typeAnnotations = nil + if len(w.annotations) > 0 { + as := w.annotations + w.annotations = nil for _, a := range as { if err := writeSymbol(a, w.out); err != nil { @@ -101,7 +101,7 @@ func (w *textWriter) endValue() { } // begin starts writing a container of the given type. -func (w *textWriter) begin(t ctxType, c byte) error { +func (w *textWriter) begin(t ctx, c byte) error { if err := w.beginValue(); err != nil { return err } @@ -113,7 +113,7 @@ func (w *textWriter) begin(t ctxType, c byte) error { } // end finishes writing a container of the given type -func (w *textWriter) end(t ctxType, c byte) error { +func (w *textWriter) end(t ctx, c byte) error { if w.ctx.peek() != t { return errors.New("not in that kind of container") } @@ -123,7 +123,7 @@ func (w *textWriter) end(t ctxType, c byte) error { } w.fieldName = "" - w.typeAnnotations = nil + w.annotations = nil w.ctx.pop() w.endValue() @@ -409,14 +409,6 @@ func (w *textWriter) WriteClob(val []byte) { }) } -func (w *textWriter) WriteValue(val interface{}) { - m := Encoder{ - w: w, - sortMaps: true, - } - w.err = m.Encode(val) -} - // Finish finishes the current datagram. func (w *textWriter) Finish() error { if w.err != nil { @@ -427,14 +419,14 @@ func (w *textWriter) Finish() error { return w.err } - if w.opts&OptQuietFinish == 0 { + if w.opts&TextWriterQuietFinish == 0 { if w.err = writeRawChar('\n', w.out); w.err != nil { return w.err } } w.fieldName = "" - w.typeAnnotations = nil + w.annotations = nil w.needsSeparator = false return nil } diff --git a/textwriter_test.go b/textwriter_test.go index 9228d430..a3262b1d 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -50,9 +50,9 @@ func TestEmptyStruct(t *testing.T) { func TestAnnotatedStruct(t *testing.T) { testTextWriter(t, "foo::$bar::'.baz'::{}", func(w Writer) { - w.TypeAnnotation("foo") - w.TypeAnnotation("$bar") - w.TypeAnnotation(".baz") + w.Annotation("foo") + w.Annotation("$bar") + w.Annotation(".baz") w.BeginStruct() w.EndStruct() @@ -67,7 +67,7 @@ func TestNestedStruct(t *testing.T) { w.BeginStruct() w.FieldName("foo") - w.TypeAnnotation("true") + w.Annotation("true") w.BeginStruct() w.EndStruct() @@ -109,11 +109,11 @@ func TestNestedLists(t *testing.T) { w.BeginStruct() w.EndStruct() - w.TypeAnnotation("foo") + w.Annotation("foo") w.BeginStruct() w.EndStruct() - w.TypeAnnotation("null") + w.Annotation("null") w.BeginList() w.EndList() @@ -146,10 +146,10 @@ func TestNull(t *testing.T) { w.BeginList() w.WriteNull() - w.TypeAnnotation("foo") + w.Annotation("foo") w.WriteNullWithType(NullType) w.WriteNullWithType(IntType) - w.TypeAnnotation("bar") + w.Annotation("bar") w.WriteNullWithType(SexpType) w.EndList() @@ -164,12 +164,12 @@ func TestBool(t *testing.T) { w.BeginSexp() w.WriteBool(false) - w.TypeAnnotation("123") + w.Annotation("123") w.WriteBool(true) w.EndSexp() - w.TypeAnnotation("false") + w.Annotation("false") w.WriteBool(false) }) } @@ -179,7 +179,7 @@ func TestWriteTextInt(t *testing.T) { testTextWriter(t, expected, func(w Writer) { w.BeginSexp() - w.TypeAnnotation("zero") + w.Annotation("zero") w.WriteInt(0) w.WriteInt(1) w.WriteInt(-1) @@ -205,7 +205,7 @@ func TestWriteTextBigInt(t *testing.T) { one.SetInt64(1) val.Add(&max, &one) - w.TypeAnnotation("big") + w.Annotation("big") w.WriteBigInt(&val) w.EndList() @@ -267,9 +267,9 @@ func TestSymbol(t *testing.T) { w.WriteSymbol("null") w.FieldName("f") - w.TypeAnnotation("a") - w.TypeAnnotation("b") - w.TypeAnnotation("u") + w.Annotation("a") + w.Annotation("b") + w.Annotation("u") w.WriteSymbol("lo🇺🇸") w.EndStruct() @@ -285,7 +285,7 @@ func TestString(t *testing.T) { w.BeginSexp() w.WriteString("\\\"\n\"\\") - w.TypeAnnotation("zany") + w.Annotation("zany") w.WriteString("🤪") w.EndSexp() @@ -298,7 +298,7 @@ func TestBlob(t *testing.T) { testTextWriter(t, expected, func(w Writer) { w.WriteBlob([]byte{0, 1, 2, 0xFD, 0xFE, 0xFF}) w.WriteBlob([]byte("Hello World")) - w.TypeAnnotation("empty") + w.Annotation("empty") w.WriteBlob(nil) }) } @@ -315,19 +315,6 @@ func TestClob(t *testing.T) { }) } -func TestWriteValue(t *testing.T) { - expected := "{s:{B:2,A:1}}" - testTextWriter(t, expected, func(w Writer) { - w.BeginStruct() - w.FieldName("s") - w.WriteValue(struct { - B int - A int - }{2, 1}) - w.EndStruct() - }) -} - func TestFinish(t *testing.T) { expected := "1\nfoo\n\"bar\"\n{}\n" testTextWriter(t, expected, func(w Writer) { diff --git a/tokenizer.go b/tokenizer.go index 7d5310e7..1831478a 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -9,10 +9,10 @@ import ( "strings" ) -type tokenType int +type token int const ( - tokenError tokenType = iota + tokenError token = iota tokenEOF // End of input @@ -45,7 +45,7 @@ const ( tokenCloseDoubleBrace // }} ) -func (t tokenType) String() string { +func (t token) String() string { switch t { case tokenError: return "" @@ -113,7 +113,7 @@ type tokenizer struct { in *bufio.Reader buffer []int - token tokenType + token token unfinished bool } @@ -132,7 +132,7 @@ func tokenize(in io.Reader) *tokenizer { } // Token returns the type of the current token. -func (t *tokenizer) Token() tokenType { +func (t *tokenizer) Token() token { return t.token } @@ -153,7 +153,7 @@ func (t *tokenizer) Next() error { switch { case c == -1: - return t.finish(tokenEOF, true) + return t.ok(tokenEOF, true) case c == ':': c2, err := t.peek() @@ -162,9 +162,9 @@ func (t *tokenizer) Next() error { } if c2 == ':' { t.read() - return t.finish(tokenDoubleColon, false) + return t.ok(tokenDoubleColon, false) } - return t.finish(tokenColon, false) + return t.ok(tokenColon, false) case c == '{': c2, err := t.peek() @@ -173,27 +173,27 @@ func (t *tokenizer) Next() error { } if c2 == '{' { t.read() - return t.finish(tokenOpenDoubleBrace, true) + return t.ok(tokenOpenDoubleBrace, true) } - return t.finish(tokenOpenBrace, true) + return t.ok(tokenOpenBrace, true) case c == '}': - return t.finish(tokenCloseBrace, false) + return t.ok(tokenCloseBrace, false) case c == '[': - return t.finish(tokenOpenBracket, true) + return t.ok(tokenOpenBracket, true) case c == ']': - return t.finish(tokenCloseBracket, false) + return t.ok(tokenCloseBracket, false) case c == '(': - return t.finish(tokenOpenParen, true) + return t.ok(tokenOpenParen, true) case c == ')': - return t.finish(tokenCloseParen, false) + return t.ok(tokenCloseParen, false) case c == ',': - return t.finish(tokenComma, false) + return t.ok(tokenComma, false) case c == '.': c2, err := t.peek() @@ -202,19 +202,19 @@ func (t *tokenizer) Next() error { } if isOperatorChar(c2) { t.unread(c) - return t.finish(tokenSymbolOperator, true) + return t.ok(tokenSymbolOperator, true) } - return t.finish(tokenDot, false) + return t.ok(tokenDot, false) case c == '\'': - ok, err := t.isTripleQuote() + ok, err := t.IsTripleQuote() if err != nil { return err } if ok { - return t.finish(tokenLongString, true) + return t.ok(tokenLongString, true) } - return t.finish(tokenSymbolQuoted, true) + return t.ok(tokenSymbolQuoted, true) case c == '+': ok, err := t.isInf(c) @@ -222,10 +222,10 @@ func (t *tokenizer) Next() error { return err } if ok { - return t.finish(tokenFloatInf, false) + return t.ok(tokenFloatInf, false) } t.unread(c) - return t.finish(tokenSymbolOperator, true) + return t.ok(tokenSymbolOperator, true) case c == '-': c2, err := t.peek() @@ -245,7 +245,7 @@ func (t *tokenizer) Next() error { } t.unread(c2) t.unread(c) - return t.finish(tt, true) + return t.ok(tt, true) } ok, err := t.isInf(c) @@ -253,22 +253,22 @@ func (t *tokenizer) Next() error { return err } if ok { - return t.finish(tokenFloatMinusInf, false) + return t.ok(tokenFloatMinusInf, false) } t.unread(c) - return t.finish(tokenSymbolOperator, true) + return t.ok(tokenSymbolOperator, true) case isOperatorChar(c): t.unread(c) - return t.finish(tokenSymbolOperator, true) + return t.ok(tokenSymbolOperator, true) case c == '"': - return t.finish(tokenString, true) + return t.ok(tokenString, true) case isIdentifierStart(c): t.unread(c) - return t.finish(tokenSymbol, true) + return t.ok(tokenSymbol, true) case isDigit(c): tt, err := t.scanForNumericType(c) @@ -277,21 +277,45 @@ func (t *tokenizer) Next() error { } t.unread(c) - return t.finish(tt, true) + return t.ok(tt, true) default: return invalidChar(c) } } -func (t *tokenizer) finish(token tokenType, more bool) error { - t.token = token +func (t *tokenizer) ok(tok token, more bool) error { + t.token = tok t.unfinished = more return nil } +// SetFinished marks the current token finished (indicating that the caller has +// chosen to step in to a list, sexp, or struct and Next should not skip over its +// contents in search of the next token). +func (t *tokenizer) SetFinished() { + t.unfinished = false +} + +// FinishValue skips to the end of the current value if (and only if) +// we're currently in the middle of reading it. +func (t *tokenizer) FinishValue() (bool, error) { + if !t.unfinished { + return false, nil + } + + c, err := t.skipValue() + if err != nil { + return true, err + } + + t.unread(c) + t.unfinished = false + return true, nil +} + // ReadValue reads the value of a token of the given type. -func (t *tokenizer) ReadValue(tok tokenType) (string, error) { +func (t *tokenizer) ReadValue(tok token) (string, error) { var str string var err error @@ -313,7 +337,7 @@ func (t *tokenizer) ReadValue(tok tokenType) (string, error) { case tokenTimestamp: str, err = t.readTimestamp() default: - panic("unsupported token type") + panic(fmt.Sprintf("unsupported token type %v", tok)) } if err != nil { @@ -981,7 +1005,7 @@ func (t *tokenizer) ReadLongClob() (string, error) { } // IsTripleQuote returns true if this is a triple-quote sequence ('''). -func (t *tokenizer) isTripleQuote() (bool, error) { +func (t *tokenizer) IsTripleQuote() (bool, error) { // We've just read a '\'', check if the next two are too. cs, err := t.peekN(2) if err == io.EOF { @@ -1036,7 +1060,7 @@ func (t *tokenizer) isInf(c int) (bool, error) { // out binary (0b...), hex (0x...), and timestamps (....-) via this // method. There are a couple other cases where we *could* distinguish, // but it's unclear that it's worth it. -func (t *tokenizer) scanForNumericType(c int) (tokenType, error) { +func (t *tokenizer) scanForNumericType(c int) (token, error) { if !isDigit(c) { panic("scanForNumericType with non-digit") } diff --git a/tokenizer_test.go b/tokenizer_test.go index 7354b162..f4af1584 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -8,7 +8,7 @@ import ( func TestNext(t *testing.T) { tok := tokenizeString("foo::'foo':[] 123, {})") - next := func(tt tokenType) { + next := func(tt token) { if err := tok.Next(); err != nil { t.Fatal(err) } @@ -28,7 +28,7 @@ func TestNext(t *testing.T) { } func TestReadSymbol(t *testing.T) { - test := func(str string, expected string, next tokenType) { + test := func(str string, expected string, next token) { t.Run(str, func(t *testing.T) { tok := tokenizeString(str) if err := tok.Next(); err != nil { @@ -182,7 +182,7 @@ func TestIsTripleQuote(t *testing.T) { t.Run(str, func(t *testing.T) { tok := tokenizeString(str) - ok, err := tok.isTripleQuote() + ok, err := tok.IsTripleQuote() if err != nil { t.Fatal(err) } @@ -254,7 +254,7 @@ func TestIsInf(t *testing.T) { } func TestScanForNumericType(t *testing.T) { - test := func(str string, ett tokenType) { + test := func(str string, ett token) { t.Run(str, func(t *testing.T) { tok := tokenizeString(str) c, err := tok.read() @@ -551,6 +551,15 @@ func TestReadUnread(t *testing.T) { read(t, tok, -1) } +func TestTokenToString(t *testing.T) { + for i := tokenError; i <= tokenCloseDoubleBrace+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected non-empty string for token %v", int(i)) + } + } +} + func read(t *testing.T, tok *tokenizer, expected int) { c, err := tok.read() if err != nil { diff --git a/unmarshal.go b/unmarshal.go index cda3a8a1..6b018fa1 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -8,7 +8,6 @@ import ( "reflect" "strconv" "strings" - "time" ) var ( @@ -22,6 +21,11 @@ func Unmarshal(data []byte, v interface{}) error { return NewDecoder(NewTextReader(bytes.NewReader(data))).DecodeTo(v) } +// UnmarshalStr unmarshals Ion data from a string to the given object. +func UnmarshalStr(data string, v interface{}) error { + return Unmarshal([]byte(data), v) +} + // A Decoder decodes go values from an Ion reader. type Decoder struct { r Reader @@ -88,7 +92,14 @@ func (d *Decoder) decode() (interface{}, error) { } func (d *Decoder) decodeInt() (interface{}, error) { - switch d.r.IntSize() { + size, err := d.r.IntSize() + if err != nil { + return nil, err + } + + switch size { + case NullInt: + return nil, nil case Int32: return d.r.IntValue() case Int64: @@ -353,8 +364,6 @@ func (d *Decoder) decodeDecimalTo(v reflect.Value) error { return fmt.Errorf("ion: cannot decode decimal to %v", v.Type().String()) } -var timeType = reflect.TypeOf(time.Time{}) - func (d *Decoder) decodeTimestampTo(v reflect.Value) error { val, err := d.r.TimeValue() if err != nil { diff --git a/unmarshal_test.go b/unmarshal_test.go index 26bd7dd8..763d8256 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -9,13 +9,11 @@ import ( "time" ) -func TestDecodeBool(t *testing.T) { +func TestUnmarshalBool(t *testing.T) { test := func(str string, eval bool) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val bool - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -30,14 +28,12 @@ func TestDecodeBool(t *testing.T) { test("true", true) test("false", false) } -func TestDecodeBoolPtr(t *testing.T) { +func TestUnmarshalBoolPtr(t *testing.T) { test := func(str string, eval interface{}) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var bval bool val := &bval - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -63,13 +59,11 @@ func TestDecodeBoolPtr(t *testing.T) { test("true", true) } -func TestDecodeInt(t *testing.T) { +func TestUnmarshalInt(t *testing.T) { testInt8 := func(str string, eval int8) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val int8 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -86,10 +80,8 @@ func TestDecodeInt(t *testing.T) { testInt16 := func(str string, eval int16) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val int16 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -106,10 +98,8 @@ func TestDecodeInt(t *testing.T) { testInt32 := func(str string, eval int32) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val int32 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -126,10 +116,8 @@ func TestDecodeInt(t *testing.T) { testInt := func(str string, eval int) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val int - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -146,10 +134,8 @@ func TestDecodeInt(t *testing.T) { testInt64 := func(str string, eval int64) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val int64 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -165,13 +151,11 @@ func TestDecodeInt(t *testing.T) { testInt64("-0x8000000000000000", -0x8000000000000000) } -func TestDecodeUint(t *testing.T) { +func TestUnmarshalUint(t *testing.T) { testUint8 := func(str string, eval uint8) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val uint8 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -187,10 +171,8 @@ func TestDecodeUint(t *testing.T) { testUint16 := func(str string, eval uint16) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val uint16 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -205,10 +187,8 @@ func TestDecodeUint(t *testing.T) { testUint32 := func(str string, eval uint32) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val uint32 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -223,10 +203,8 @@ func TestDecodeUint(t *testing.T) { testUint := func(str string, eval uint) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val uint - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -241,10 +219,8 @@ func TestDecodeUint(t *testing.T) { testUintptr := func(str string, eval uintptr) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val uintptr - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -259,10 +235,8 @@ func TestDecodeUint(t *testing.T) { testUint64 := func(str string, eval uint64) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val uint64 - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -276,13 +250,11 @@ func TestDecodeUint(t *testing.T) { testUint64("0xFFFFFFFFFFFFFFFF", 0xFFFFFFFFFFFFFFFF) } -func TestDecodeBigInt(t *testing.T) { +func TestUnmarshalBigInt(t *testing.T) { test := func(str string, eval *big.Int) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) - var val big.Int - err := d.DecodeTo(&val) + err := UnmarshalStr(str, &val) if err != nil { t.Fatal(err) } @@ -300,7 +272,7 @@ func TestDecodeBigInt(t *testing.T) { func TestDecodeFloat(t *testing.T) { test32 := func(str string, eval float32) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val float32 err := d.DecodeTo(&val) @@ -320,7 +292,7 @@ func TestDecodeFloat(t *testing.T) { test64 := func(str string, eval float64) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val float64 err := d.DecodeTo(&val) @@ -341,7 +313,7 @@ func TestDecodeFloat(t *testing.T) { func TestDecodeDecimal(t *testing.T) { test := func(str string, eval *Decimal) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val *Decimal err := d.DecodeTo(&val) @@ -362,7 +334,7 @@ func TestDecodeDecimal(t *testing.T) { func TestDecodeTimeTo(t *testing.T) { test := func(str string, eval time.Time) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val time.Time err := d.DecodeTo(&val) @@ -382,7 +354,7 @@ func TestDecodeTimeTo(t *testing.T) { func TestDecodeStringTo(t *testing.T) { test := func(str string, eval string) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val string err := d.DecodeTo(&val) @@ -404,7 +376,7 @@ func TestDecodeStringTo(t *testing.T) { func TestDecodeLobTo(t *testing.T) { testSlice := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val []byte err := d.DecodeTo(&val) @@ -424,7 +396,7 @@ func TestDecodeLobTo(t *testing.T) { testArray := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) var val [8]byte err := d.DecodeTo(&val) @@ -444,7 +416,7 @@ func TestDecodeLobTo(t *testing.T) { func TestDecodeStructTo(t *testing.T) { test := func(str string, val, eval interface{}) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) err := d.DecodeTo(val) if err != nil { t.Fatal(err) @@ -474,7 +446,7 @@ func TestDecodeStructTo(t *testing.T) { func TestDecodeListTo(t *testing.T) { test := func(str string, val, eval interface{}) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(str)) + d := NewDecoder(NewTextReaderStr(str)) err := d.DecodeTo(val) if err != nil { t.Fatal(err) @@ -506,7 +478,7 @@ func TestDecodeListTo(t *testing.T) { func TestDecode(t *testing.T) { test := func(data string, eval interface{}) { t.Run(data, func(t *testing.T) { - d := NewDecoder(NewTextReaderString(data)) + d := NewDecoder(NewTextReaderStr(data)) val, err := d.Decode() if err != nil { t.Fatal(err) diff --git a/writer.go b/writer.go index 6f43a640..974f0861 100644 --- a/writer.go +++ b/writer.go @@ -8,11 +8,11 @@ import ( // writer holds shared stuff for all writers. type writer struct { out io.Writer - ctx ctx + ctx ctxstack err error - fieldName string - typeAnnotations []string + fieldName string + annotations []string } // InStruct returns true if we're currently writing a struct. @@ -48,19 +48,18 @@ func (w *writer) FieldName(val string) { w.fieldName = val } -// TypeAnnotation adds a type annotation to the next value written. -func (w *writer) TypeAnnotation(val string) { +// Annotation adds an annotation to the next value written. +func (w *writer) Annotation(val string) { if w.err != nil { return } - w.typeAnnotations = append(w.typeAnnotations, val) + w.annotations = append(w.annotations, val) } -// TypeAnnotations adds one or more type annotations to the next value -// written. -func (w *writer) TypeAnnotations(val ...string) { +// Annotations adds one or more annotations to the next value written. +func (w *writer) Annotations(val ...string) { if w.err != nil { return } - w.typeAnnotations = append(w.typeAnnotations, val...) + w.annotations = append(w.annotations, val...) } From 5abeffa4c6cbdc6c37ebeac5d7b83222131335be Mon Sep 17 00:00:00 2001 From: David Murray Date: Wed, 14 Aug 2019 17:00:40 +1000 Subject: [PATCH 31/56] start of binaryreader --- binaryreader.go | 117 +++++++++++ bitstream.go | 488 ++++++++++++++++++++++++++++++++++++++++++++++ bitstream_test.go | 102 ++++++++++ fields.go | 6 +- 4 files changed, 710 insertions(+), 3 deletions(-) create mode 100644 binaryreader.go create mode 100644 bitstream.go create mode 100644 bitstream_test.go diff --git a/binaryreader.go b/binaryreader.go new file mode 100644 index 00000000..93e89f60 --- /dev/null +++ b/binaryreader.go @@ -0,0 +1,117 @@ +package ion + +import "fmt" + +type binaryReader struct { + bits bitstream + ctx ctxstack + eof bool + err error + + lst SymbolTable + fieldName string + annotations []string + valueType Type + value interface{} +} + +func (r *binaryReader) SymbolTable() SymbolTable { + return r.lst +} + +func (r *binaryReader) Next() bool { + if r.eof || r.err != nil { + return false + } + + done := false + for !done { + done, r.err = r.next() + if r.err != nil { + return false + } + } + + return !r.eof +} + +func (r *binaryReader) next() (bool, error) { + if err := r.bits.Next(); err != nil { + return false, err + } + + if r.bits.Code() == bitcodeFieldID { + if err := r.readFieldName(); err != nil { + return false, err + } + } + + if r.bits.Code() == bitcodeAnnotation { + if err := r.readAnnotations(); err != nil { + return false, err + } + } + + switch r.bits.Code() { + case bitcodeEOF: + r.eof = true + return true, nil + + case bitcodeBVM: + if err := r.readBVM(); err != nil { + return false, err + } + return false, nil + + } + panic(fmt.Sprintf("unsupported bitcode %v", r.bits.Code())) +} + +func (r *binaryReader) readBVM() error { + major, minor, err := r.bits.ReadBVM() + if err != nil { + return err + } + + if major != 1 && minor != 0 { + return fmt.Errorf("ion: unsupported version %v.%v", major, minor) + } + + r.lst = V1SystemSymbolTable + return nil +} + +func (r *binaryReader) readFieldName() error { + id, err := r.bits.ReadFieldID() + if err != nil { + return err + } + + r.fieldName = r.resolve(id) + + return r.bits.Next() +} + +func (r *binaryReader) readAnnotations() error { + ids, err := r.bits.ReadAnnotations() + if err != nil { + return err + } + + as := make([]string, len(ids)) + for i, id := range ids { + as[i] = r.resolve(id) + } + + r.annotations = as + + return r.bits.Next() +} + +func (r *binaryReader) resolve(id uint64) string { + s, ok := r.lst.FindByID(int(id)) + if !ok { + return fmt.Sprintf("$%v", id) + } + return s +} diff --git a/bitstream.go b/bitstream.go new file mode 100644 index 00000000..d3bcbc19 --- /dev/null +++ b/bitstream.go @@ -0,0 +1,488 @@ +package ion + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" +) + +type bss uint8 + +const ( + bssBeforeValue bss = iota + bssOnValue + bssBeforeFieldID + bssOnFieldID +) + +type bitcode uint8 + +const ( + bitcodeNone bitcode = iota + bitcodeEOF + bitcodeBVM + bitcodeNull + bitcodeBool + bitcodeInt + bitcodeNegInt + bitcodeFloat + bitcodeDecimal + bitcodeTimestamp + bitcodeSymbol + bitcodeString + bitcodeClob + bitcodeBlob + bitcodeList + bitcodeSexp + bitcodeStruct + bitcodeFieldID + bitcodeAnnotation +) + +func (b bitcode) String() string { + switch b { + case bitcodeNone: + return "none" + case bitcodeEOF: + return "eof" + case bitcodeBVM: + return "bvm" + case bitcodeBool: + return "bool" + case bitcodeInt: + return "int" + case bitcodeNegInt: + return "negint" + case bitcodeFloat: + return "float" + case bitcodeDecimal: + return "decimal" + case bitcodeTimestamp: + return "timestamp" + case bitcodeSymbol: + return "symbol" + case bitcodeString: + return "string" + case bitcodeClob: + return "clob" + case bitcodeBlob: + return "blob" + case bitcodeList: + return "list" + case bitcodeSexp: + return "sexp" + case bitcodeStruct: + return "struct" + case bitcodeFieldID: + return "fieldid" + case bitcodeAnnotation: + return "annotation" + default: + return fmt.Sprintf("", uint8(b)) + } +} + +type bitnode struct { + code bitcode + end uint64 +} + +type bitstack struct { + arr []bitnode +} + +func (b *bitstack) empty() bool { + return len(b.arr) == 0 +} + +func (b *bitstack) peek() bitnode { + if len(b.arr) == 0 { + return bitnode{} + } + return b.arr[len(b.arr)-1] +} + +func (b *bitstack) push(code bitcode, end uint64) { + b.arr = append(b.arr, bitnode{code, end}) +} + +func (b *bitstack) pop() { + if len(b.arr) == 0 { + panic("pop called on empty bitstack") + } + b.arr = b.arr[:len(b.arr)-1] +} + +type bitstream struct { + in *bufio.Reader + pos uint64 + state bss + stack bitstack + + code bitcode + null bool + len uint64 +} + +func (b *bitstream) Init(in io.Reader) { + b.in = bufio.NewReader(in) +} + +func (b *bitstream) InitBytes(in []byte) { + b.Init(bytes.NewReader(in)) +} + +func (b *bitstream) Code() bitcode { + return b.code +} + +func (b *bitstream) Null() bool { + return b.null +} + +func (b *bitstream) Len() uint64 { + return b.len +} + +func (b *bitstream) Next() error { + // If we have an unread value, skip over it to the next one. + switch b.state { + case bssOnValue, bssOnFieldID: + if err := b.SkipValue(); err != nil { + return err + } + } + + // If we're at the end of the current container, stop and make the user step out. + if !b.stack.empty() { + cur := b.stack.peek() + if b.pos == cur.end { + b.code = bitcodeEOF + return nil + } + } + + // If it's time to read a field id, do that. + if b.state == bssBeforeFieldID { + b.code = bitcodeFieldID + b.state = bssOnFieldID + return nil + } + + // Otherwise it's time to read a value. Read the tag byte. + c, err := b.read() + if err != nil { + return err + } + + // Found the actual end of the file. + if c == -1 { + b.code = bitcodeEOF + return nil + } + + code, len := parseTag(c) + if code == bitcodeNone { + return fmt.Errorf("ion: invalid tag byte: 0x%X", c) + } + + b.state = bssOnValue + + // This value is actually a BVM. It's invalid if we're not at the top level. + if code == bitcodeAnnotation && len == 0 { + if !b.stack.empty() { + return errors.New("ion: BVM in a container") + } + b.code = bitcodeBVM + b.len = 3 + return nil + } + + // This value is actually a null. + if len == 0x0F { + b.code = code + b.null = true + return nil + } + + // This value's actual len is encoded as a separate varUint. + if len == 0x0E { + len, err = b.readVarUint() + if err != nil { + return err + } + } + + b.code = code + b.len = len + return nil +} + +func (b *bitstream) SkipValue() error { + switch b.state { + case bssBeforeFieldID, bssBeforeValue: + return nil + + case bssOnFieldID: + if err := b.skipVarUint(); err != nil { + return err + } + b.state = bssBeforeValue + + case bssOnValue: + if b.len > 0 { + if err := b.skip(b.len); err != nil { + return err + } + + if b.stack.peek().code == bitcodeStruct { + b.state = bssBeforeFieldID + } else { + b.state = bssBeforeValue + } + } + } + + b.code = bitcodeNone + b.null = false + b.len = 0 + + return nil +} + +func (b *bitstream) StepIn() { + switch b.code { + case bitcodeStruct: + b.state = bssBeforeFieldID + + case bitcodeList, bitcodeSexp: + b.state = bssBeforeValue + + default: + panic(fmt.Sprintf("called StepIn with code=%v", b.code)) + } + + b.stack.push(b.code, b.pos+b.len) + b.code = bitcodeNone + b.len = 0 +} + +func (b *bitstream) StepOut() error { + if b.stack.empty() { + panic("called StepOut at top level") + } + + cur := b.stack.peek() + b.stack.pop() + + if cur.end < b.pos { + panic("end greater than b.pos") + } + + diff := cur.end - b.pos + if err := b.skip(diff); err != nil { + return err + } + + if b.stack.peek().code == bitcodeStruct { + b.state = bssBeforeFieldID + } else { + b.state = bssBeforeValue + } + + b.code = bitcodeNone + b.null = false + b.len = 0 + + return nil +} + +func (b *bitstream) ReadBVM() (byte, byte, error) { + if b.code != bitcodeBVM { + return 0, 0, errors.New("ion: not a bvm") + } + + major, err := b.read() + if err != nil { + return 0, 0, err + } + if major == -1 { + return 0, 0, errors.New("ion: unexpected end of input") + } + + minor, err := b.read() + if err != nil { + return 0, 0, err + } + if minor == -1 { + return 0, 0, errors.New("ion: unexpected end of input") + } + + end, err := b.read() + if err != nil { + return 0, 0, err + } + if end == -1 { + return 0, 0, errors.New("ion: unexpected end of input") + } + + if end != 0xEA { + return 0, 0, fmt.Errorf("ion: invalid BVM (0xE0 0x%X 0x%X 0x%X)", major, minor, end) + } + + b.state = bssBeforeValue + b.code = bitcodeNone + b.len = 0 + + return byte(major), byte(minor), nil +} + +func (b *bitstream) ReadAnnotations() ([]uint64, error) { + if b.code != bitcodeAnnotation { + return nil, errors.New("ion: not an annotation") + } + + alen, lenlen, err := b.readVarUintLen(b.len) + if err != nil { + return nil, err + } + + if b.len-lenlen <= alen { + // The size of the annotation is larger than the remaining free space inside the + // annotation container. + return nil, errors.New("ion: malformed annotation") + } + + as := []uint64{} + for alen > 0 { + id, idlen, err := b.readVarUintLen(alen) + if err != nil { + return nil, err + } + + as = append(as, id) + alen -= idlen + } + + b.state = bssBeforeValue + b.code = bitcodeNone + b.len = 0 + + return as, nil +} + +func (b *bitstream) ReadFieldID() (uint64, error) { + if b.code != bitcodeFieldID { + return 0, errors.New("ion: not a field id") + } + + id, err := b.readVarUint() + if err != nil { + return 0, err + } + + b.state = bssBeforeValue + b.code = bitcodeNone + + return id, nil +} + +func (b *bitstream) readVarUint() (uint64, error) { + r, _, err := b.readVarUintLen(10) + return r, err +} + +func (b *bitstream) readVarUintLen(max uint64) (uint64, uint64, error) { + r := uint64(0) + l := uint64(0) + + for { + c, err := b.read() + if err != nil { + return 0, 0, err + } + if c == -1 { + return 0, 0, errors.New("ion: unexpected end of input") + } + + l++ + + r = (r << 7) ^ uint64(c&0x7F) + if c&0x80 != 0 { + return r, l, nil + } + + if l == max { + return 0, 0, errors.New("ion: varuint too large") + } + } +} + +func (b *bitstream) skipVarUint() error { + for { + c, err := b.read() + if err != nil { + return err + } + if c == -1 { + return errors.New("ion: unexpected end of input") + } + if c&0x80 != 0 { + return nil + } + } +} + +var bitcodes = []bitcode{ + bitcodeNull, // 0x00 + bitcodeBool, // 0x10 + bitcodeInt, // 0x20 + bitcodeNegInt, // 0x30 + bitcodeFloat, // 0x40 + bitcodeDecimal, // 0x50 + bitcodeTimestamp, // 0x60 + bitcodeSymbol, // 0x70 + bitcodeString, // 0x80 + bitcodeClob, // 0x90 + bitcodeBlob, // 0xA0 + bitcodeList, // 0xB0 + bitcodeSexp, // 0xC0 + bitcodeStruct, // 0xD0 + bitcodeAnnotation, // 0xE0 +} + +func parseTag(c int) (bitcode, uint64) { + high := (c >> 4) & 0x0F + low := c & 0x0F + + code := bitcodeNone + if high < len(bitcodes) { + code = bitcodes[high] + } + + return code, uint64(low) +} + +func (b *bitstream) read() (int, error) { + c, err := b.in.ReadByte() + if err == io.EOF { + return -1, nil + } + if err != nil { + return 0, err + } + + b.pos++ + return int(c), nil +} + +func (b *bitstream) skip(n uint64) error { + _, err := b.in.Discard(int(n)) + if err == io.EOF { + return nil + } + b.pos += n + return err +} diff --git a/bitstream_test.go b/bitstream_test.go new file mode 100644 index 00000000..5e6f7231 --- /dev/null +++ b/bitstream_test.go @@ -0,0 +1,102 @@ +package ion + +import "testing" + +func TestBitstream(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ + 0x86, 0xBE, 0x8E, // imports:[ + 0xDD, // { + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" + 0x85, 0x21, 0x2A, // version: 42 + 0x88, 0x21, 0x64, // max_id: 100 + // }] + 0x87, 0xB8, // symbols: [ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" + // ] + // } + 0xD0, // {} + 0xEA, 0x81, 0xEE, 0xD7, // foo::{ + 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, + 0x88, 0x20, // max_id:0 + // } + } + + b := bitstream{} + b.InitBytes(ion) + + next := func(code bitcode, null bool, len uint64) { + if err := b.Next(); err != nil { + t.Fatal(err) + } + if b.Code() != code { + t.Errorf("expected code=%v, got %v", code, b.Code()) + } + if b.Null() != null { + t.Errorf("expected null=%v, got %v", null, b.Null()) + } + if b.Len() != len { + t.Errorf("expected len=%v, got %v", len, b.Len()) + } + } + + fieldid := func(eid uint64) { + id, err := b.ReadFieldID() + if err != nil { + t.Fatal(err) + } + if id != eid { + t.Errorf("expected %v, got %v", eid, id) + } + } + + next(bitcodeBVM, false, 3) + maj, min, err := b.ReadBVM() + if err != nil { + t.Fatal(err) + } + if maj != 1 && min != 0 { + t.Errorf("expected $ion_1.0, got $ion_%v.%v", maj, min) + } + + next(bitcodeAnnotation, false, 31) + ids, err := b.ReadAnnotations() + if err != nil { + t.Fatal(err) + } + if len(ids) != 1 || ids[0] != 3 { // $ion_symbol_table + t.Errorf("expected [3], got %v", ids) + } + + next(bitcodeStruct, false, 27) + b.StepIn() + { + next(bitcodeFieldID, false, 0) + fieldid(6) // imports + + next(bitcodeList, false, 14) + b.StepIn() + { + next(bitcodeStruct, false, 13) + } + if err := b.StepOut(); err != nil { + t.Fatal(err) + } + + next(bitcodeFieldID, false, 0) + // fieldid(7) // symbols + + next(bitcodeList, false, 8) + next(bitcodeEOF, false, 0) + } + if err := b.StepOut(); err != nil { + t.Fatal(err) + } + + next(bitcodeStruct, false, 0) + next(bitcodeAnnotation, false, 10) + next(bitcodeEOF, false, 0) + next(bitcodeEOF, false, 0) +} diff --git a/fields.go b/fields.go index 061113d9..2f4a8f06 100644 --- a/fields.go +++ b/fields.go @@ -42,7 +42,7 @@ func (f *fielder) inspect(t reflect.Type, path []int) { // Skip fields that are explicitly hidden by tag. continue } - name, opts := parseTag(tag) + name, opts := parseJSONTag(tag) newpath := make([]int, len(path)+1) copy(newpath, path) @@ -93,8 +93,8 @@ func visible(sf *reflect.StructField) bool { return exported } -// ParseTag parses a `json:"..."` field tag, returning the name and opts. -func parseTag(tag string) (string, string) { +// ParseJSONTag parses a `json:"..."` field tag, returning the name and opts. +func parseJSONTag(tag string) (string, string) { if idx := strings.Index(tag, ","); idx != -1 { // Ignore additional JSON options, at least for now. return tag[:idx], tag[idx+1:] From 5030a04c7f58c3b08e151c7f9ba871d9ff5c8ae3 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 19 Aug 2019 17:18:52 +1000 Subject: [PATCH 32/56] more binaryreader --- binaryreader.go | 298 +++++++++++++++++++++++++++++++++---- binaryreader_test.go | 123 ++++++++++++++++ bitstream.go | 343 +++++++++++++++++++++++++++++++++++-------- catalog.go | 20 +++ decimal.go | 5 + reader.go | 184 +++++++++++++++++++++++ symboltable.go | 93 +++++++++--- textreader.go | 293 ++++++++---------------------------- textreader_test.go | 103 ++++++++++++- 9 files changed, 1123 insertions(+), 339 deletions(-) create mode 100644 binaryreader_test.go create mode 100644 catalog.go create mode 100644 reader.go diff --git a/binaryreader.go b/binaryreader.go index 93e89f60..266db139 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -1,18 +1,35 @@ package ion -import "fmt" +import ( + "errors" + "fmt" + "io" +) type binaryReader struct { bits bitstream - ctx ctxstack - eof bool - err error + cat *Catalog + lst SymbolTable - lst SymbolTable - fieldName string - annotations []string - valueType Type - value interface{} + reader +} + +// NewBinaryReader creates a new binary reader. +func NewBinaryReader(in io.Reader, cat *Catalog) Reader { + r := &binaryReader{ + cat: cat, + } + r.bits.Init(in) + return r +} + +// NewBinaryReaderBytes creates a new binary reader for the given bytes. +func NewBinaryReaderBytes(in []byte, cat *Catalog) Reader { + r := &binaryReader{ + cat: cat, + } + r.bits.InitBytes(in) + return r } func (r *binaryReader) SymbolTable() SymbolTable { @@ -24,6 +41,8 @@ func (r *binaryReader) Next() bool { return false } + r.clear() + done := false for !done { done, r.err = r.next() @@ -40,31 +59,97 @@ func (r *binaryReader) next() (bool, error) { return false, err } - if r.bits.Code() == bitcodeFieldID { - if err := r.readFieldName(); err != nil { - return false, err - } - } - - if r.bits.Code() == bitcodeAnnotation { - if err := r.readAnnotations(); err != nil { - return false, err - } - } - switch r.bits.Code() { case bitcodeEOF: r.eof = true return true, nil case bitcodeBVM: - if err := r.readBVM(); err != nil { + err := r.readBVM() + return false, err + + case bitcodeFieldID: + err := r.readFieldName() + return false, err + + case bitcodeAnnotation: + err := r.readAnnotations() + return false, err + + case bitcodeFalse, bitcodeTrue: + r.valueType = BoolType + if !r.bits.Null() { + r.value = (r.bits.Code() == bitcodeTrue) + } + return true, nil + + case bitcodeInt, bitcodeNegInt: + r.valueType = IntType + if !r.bits.Null() { + val, err := r.bits.ReadInt() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeFloat: + r.valueType = FloatType + if !r.bits.Null() { + val, err := r.bits.ReadFloat() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeDecimal: + r.valueType = DecimalType + if !r.bits.Null() { + val, err := r.bits.ReadDecimal() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeString: + r.valueType = StringType + if !r.bits.Null() { + val, err := r.bits.ReadString() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeList: + r.valueType = ListType + if !r.bits.Null() { + r.value = ListType + } + return true, nil + + case bitcodeStruct: + r.valueType = StructType + if !r.bits.Null() { + r.value = StructType + } + + if len(r.annotations) > 0 && r.annotations[0] == "$ion_symbol_table" { + err := r.readLocalSymbolTable() return false, err } - return false, nil + return true, nil + + default: + panic(fmt.Sprintf("unsupported bitcode %v", r.bits.Code())) } - panic(fmt.Sprintf("unsupported bitcode %v", r.bits.Code())) } func (r *binaryReader) readBVM() error { @@ -88,8 +173,7 @@ func (r *binaryReader) readFieldName() error { } r.fieldName = r.resolve(id) - - return r.bits.Next() + return nil } func (r *binaryReader) readAnnotations() error { @@ -104,8 +188,7 @@ func (r *binaryReader) readAnnotations() error { } r.annotations = as - - return r.bits.Next() + return nil } func (r *binaryReader) resolve(id uint64) string { @@ -115,3 +198,162 @@ func (r *binaryReader) resolve(id uint64) string { } return s } + +func (r *binaryReader) readLocalSymbolTable() error { + if err := r.StepIn(); err != nil { + return err + } + + imps := []SharedSymbolTable{} + syms := []string{} + + for r.Next() { + var err error + switch r.FieldName() { + case "imports": + imps, err = r.readImports() + case "symbols": + syms, err = r.readSymbols() + } + if err != nil { + return err + } + } + + if err := r.StepOut(); err != nil { + return err + } + + r.lst = NewLocalSymbolTable(imps, syms) + return nil +} + +func (r *binaryReader) readImports() ([]SharedSymbolTable, error) { + if r.Type() != ListType { + return nil, nil + } + if err := r.StepIn(); err != nil { + return nil, err + } + + imps := []SharedSymbolTable{} + for r.Next() { + imp, err := r.readImport() + if err != nil { + return nil, err + } + imps = append(imps, imp) + } + + err := r.StepOut() + return imps, err +} + +func (r *binaryReader) readImport() (SharedSymbolTable, error) { + if r.Type() != StructType { + return nil, nil + } + if err := r.StepIn(); err != nil { + return nil, err + } + + name := "" + version := 0 + maxID := 0 + + for r.Next() { + var err error + switch r.FieldName() { + case "name": + name, err = r.StringValue() + case "version": + version, err = r.IntValue() + case "max_id": + maxID, err = r.IntValue() + } + if err != nil { + return nil, err + } + } + + if err := r.StepOut(); err != nil { + return nil, err + } + + if name == "" || version == 0 || maxID == 0 { + return nil, errors.New("ion: invalid import in local symbol table") + } + + var imp SharedSymbolTable + if r.cat != nil { + imp = r.cat.Find(name, version) + if imp != nil && imp.MaxID() != maxID { + // TODO: Better error. + return nil, errors.New("ion: maxID mismatch in imported symbol table") + } + } + + if imp == nil { + imp = &bogusSST{ + name: name, + version: version, + maxID: maxID, + } + } + + return imp, nil +} + +func (r *binaryReader) readSymbols() ([]string, error) { + if r.Type() != ListType { + return nil, nil + } + if err := r.StepIn(); err != nil { + return nil, err + } + + syms := []string{} + for r.Next() { + if r.Type() == StringType { + sym, err := r.StringValue() + if err != nil { + return nil, err + } + syms = append(syms, sym) + } + } + + err := r.StepOut() + + return syms, err +} + +func (r *binaryReader) StepIn() error { + if r.err != nil { + return r.err + } + switch r.valueType { + case ListType, SexpType, StructType: + default: + return errors.New("ion: StepIn called when not on a container") + } + + ctx := containerTypeToCtx(r.valueType) + r.ctx.push(ctx) + + r.clear() + r.bits.StepIn() + + return nil +} + +func (r *binaryReader) StepOut() error { + if err := r.bits.StepOut(); err != nil { + return err + } + + r.clear() + r.eof = false + + return nil +} diff --git a/binaryreader_test.go b/binaryreader_test.go new file mode 100644 index 00000000..e7310ba8 --- /dev/null +++ b/binaryreader_test.go @@ -0,0 +1,123 @@ +package ion + +import ( + "math" + "math/big" + "testing" +) + +func TestReadBinaryStructs(t *testing.T) { + r := readBinary([]byte{ + 0xD0, // {} + 0xDF, // null.struct + 0xEA, 0x81, 0xEE, 0xD7, // foo::{ + 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, + 0x88, 0x20, // max_id:0 + // } + }) + + _next(t, r, StructType) + _null(t, r, StructType) + _nextAF(t, r, StructType, "", []string{"foo"}) + _eof(t, r) +} + +func TestReadBinaryDecimals(t *testing.T) { + r := readBinary([]byte{ + 0x50, // 0. + 0x5F, // null.decimal + 0x51, 0xC3, // 0.000, aka 0 x 10^-3 + 0x53, 0xC3, 0x03, 0xE8, // 1.000, aka 1000 x 10^-3 + 0x53, 0xC3, 0x83, 0xE8, // -1.000, aka -1000 x 10^-3 + 0x53, 0x00, 0xE4, 0x01, // 1d100, aka 1 * 10^100 + 0x53, 0x00, 0xE4, 0x81, // -1d100, aka -1 * 10^100 + }) + + _decimal(t, r, MustParseDecimal("0.")) + _null(t, r, DecimalType) + _decimal(t, r, MustParseDecimal("0.000")) + _decimal(t, r, MustParseDecimal("1.000")) + _decimal(t, r, MustParseDecimal("-1.000")) + _decimal(t, r, MustParseDecimal("1d100")) + _decimal(t, r, MustParseDecimal("-1d100")) + _eof(t, r) +} + +func TestReadBinaryFloats(t *testing.T) { + r := readBinary([]byte{ + 0x40, // 0 + 0x4F, // null.float + 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 + 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 + 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf + 0x48, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -inf + 0x48, 0x7F, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // NaN + }) + + _float(t, r, 0) + _null(t, r, FloatType) + _float(t, r, math.MaxFloat64) + _float(t, r, -math.MaxFloat64) + _float(t, r, math.Inf(1)) + _float(t, r, math.Inf(-1)) + _float(t, r, math.NaN()) + _eof(t, r) +} + +func TestReadBinaryInts(t *testing.T) { + r := readBinary([]byte{ + 0x20, // 0 + 0x2F, // null.int + 0x21, 0x01, // 1 + 0x31, 0x01, // -1 + 0x28, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x7FFFFFFFFFFFFFFF + 0x38, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -0x7FFFFFFFFFFFFFFF + 0x28, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0x8000000000000000 + 0x38, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0x8000000000000000 + }) + + _int(t, r, 0) + _null(t, r, IntType) + _int(t, r, 1) + _int(t, r, -1) + _int64(t, r, math.MaxInt64) + _int64(t, r, -math.MaxInt64) + + i := new(big.Int).SetUint64(math.MaxInt64 + 1) + _bigInt(t, r, i) + _bigInt(t, r, new(big.Int).Neg(i)) + + _eof(t, r) +} + +func TestReadBinaryBools(t *testing.T) { + r := readBinary([]byte{ + 0x10, // false + 0x11, // true + 0x1F, // null.bool + }) + + _bool(t, r, false) + _bool(t, r, true) + _null(t, r, BoolType) + _eof(t, r) +} + +func readBinary(ion []byte) Reader { + prefix := []byte{ + 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ + 0x86, 0xBE, 0x8E, // imports:[ + 0xDD, // { + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" + 0x85, 0x21, 0x2A, // version: 42 + 0x88, 0x21, 0x64, // max_id: 100 + // }] + 0x87, 0xB8, // symbols: [ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" + // ] + // } + } + return NewBinaryReaderBytes(append(prefix, ion...), nil) +} diff --git a/bitstream.go b/bitstream.go index d3bcbc19..a9f4cf3b 100644 --- a/bitstream.go +++ b/bitstream.go @@ -3,9 +3,12 @@ package ion import ( "bufio" "bytes" + "encoding/binary" "errors" "fmt" "io" + "math" + "math/big" ) type bss uint8 @@ -24,7 +27,8 @@ const ( bitcodeEOF bitcodeBVM bitcodeNull - bitcodeBool + bitcodeFalse + bitcodeTrue bitcodeInt bitcodeNegInt bitcodeFloat @@ -49,8 +53,10 @@ func (b bitcode) String() string { return "eof" case bitcodeBVM: return "bvm" - case bitcodeBool: - return "bool" + case bitcodeFalse: + return "false" + case bitcodeTrue: + return "true" case bitcodeInt: return "int" case bitcodeNegInt: @@ -84,37 +90,6 @@ func (b bitcode) String() string { } } -type bitnode struct { - code bitcode - end uint64 -} - -type bitstack struct { - arr []bitnode -} - -func (b *bitstack) empty() bool { - return len(b.arr) == 0 -} - -func (b *bitstack) peek() bitnode { - if len(b.arr) == 0 { - return bitnode{} - } - return b.arr[len(b.arr)-1] -} - -func (b *bitstack) push(code bitcode, end uint64) { - b.arr = append(b.arr, bitnode{code, end}) -} - -func (b *bitstack) pop() { - if len(b.arr) == 0 { - panic("pop called on empty bitstack") - } - b.arr = b.arr[:len(b.arr)-1] -} - type bitstream struct { in *bufio.Reader pos uint64 @@ -190,18 +165,39 @@ func (b *bitstream) Next() error { b.state = bssOnValue - // This value is actually a BVM. It's invalid if we're not at the top level. - if code == bitcodeAnnotation && len == 0 { - if !b.stack.empty() { - return errors.New("ion: BVM in a container") + if code == bitcodeAnnotation { + switch len { + case 0: + // This value is actually a BVM. It's invalid if we're not at the top level. + if !b.stack.empty() { + return errors.New("ion: BVM in a container") + } + b.code = bitcodeBVM + b.len = 3 + return nil + + case 0x0F: + // No such thing as a null annotation. + return fmt.Errorf("ion: invalid tag byte: 0x%X", c) + } + } + + // Booleans are a bit special. + if code == bitcodeFalse { + switch len { + case 0, 0x0F: + break + case 1: + code = bitcodeTrue + len = 0 + default: + // Other forms of bool are invalid. + return fmt.Errorf("ion: invalid tag byte: 0x%X", c) } - b.code = bitcodeBVM - b.len = 3 - return nil } - // This value is actually a null. if len == 0x0F { + // This value is actually a null. b.code = code b.null = true return nil @@ -236,13 +232,8 @@ func (b *bitstream) SkipValue() error { if err := b.skip(b.len); err != nil { return err } - - if b.stack.peek().code == bitcodeStruct { - b.state = bssBeforeFieldID - } else { - b.state = bssBeforeValue - } } + b.state = b.stateAfterValue() } b.code = bitcodeNone @@ -282,16 +273,13 @@ func (b *bitstream) StepOut() error { } diff := cur.end - b.pos - if err := b.skip(diff); err != nil { - return err - } - - if b.stack.peek().code == bitcodeStruct { - b.state = bssBeforeFieldID - } else { - b.state = bssBeforeValue + if diff > 0 { + if err := b.skip(diff); err != nil { + return err + } } + b.state = b.stateAfterValue() b.code = bitcodeNone b.null = false b.len = 0 @@ -389,6 +377,144 @@ func (b *bitstream) ReadFieldID() (uint64, error) { return id, nil } +func (b *bitstream) ReadInt() (interface{}, error) { + switch b.code { + case bitcodeInt, bitcodeNegInt: + default: + return "", errors.New("ion: not an integer") + } + + bs, err := b.readN(b.len) + if err != nil { + return "", err + } + + var ret interface{} + switch { + case len(bs) == 0: + // Special case for zero. + ret = int64(0) + + case len(bs) < 8, (len(bs) == 8 && bs[0]&0x80 == 0): + // It'll fit in an int64. + i := int64(0) + for _, b := range bs { + i <<= 8 + i |= int64(b) + } + if b.code == bitcodeNegInt { + i = -i + } + ret = i + + default: + // Need to go big.Int. + i := new(big.Int).SetBytes(bs) + if b.code == bitcodeNegInt { + i = i.Neg(i) + } + ret = i + } + + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + + return ret, nil +} + +func (b *bitstream) ReadFloat() (float64, error) { + if b.code != bitcodeFloat { + return 0, errors.New("ion: not a float") + } + + bs, err := b.readN(b.len) + if err != nil { + return 0, err + } + + var ret float64 + switch len(bs) { + case 0: + ret = 0 + + case 4: + ui := binary.BigEndian.Uint32(bs) + ret = float64(math.Float32frombits(ui)) + + case 8: + ui := binary.BigEndian.Uint64(bs) + ret = math.Float64frombits(ui) + + default: + return 0, errors.New("ion: invalid float size") + } + + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + + return ret, nil +} + +func (b *bitstream) ReadDecimal() (*Decimal, error) { + if b.code != bitcodeDecimal { + return nil, errors.New("ion: not a decimal") + } + if b.len == 0 { + return NewDecimalInt(0), nil + } + + exp, explen, err := b.readVarIntLen(b.len) + if err != nil { + return nil, err + } + + coef := new(big.Int) + + coeflen := b.len - explen + if coeflen > 0 { + bs, err := b.readN(coeflen) + if err != nil { + return nil, err + } + + neg := (bs[0]&0x80 != 0) + bs[0] &= 0x7F + if bs[0] == 0 { + bs = bs[1:] + } + + coef.SetBytes(bs) + if neg { + coef.Neg(coef) + } + } + + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + + return NewDecimal(coef, int(exp)), nil +} + +func (b *bitstream) ReadString() (string, error) { + if b.code != bitcodeString { + return "", errors.New("ion: not a string") + } + + bs, err := b.readN(b.len) + if err != nil { + return "", err + } + + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + + return string(bs), nil +} + func (b *bitstream) readVarUint() (uint64, error) { r, _, err := b.readVarUintLen(10) return r, err @@ -407,9 +533,10 @@ func (b *bitstream) readVarUintLen(max uint64) (uint64, uint64, error) { return 0, 0, errors.New("ion: unexpected end of input") } + r <<= 7 + r ^= uint64(c & 0x7F) l++ - r = (r << 7) ^ uint64(c&0x7F) if c&0x80 != 0 { return r, l, nil } @@ -420,6 +547,50 @@ func (b *bitstream) readVarUintLen(max uint64) (uint64, uint64, error) { } } +func (b *bitstream) readVarIntLen(max uint64) (int64, uint64, error) { + c, err := b.read() + if err != nil { + return 0, 0, err + } + if c == -1 { + return 0, 0, errors.New("ion: unexpected end of input") + } + + sign := int64(1) + if c&0x40 != 0 { + sign = -1 + } + + r := int64(c & 0x3F) + l := uint64(1) + + if c&0x80 != 0 { + return r * sign, l, nil + } + + for { + c, err := b.read() + if err != nil { + return 0, 0, err + } + if c == -1 { + return 0, 0, errors.New("ion: unexpected end of input") + } + + r <<= 7 + r ^= int64(c & 0x7F) + l++ + + if c&0x80 != 0 { + return r * sign, l, nil + } + + if l == max { + return 0, 0, errors.New("ion: varint too large") + } + } +} + func (b *bitstream) skipVarUint() error { for { c, err := b.read() @@ -435,9 +606,16 @@ func (b *bitstream) skipVarUint() error { } } +func (b *bitstream) stateAfterValue() bss { + if b.stack.peek().code == bitcodeStruct { + return bssBeforeFieldID + } + return bssBeforeValue +} + var bitcodes = []bitcode{ bitcodeNull, // 0x00 - bitcodeBool, // 0x10 + bitcodeFalse, // 0x10 bitcodeInt, // 0x20 bitcodeNegInt, // 0x30 bitcodeFloat, // 0x40 @@ -465,6 +643,24 @@ func parseTag(c int) (bitcode, uint64) { return code, uint64(low) } +func (b *bitstream) readN(n uint64) ([]byte, error) { + if n == 0 { + return nil, nil + } + + bs := make([]byte, n) + _, err := b.in.Read(bs) + if err == io.EOF { + return nil, errors.New("ion: unexpected end of input") + } + if err != nil { + return nil, err + } + + b.pos += n + return bs, nil +} + func (b *bitstream) read() (int, error) { c, err := b.in.ReadByte() if err == io.EOF { @@ -486,3 +682,34 @@ func (b *bitstream) skip(n uint64) error { b.pos += n return err } + +type bitnode struct { + code bitcode + end uint64 +} + +type bitstack struct { + arr []bitnode +} + +func (b *bitstack) empty() bool { + return len(b.arr) == 0 +} + +func (b *bitstack) peek() bitnode { + if len(b.arr) == 0 { + return bitnode{} + } + return b.arr[len(b.arr)-1] +} + +func (b *bitstack) push(code bitcode, end uint64) { + b.arr = append(b.arr, bitnode{code, end}) +} + +func (b *bitstack) pop() { + if len(b.arr) == 0 { + panic("pop called on empty bitstack") + } + b.arr = b.arr[:len(b.arr)-1] +} diff --git a/catalog.go b/catalog.go new file mode 100644 index 00000000..315d0926 --- /dev/null +++ b/catalog.go @@ -0,0 +1,20 @@ +package ion + +import "fmt" + +// A Catalog stores shared symbol tables. +type Catalog struct { + ssts map[string]SharedSymbolTable +} + +// Add adds a shared symbol table to the catalog. +func (c *Catalog) Add(sst SharedSymbolTable) { + key := fmt.Sprintf("%v/%v", sst.Name(), sst.Version()) + c.ssts[key] = sst +} + +// Find attempts to find a shared symbol table with the given name and version. +func (c *Catalog) Find(name string, version int) SharedSymbolTable { + key := fmt.Sprintf("%v/%v", name, version) + return c.ssts[key] +} diff --git a/decimal.go b/decimal.go index 3f6fc68e..2499a403 100644 --- a/decimal.go +++ b/decimal.go @@ -25,6 +25,11 @@ func NewDecimal(n *big.Int, exp int) *Decimal { } } +// NewDecimalInt creates a new decimal whose value is equal to n. +func NewDecimalInt(n int64) *Decimal { + return NewDecimal(big.NewInt(n), 0) +} + // MustParseDecimal parses the given string into a decimal object, // panicing on error. func MustParseDecimal(in string) *Decimal { diff --git a/reader.go b/reader.go new file mode 100644 index 00000000..252c13a9 --- /dev/null +++ b/reader.go @@ -0,0 +1,184 @@ +package ion + +import ( + "errors" + "math" + "math/big" + "time" +) + +// A reader holds common implementation stuff to both the text and binary readers. +type reader struct { + ctx ctxstack + eof bool + err error + + fieldName string + annotations []string + valueType Type + value interface{} +} + +// Err returns the current error. +func (r *reader) Err() error { + return r.err +} + +// Type returns the current value's type. +func (r *reader) Type() Type { + return r.valueType +} + +// IsNull returns true if the current value is null. +func (r *reader) IsNull() bool { + return r.valueType != NoType && r.value == nil +} + +// FieldName returns the current value's field name. +func (r *reader) FieldName() string { + return r.fieldName +} + +// Annotations returns the current value's annotations. +func (r *reader) Annotations() []string { + return r.annotations +} + +// BoolValue returns the current value as a bool. +func (r *reader) BoolValue() (bool, error) { + if r.valueType == BoolType { + if r.value == nil { + return false, nil + } + return r.value.(bool), nil + } + return false, errors.New("ion: value is not a bool") +} + +// IntSize returns the size of the current int value. +func (r *reader) IntSize() (IntSize, error) { + if r.valueType != IntType { + return NullInt, errors.New("ion: value is not an int") + } + if r.value == nil { + return NullInt, nil + } + + if i, ok := r.value.(int64); ok { + if i > math.MaxInt32 || i < math.MinInt32 { + return Int64, nil + } + return Int32, nil + } + + return BigInt, nil +} + +// IntValue returns the current value as an int. +func (r *reader) IntValue() (int, error) { + i, err := r.Int64Value() + if err != nil { + return 0, err + } + if i > math.MaxInt32 || i < math.MinInt32 { + return 0, errors.New("ion: int value out of bounds") + } + return int(i), nil +} + +// Int64Value returns the current value as an int64. +func (r *reader) Int64Value() (int64, error) { + if r.valueType == IntType { + if r.value == nil { + return 0, nil + } + + if i, ok := r.value.(int64); ok { + return i, nil + } + + bi := r.value.(*big.Int) + if bi.IsInt64() { + return bi.Int64(), nil + } + + return 0, errors.New("ion: int value out of bounds") + } + return 0, errors.New("ion: value is not an int") +} + +// BigIntValue returns the current value as a big int. +func (r *reader) BigIntValue() (*big.Int, error) { + if r.valueType == IntType { + if r.value == nil { + return nil, nil + } + if i, ok := r.value.(int64); ok { + return big.NewInt(i), nil + } + return r.value.(*big.Int), nil + } + return nil, errors.New("ion: value is not an int") +} + +// FloatValue returns the current value as a float. +func (r *reader) FloatValue() (float64, error) { + if r.valueType == FloatType { + if r.value == nil { + return 0.0, nil + } + return r.value.(float64), nil + } + return 0.0, errors.New("ion: value is not a float") +} + +// DecimalValue returns the current value as a Decimal. +func (r *reader) DecimalValue() (*Decimal, error) { + if r.valueType == DecimalType { + if r.value == nil { + return nil, nil + } + return r.value.(*Decimal), nil + } + return nil, errors.New("ion: value is not a decimal") +} + +// TimeValue returns the current value as a time. +func (r *reader) TimeValue() (time.Time, error) { + if r.valueType == TimestampType { + if r.value == nil { + return time.Time{}, nil + } + return r.value.(time.Time), nil + } + return time.Time{}, errors.New("ion: value is not a timestamp") +} + +// StringValue returns the current value as a string. +func (r *reader) StringValue() (string, error) { + if r.valueType == StringType || r.valueType == SymbolType { + if r.value == nil { + return "", nil + } + return r.value.(string), nil + } + return "", errors.New("ion: value is not a string") +} + +// ByteValue returns the current value as a byte slice. +func (r *reader) ByteValue() ([]byte, error) { + if r.valueType == BlobType || r.valueType == ClobType { + if r.value == nil { + return nil, nil + } + return r.value.([]byte), nil + } + return nil, errors.New("ion: value is not a byte array") +} + +func (r *reader) clear() { + r.fieldName = "" + r.annotations = nil + r.valueType = NoType + r.value = nil +} diff --git a/symboltable.go b/symboltable.go index 03c18633..284f4311 100644 --- a/symboltable.go +++ b/symboltable.go @@ -1,6 +1,7 @@ package ion import ( + "errors" "strings" ) @@ -28,7 +29,7 @@ type SharedSymbolTable interface { Version() int } -type sharedSymbolTable struct { +type sst struct { name string version int symbols []string @@ -46,7 +47,7 @@ func NewSharedSymbolTable(name string, version int, symbols []string) SharedSymb index, copy := buildIndex(symbols, 0) - return &sharedSymbolTable{ + return &sst{ name: name, version: version, symbols: copy, @@ -68,31 +69,31 @@ func buildIndex(symbols []string, offset int) (map[string]int, []string) { return index, copy } -func (s *sharedSymbolTable) Name() string { +func (s *sst) Name() string { return s.name } -func (s *sharedSymbolTable) Version() int { +func (s *sst) Version() int { return s.version } -func (s *sharedSymbolTable) MaxID() int { +func (s *sst) MaxID() int { return len(s.symbols) } -func (s *sharedSymbolTable) FindByName(sym string) (int, bool) { +func (s *sst) FindByName(sym string) (int, bool) { id, ok := s.index[sym] return id, ok } -func (s *sharedSymbolTable) FindByID(id int) (string, bool) { +func (s *sst) FindByID(id int) (string, bool) { if id <= 0 || id > len(s.symbols) { return "", false } return s.symbols[id-1], true } -func (s *sharedSymbolTable) WriteTo(w Writer) error { +func (s *sst) WriteTo(w Writer) error { w.Annotation("$ion_shared_symbol_table") w.BeginStruct() @@ -115,7 +116,7 @@ func (s *sharedSymbolTable) WriteTo(w Writer) error { return w.Err() } -func (s *sharedSymbolTable) String() string { +func (s *sst) String() string { buf := strings.Builder{} w := NewTextWriter(&buf) @@ -137,9 +138,61 @@ var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ "$ion_shared_symbol_table", }) +// A BogusSST represents an SST imported by an LST that cannot be found in the +// local catalog. It exists to reserve some part of the symbol ID space so other +// symbol tables get mapped to the right IDs. +type bogusSST struct { + name string + version int + maxID int +} + +func (s *bogusSST) Name() string { + return s.name +} + +func (s *bogusSST) Version() int { + return s.version +} + +func (s *bogusSST) MaxID() int { + return s.maxID +} + +func (s *bogusSST) FindByName(sym string) (int, bool) { + return 0, false +} + +func (s *bogusSST) FindByID(id int) (string, bool) { + return "", false +} + +func (s *bogusSST) WriteTo(w Writer) error { + return errors.New("ion: bogusSST does not implement WriteTo") +} + +func (s *bogusSST) String() string { + buf := strings.Builder{} + w := NewTextWriter(&buf) + w.Annotations("$ion_shared_symbol_table", "bogus") + w.BeginStruct() + + w.FieldName("name") + w.WriteString(s.name) + + w.FieldName("version") + w.WriteInt(int64(s.version)) + + w.FieldName("max_id") + w.WriteInt(int64(s.maxID)) + + w.EndStruct() + return buf.String() +} + // A LocalSymbolTable is transmitted in-band along with the binary data // it describes. It may include SharedSymbolTables by reference. -type localSymbolTable struct { +type lst struct { imports []SharedSymbolTable offsets []int maxImportID int @@ -153,7 +206,7 @@ func NewLocalSymbolTable(imports []SharedSymbolTable, symbols []string) SymbolTa imps, offsets, maxID := processImports(imports) index, copy := buildIndex(symbols, maxID) - return &localSymbolTable{ + return &lst{ imports: imps, offsets: offsets, maxImportID: maxID, @@ -183,11 +236,11 @@ func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []int, in return imps, offsets, maxID } -func (t *localSymbolTable) MaxID() int { +func (t *lst) MaxID() int { return t.maxImportID + len(t.symbols) } -func (t *localSymbolTable) FindByName(s string) (int, bool) { +func (t *lst) FindByName(s string) (int, bool) { for i, imp := range t.imports { if id, ok := imp.FindByName(s); ok { return t.offsets[i] + id, true @@ -201,7 +254,7 @@ func (t *localSymbolTable) FindByName(s string) (int, bool) { return 0, false } -func (t *localSymbolTable) FindByID(id int) (string, bool) { +func (t *lst) FindByID(id int) (string, bool) { if id <= 0 { return "", false } @@ -218,7 +271,7 @@ func (t *localSymbolTable) FindByID(id int) (string, bool) { return "", false } -func (t *localSymbolTable) findByIDInImports(id int) (string, bool) { +func (t *lst) findByIDInImports(id int) (string, bool) { i := 1 off := 0 @@ -232,7 +285,7 @@ func (t *localSymbolTable) findByIDInImports(id int) (string, bool) { return t.imports[i-1].FindByID(id - off) } -func (t *localSymbolTable) WriteTo(w Writer) error { +func (t *lst) WriteTo(w Writer) error { if len(t.imports) == 1 && len(t.symbols) == 0 { return nil } @@ -276,7 +329,7 @@ func (t *localSymbolTable) WriteTo(w Writer) error { return w.Err() } -func (t *localSymbolTable) String() string { +func (t *lst) String() string { buf := strings.Builder{} w := NewTextWriter(&buf) @@ -296,14 +349,14 @@ type SymbolTableBuilder interface { } type symbolTableBuilder struct { - localSymbolTable + lst } // NewSymbolTableBuilder creates a new symbol table builder with the given imports. func NewSymbolTableBuilder(imports ...SharedSymbolTable) SymbolTableBuilder { imps, offsets, maxID := processImports(imports) return &symbolTableBuilder{ - localSymbolTable{ + lst{ imports: imps, offsets: offsets, maxImportID: maxID, @@ -331,7 +384,7 @@ func (b *symbolTableBuilder) Build() SymbolTable { index[s] = i } - return &localSymbolTable{ + return &lst{ imports: b.imports, offsets: b.offsets, maxImportID: b.maxImportID, diff --git a/textreader.go b/textreader.go index 8efabf9b..3e36898c 100644 --- a/textreader.go +++ b/textreader.go @@ -7,10 +7,8 @@ import ( "fmt" "io" "math" - "math/big" "strconv" "strings" - "time" ) // trs is the state of the text reader. @@ -45,14 +43,8 @@ func (s trs) String() string { type textReader struct { tok tokenizer state trs - ctx ctxstack - eof bool - err error - fieldName string - annotations []string - valueType Type - value interface{} + reader debug bool } @@ -99,10 +91,7 @@ func (t *textReader) Next() bool { fmt.Println("ion: state after finish =", t.state) } - t.fieldName = "" - t.annotations = nil - t.valueType = NoType - t.value = nil + t.clear() // Loop until we've consumed enough tokens to know what the next value is. for { @@ -141,225 +130,6 @@ func (t *textReader) Next() bool { } } -// Err returns the current error. -func (t *textReader) Err() error { - return t.err -} - -// Type returns the current value's type. -func (t *textReader) Type() Type { - return t.valueType -} - -// IsNull returns true if the current value is null. -func (t *textReader) IsNull() bool { - return t.valueType != NoType && t.value == nil -} - -// FieldName returns the current value's field name. -func (t *textReader) FieldName() string { - return t.fieldName -} - -// Annotations returns the current value's annotations. -func (t *textReader) Annotations() []string { - return t.annotations -} - -// StepIn steps in to a container. -func (t *textReader) StepIn() error { - if t.err != nil { - return t.err - } - if t.state != trsBeforeContainer { - return errors.New("ion: StepIn called when not on a container") - } - - ctx := containerTypeToCtx(t.valueType) - t.ctx.push(ctx) - - if ctx == ctxInStruct { - t.state = trsBeforeFieldName - } else { - t.state = trsBeforeTypeAnnotations - } - - t.tok.SetFinished() - return nil -} - -// StepOut steps out of a container. -func (t *textReader) StepOut() error { - if t.err != nil { - return t.err - } - - ctx := t.ctx.peek() - if ctx == ctxAtTopLevel { - return errors.New("ion: StepOut called at top level") - } - ctype := ctxToContainerType(ctx) - - // Finish off whatever value *inside* the container that we're currently reading. - _, err := t.tok.FinishValue() - if err != nil { - t.explode(err) - return err - } - - // If we haven't seen the end of the container yet, skip values until we find it. - if !t.eof { - if err := t.tok.SkipContainerContents(ctype); err != nil { - t.explode(err) - return err - } - } - - t.ctx.pop() - t.state = t.stateAfterValue() - t.valueType = NoType - t.value = nil - t.eof = false - - return nil -} - -// BoolValue returns the current value as a bool. -func (t *textReader) BoolValue() (bool, error) { - if t.valueType == BoolType { - if t.value == nil { - return false, nil - } - return t.value.(bool), nil - } - return false, errors.New("ion: value is not a bool") -} - -// IntSize returns the size of the current int value. -func (t *textReader) IntSize() (IntSize, error) { - if t.valueType != IntType { - return NullInt, errors.New("ion: value is not an int") - } - if t.value == nil { - return NullInt, nil - } - - if i, ok := t.value.(int64); ok { - if i > math.MaxInt32 || i < math.MinInt32 { - return Int64, nil - } - return Int32, nil - } - - return BigInt, nil -} - -// IntValue returns the current value as an int. -func (t *textReader) IntValue() (int, error) { - i, err := t.Int64Value() - if err != nil { - return 0, err - } - if i > math.MaxInt32 || i < math.MinInt32 { - return 0, errors.New("ion: int value out of bounds") - } - return int(i), nil -} - -// Int64Value returns the current value as an int64. -func (t *textReader) Int64Value() (int64, error) { - if t.valueType == IntType { - if t.value == nil { - return 0, nil - } - - if i, ok := t.value.(int64); ok { - return i, nil - } - - bi := t.value.(*big.Int) - if bi.IsInt64() { - return bi.Int64(), nil - } - - return 0, errors.New("ion: int value out of bounds") - } - return 0, errors.New("ion: value is not an int") -} - -// BigIntValue returns the current value as a big int. -func (t *textReader) BigIntValue() (*big.Int, error) { - if t.valueType == IntType { - if t.value == nil { - return nil, nil - } - if i, ok := t.value.(int64); ok { - return big.NewInt(i), nil - } - return t.value.(*big.Int), nil - } - return nil, errors.New("ion: value is not an int") -} - -// FloatValue returns the current value as a float. -func (t *textReader) FloatValue() (float64, error) { - if t.valueType == FloatType { - if t.value == nil { - return 0.0, nil - } - return t.value.(float64), nil - } - return 0.0, errors.New("ion: value is not a float") -} - -// DecimalValue returns the current value as a Decimal. -func (t *textReader) DecimalValue() (*Decimal, error) { - switch t.valueType { - case DecimalType: - if t.value == nil { - return nil, nil - } - return t.value.(*Decimal), nil - } - return nil, errors.New("ion: value is not a decimal") -} - -// TimeValue returns the current value as a time. -func (t *textReader) TimeValue() (time.Time, error) { - switch t.valueType { - case TimestampType: - if t.value == nil { - return time.Time{}, nil - } - return t.value.(time.Time), nil - } - return time.Time{}, errors.New("ion: value is not a timestamp") -} - -// StringValue returns the current value as a string. -func (t *textReader) StringValue() (string, error) { - switch t.valueType { - case StringType, SymbolType: - if t.value == nil { - return "", nil - } - return t.value.(string), nil - } - return "", errors.New("ion: value is not a string") -} - -// ByteValue returns the current value as a byte slice. -func (t *textReader) ByteValue() ([]byte, error) { - switch t.valueType { - case BlobType, ClobType: - if t.value == nil { - return nil, nil - } - return t.value.([]byte), nil - } - return nil, errors.New("ion: value is not a byte array") -} - // NextAfterValue moves to the next value when we're in the // AfterValue state. func (t *textReader) nextAfterValue() (bool, error) { @@ -554,6 +324,65 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { } } +// StepIn steps in to a container. +func (t *textReader) StepIn() error { + if t.err != nil { + return t.err + } + if t.state != trsBeforeContainer { + return errors.New("ion: StepIn called when not on a container") + } + + ctx := containerTypeToCtx(t.valueType) + t.ctx.push(ctx) + + if ctx == ctxInStruct { + t.state = trsBeforeFieldName + } else { + t.state = trsBeforeTypeAnnotations + } + + t.clear() + + t.tok.SetFinished() + return nil +} + +// StepOut steps out of a container. +func (t *textReader) StepOut() error { + if t.err != nil { + return t.err + } + + ctx := t.ctx.peek() + if ctx == ctxAtTopLevel { + return errors.New("ion: StepOut called at top level") + } + ctype := ctxToContainerType(ctx) + + // Finish off whatever value *inside* the container that we're currently reading. + _, err := t.tok.FinishValue() + if err != nil { + t.explode(err) + return err + } + + // If we haven't seen the end of the container yet, skip values until we find it. + if !t.eof { + if err := t.tok.SkipContainerContents(ctype); err != nil { + t.explode(err) + return err + } + } + + t.ctx.pop() + t.state = t.stateAfterValue() + t.clear() + t.eof = false + + return nil +} + // VerifyUnquotedSymbol checks for certain 'special' values that are returned from // the tokenizer as symbols but cannot be used as field names or annotations. func verifyUnquotedSymbol(val string, ctx string) error { diff --git a/textreader_test.go b/textreader_test.go index fec0108c..334c5f1b 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -223,7 +223,7 @@ func TestTimestamps(t *testing.T) { testA("foo::'bar'::2001-01-01T00:00:00.000Z", []string{"foo", "bar"}, et) } -func TestDoubles(t *testing.T) { +func TestDecimals(t *testing.T) { testA := func(str string, etas []string, eval string) { t.Run(str, func(t *testing.T) { ee := MustParseDecimal(eval) @@ -466,6 +466,87 @@ func _containerAF(t *testing.T, r Reader, et Type, efn string, etas []string, f } } +func _int(t *testing.T, r Reader, eval int) { + _intAF(t, r, "", nil, eval) +} + +func _intAF(t *testing.T, r Reader, efn string, etas []string, eval int) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != Int32 { + t.Errorf("expected size=Int32, got %v", size) + } + + val, err := r.IntValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _int64(t *testing.T, r Reader, eval int64) { + _int64AF(t, r, "", nil, eval) +} + +func _int64AF(t *testing.T, r Reader, efn string, etas []string, eval int64) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != Int64 { + t.Errorf("expected size=Int64, got %v", size) + } + + val, err := r.Int64Value() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _bigInt(t *testing.T, r Reader, eval *big.Int) { + _bigIntAF(t, r, "", nil, eval) +} + +func _bigIntAF(t *testing.T, r Reader, efn string, etas []string, eval *big.Int) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != BigInt { + t.Errorf("expected size=BigInt, got %v", size) + } + + val, err := r.BigIntValue() + if err != nil { + t.Fatal(err) + } + if val.Cmp(eval) != 0 { + t.Errorf("expected %v, got %v", eval, val) + } +} + func _float(t *testing.T, r Reader, eval float64) { _floatAF(t, r, "", nil, eval) } @@ -490,6 +571,26 @@ func _floatAF(t *testing.T, r Reader, efn string, etas []string, eval float64) { } } +func _decimal(t *testing.T, r Reader, eval *Decimal) { + _decimalAF(t, r, "", nil, eval) +} + +func _decimalAF(t *testing.T, r Reader, efn string, etas []string, eval *Decimal) { + _nextAF(t, r, DecimalType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.decimal", eval) + } + + val, err := r.DecimalValue() + if err != nil { + t.Fatal(err) + } + + if !eval.Equal(val) { + t.Errorf("expected %v, got %v", eval, val) + } +} + func _string(t *testing.T, r Reader, eval string) { _stringAF(t, r, "", nil, eval) } From 6996d989819cea6348f2e00ff56c90cffa22a2e6 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 26 Aug 2019 13:33:45 +1000 Subject: [PATCH 33/56] more binary reading --- binaryreader.go | 51 ++++++++++++++ binaryreader_test.go | 162 +++++++++++++++++++++++++++++++++++++++++-- bitstream.go | 133 +++++++++++++++++++++++++++++++---- decimal.go | 18 +++++ decimal_test.go | 24 +++++++ textreader_test.go | 58 ++++++++++++++++ 6 files changed, 429 insertions(+), 17 deletions(-) diff --git a/binaryreader.go b/binaryreader.go index 266db139..63cbb27e 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -116,6 +116,28 @@ func (r *binaryReader) next() (bool, error) { } return true, nil + case bitcodeTimestamp: + r.valueType = TimestampType + if !r.bits.Null() { + val, err := r.bits.ReadTimestamp() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeSymbol: + r.valueType = SymbolType + if !r.bits.Null() { + id, err := r.bits.ReadSymbol() + if err != nil { + return false, err + } + r.value = r.resolve(id) + } + return true, nil + case bitcodeString: r.valueType = StringType if !r.bits.Null() { @@ -127,6 +149,28 @@ func (r *binaryReader) next() (bool, error) { } return true, nil + case bitcodeClob: + r.valueType = ClobType + if !r.bits.Null() { + val, err := r.bits.ReadBytes() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeBlob: + r.valueType = BlobType + if !r.bits.Null() { + val, err := r.bits.ReadBytes() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + case bitcodeList: r.valueType = ListType if !r.bits.Null() { @@ -134,6 +178,13 @@ func (r *binaryReader) next() (bool, error) { } return true, nil + case bitcodeSexp: + r.valueType = SexpType + if !r.bits.Null() { + r.value = SexpType + } + return true, nil + case bitcodeStruct: r.valueType = StructType if !r.bits.Null() { diff --git a/binaryreader_test.go b/binaryreader_test.go index e7310ba8..a838c344 100644 --- a/binaryreader_test.go +++ b/binaryreader_test.go @@ -1,24 +1,178 @@ package ion import ( + "fmt" "math" "math/big" "testing" + "time" ) func TestReadBinaryStructs(t *testing.T) { r := readBinary([]byte{ - 0xD0, // {} 0xDF, // null.struct + 0xD0, // {} 0xEA, 0x81, 0xEE, 0xD7, // foo::{ - 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, + 0x84, 0xE3, 0x81, 0xEF, 0xD0, // name:bar::{}, 0x88, 0x20, // max_id:0 // } }) - _next(t, r, StructType) _null(t, r, StructType) - _nextAF(t, r, StructType, "", []string{"foo"}) + _struct(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _structAF(t, r, "", []string{"foo"}, func(t *testing.T, r Reader) { + _structAF(t, r, "name", []string{"bar"}, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _intAF(t, r, "max_id", nil, 0) + }) + _eof(t, r) +} + +func TestReadBinarySexps(t *testing.T) { + r := readBinary([]byte{ + 0xCF, + 0xC3, 0xC1, 0xC0, 0xC0, + }) + + _null(t, r, SexpType) + _sexp(t, r, func(t *testing.T, r Reader) { + _sexp(t, r, func(t *testing.T, r Reader) { + _sexp(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + }) + _sexp(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _eof(t, r) + }) + _eof(t, r) +} + +func TestReadBinaryLists(t *testing.T) { + r := readBinary([]byte{ + 0xBF, + 0xB3, 0xB1, 0xB0, 0xB0, + }) + + _null(t, r, ListType) + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + }) + _list(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _eof(t, r) + }) + _eof(t, r) +} + +func TestReadBinaryBlobs(t *testing.T) { + r := readBinary([]byte{ + 0xAF, + 0xA0, + 0xA1, 'a', + 0xAE, 0x96, + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', + ' ', 'l', 'o', 'n', 'g', 'e', 'r', + }) + + _null(t, r, BlobType) + _blob(t, r, []byte("")) + _blob(t, r, []byte("a")) + _blob(t, r, []byte("hello world but longer")) + _eof(t, r) +} + +func TestReadBinaryClobs(t *testing.T) { + r := readBinary([]byte{ + 0x9F, + 0x90, // {{}} + 0x91, 'a', // {{a}} + 0x9E, 0x96, + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', + ' ', 'l', 'o', 'n', 'g', 'e', 'r', + }) + + _null(t, r, ClobType) + _clob(t, r, []byte("")) + _clob(t, r, []byte("a")) + _clob(t, r, []byte("hello world but longer")) + _eof(t, r) +} + +func TestReadBinaryStrings(t *testing.T) { + r := readBinary([]byte{ + 0x8F, + 0x80, // "" + 0x81, 'a', // "a" + 0x8E, 0x96, + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', + ' ', 'l', 'o', 'n', 'g', 'e', 'r', + }) + + _null(t, r, StringType) + _string(t, r, "") + _string(t, r, "a") + _string(t, r, "hello world but longer") + _eof(t, r) +} + +func TestReadBinarySymbols(t *testing.T) { + r := readBinary([]byte{ + 0x7F, + 0x70, // $0 + 0x71, 0x01, // $ion + 0x71, 0x0A, // $10 + 0x71, 0x6E, // foo + 0xE4, 0x81, 0xEE, 0x71, 0x6F, // foo::bar + 0x78, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // ${maxint64} + }) + + _null(t, r, SymbolType) + _symbol(t, r, "$0") + _symbol(t, r, "$ion") + _symbol(t, r, "$10") + _symbol(t, r, "foo") + _symbolAF(t, r, "", []string{"foo"}, "bar") + _symbol(t, r, fmt.Sprintf("$%v", uint64(math.MaxUint64))) + _eof(t, r) +} + +func TestReadBinaryTimestamps(t *testing.T) { + r := readBinary([]byte{ + 0x6F, + 0x62, 0x80, 0x81, // 0001T + 0x63, 0x80, 0x81, 0x81, // 0001-01T + 0x64, 0x80, 0x81, 0x81, 0x81, // 0001-01-01T + 0x66, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, // 0001-01-01T00:00Z + 0x67, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80, // 0001-01-01T00:00:00Z + 0x6E, 0x8E, // 0x0E-bit timestamp + 0x04, 0xD8, // offset: +600 minutes (+10:00) + 0x0F, 0xE3, // year: 2019 + 0x88, // month: 8 + 0x84, // day: 4 + 0x88, // hour: 8 utc (18 local) + 0x8F, // minute: 15 + 0xAB, // second: 43 + 0xC9, // exp: -9 + 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 + }) + + _null(t, r, TimestampType) + + for i := 0; i < 5; i++ { + _timestamp(t, r, time.Time{}) + } + + nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") + _timestamp(t, r, nowish) _eof(t, r) } diff --git a/bitstream.go b/bitstream.go index a9f4cf3b..5b64a3f1 100644 --- a/bitstream.go +++ b/bitstream.go @@ -9,6 +9,7 @@ import ( "io" "math" "math/big" + "time" ) type bss uint8 @@ -465,37 +466,106 @@ func (b *bitstream) ReadDecimal() (*Decimal, error) { return NewDecimalInt(0), nil } - exp, explen, err := b.readVarIntLen(b.len) + d, err := b.readDecimal(b.len) if err != nil { return nil, err } + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + + return d, nil +} + +func (b *bitstream) ReadTimestamp() (time.Time, error) { + if b.code != bitcodeTimestamp { + return time.Time{}, errors.New("ion: not a timestamp") + } + + offset, olen, err := b.readVarIntLen(b.len) + if err != nil { + return time.Time{}, err + } + b.len -= olen + + ts := []int{1, 1, 1, 0, 0, 0} + for i := 0; b.len > 0 && i < 6; i++ { + val, vlen, err := b.readVarUintLen(b.len) + if err != nil { + return time.Time{}, err + } + b.len -= vlen + ts[i] = int(val) + } + + nsecs, err := b.readNsecs() + if err != nil { + return time.Time{}, err + } + + utc := time.Date(ts[0], time.Month(ts[1]), ts[2], ts[3], ts[4], ts[5], int(nsecs), time.UTC) + + return utc.In(time.FixedZone("fixed", int(offset)*60)), nil +} + +func (b *bitstream) readNsecs() (int64, error) { + d, err := b.readDecimal(b.len) + if err != nil { + return 0, err + } + return d.ShiftL(9).Trunc() +} + +func (b *bitstream) readDecimal(len uint64) (*Decimal, error) { + exp := int64(0) coef := new(big.Int) - coeflen := b.len - explen - if coeflen > 0 { - bs, err := b.readN(coeflen) + if len > 0 { + val, vlen, err := b.readVarIntLen(len) if err != nil { return nil, err } + exp = val + len -= vlen + } - neg := (bs[0]&0x80 != 0) - bs[0] &= 0x7F - if bs[0] == 0 { - bs = bs[1:] + if len > 0 { + if err := b.readIntTo(len, coef); err != nil { + return nil, err } + } - coef.SetBytes(bs) - if neg { - coef.Neg(coef) - } + return NewDecimal(coef, int(exp)), nil +} + +func (b *bitstream) ReadSymbol() (uint64, error) { + if b.code != bitcodeSymbol { + return 0, errors.New("ion: not a symbol") + } + + bs, err := b.readN(b.len) + if err != nil { + return 0, err } b.state = b.stateAfterValue() b.code = bitcodeNone b.len = 0 - return NewDecimal(coef, int(exp)), nil + if len(bs) == 0 { + return 0, nil + } + if len(bs) > 8 { + return 0, errors.New("ion: symbol id out of range") + } + + ret := uint64(0) + for _, b := range bs { + ret <<= 8 + ret |= uint64(b) + } + return ret, nil } func (b *bitstream) ReadString() (string, error) { @@ -515,6 +585,43 @@ func (b *bitstream) ReadString() (string, error) { return string(bs), nil } +func (b *bitstream) ReadBytes() ([]byte, error) { + if b.code != bitcodeClob && b.code != bitcodeBlob { + return nil, errors.New("ion: not a lob") + } + + bs, err := b.readN(b.len) + if err != nil { + return nil, err + } + + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + + return bs, nil +} + +func (b *bitstream) readIntTo(len uint64, ret *big.Int) error { + bs, err := b.readN(len) + if err != nil { + return err + } + + neg := (bs[0]&0x80 != 0) + bs[0] &= 0x7F + if bs[0] == 0 { + bs = bs[1:] + } + + ret.SetBytes(bs) + if neg { + ret.Neg(ret) + } + + return nil +} + func (b *bitstream) readVarUint() (uint64, error) { r, _, err := b.readVarUintLen(10) return r, err diff --git a/decimal.go b/decimal.go index 2499a403..f8ad6cd5 100644 --- a/decimal.go +++ b/decimal.go @@ -221,6 +221,24 @@ func (d *Decimal) upscale(scale int) *Decimal { } } +// Trunc attempts to truncate this decimal to an int64. Use at your own risk. +func (d *Decimal) Trunc() (int64, error) { + if d.scale < 0 { + // TODO: safety in case scale is very small? + d = d.upscale(0) + } + + str := d.n.String() + + want := len(str) - d.scale + if want <= 0 { + return 0, nil + } + + trunc := str[:want] + return strconv.ParseInt(trunc, 10, 64) +} + // Truncate returns a new decimal, truncated to the given number of // decimal digits of precision. It does not round, so 19.Truncate(1) // = 1d1. diff --git a/decimal_test.go b/decimal_test.go index f690cbdc..6db5a2c5 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -129,6 +129,30 @@ func TestNeg(t *testing.T) { test("-1.2d-3", "1.2d-3") } +func TestTrunc(t *testing.T) { + test := func(a string, eval int64) { + t.Run(fmt.Sprintf("trunc(%v)=%v", a, eval), func(t *testing.T) { + aa := MustParseDecimal(a) + val, err := aa.Trunc() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("0.", 0) + test("0.01", 0) + test("1.", 1) + test("-1.", -1) + test("1.01", 1) + test("-1.01", -1) + test("101", 101) + test("1d3", 1000) +} + func addF(a, b *Decimal) *Decimal { return a.Add(b) } func subF(a, b *Decimal) *Decimal { return a.Sub(b) } func mulF(a, b *Decimal) *Decimal { return a.Mul(b) } diff --git a/textreader_test.go b/textreader_test.go index 334c5f1b..7c79ba3a 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -591,6 +591,26 @@ func _decimalAF(t *testing.T, r Reader, efn string, etas []string, eval *Decimal } } +func _timestamp(t *testing.T, r Reader, eval time.Time) { + _timestampAF(t, r, "", nil, eval) +} + +func _timestampAF(t *testing.T, r Reader, efn string, etas []string, eval time.Time) { + _nextAF(t, r, TimestampType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.timestamp", eval) + } + + val, err := r.TimeValue() + if err != nil { + t.Fatal(err) + } + + if !val.Equal(eval) { + t.Errorf("expected %v, got %v", eval, val) + } +} + func _string(t *testing.T, r Reader, eval string) { _stringAF(t, r, "", nil, eval) } @@ -648,6 +668,44 @@ func _boolAF(t *testing.T, r Reader, efn string, etas []string, eval bool) { } } +func _clob(t *testing.T, r Reader, eval []byte) { + _clobAF(t, r, "", nil, eval) +} + +func _clobAF(t *testing.T, r Reader, efn string, etas []string, eval []byte) { + _nextAF(t, r, ClobType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.clob", eval) + } + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _blob(t *testing.T, r Reader, eval []byte) { + _blobAF(t, r, "", nil, eval) +} + +func _blobAF(t *testing.T, r Reader, efn string, etas []string, eval []byte) { + _nextAF(t, r, BlobType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.blob", eval) + } + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } +} + func _null(t *testing.T, r Reader, et Type) { _nullAF(t, r, et, "", nil) } From 94e39d6129a8a0bfddd281506e536db538a389ec Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 26 Aug 2019 16:05:27 +1000 Subject: [PATCH 34/56] bugfixes, docs, and refactoring --- api.go | 136 ++++++++++----- binaryreader.go | 23 ++- binaryreader_test.go | 26 ++- binarywriter.go | 382 +++++++++++++++++++++---------------------- binarywriter_test.go | 28 ++-- bitstream.go | 11 +- bitstream_test.go | 9 + marshal.go | 48 ++---- reader.go | 15 ++ reader_test.go | 13 +- symboltable.go | 33 ++-- textreader.go | 16 +- textwriter.go | 351 +++++++++++++++++++-------------------- textwriter_test.go | 76 ++++----- writer.go | 24 +-- 15 files changed, 636 insertions(+), 555 deletions(-) diff --git a/api.go b/api.go index bc0da9ad..59d95372 100644 --- a/api.go +++ b/api.go @@ -157,7 +157,7 @@ func (i IntSize) String() string { // outer sequence of values. The Reader will be positioned at the end of the composite value, // such that a call to Next will move to the immediately-following value (if any). // -// r := NewTextReaderStr("[foo, bar] [") +// r := NewTextReaderStr("[foo, bar] [baz]") // for r.Next() { // if err := r.StepIn(); err != nil { // return err @@ -260,43 +260,105 @@ type Reader interface { ByteValue() ([]byte, error) } -// A Writer writes Ion values to an output stream. +// A Writer writes a stream of Ion values. +// +// The various Write methods write atomic values to the current output stream. The +// Begin methods begin writing a list, sexp, or struct respectively. Subsequent +// calls to Write will write values inside of the container until a matching +// End method is called. +// +// var w Writer +// w.BeginSexp() +// { +// w.WriteInt(1) +// w.WriteSymbol("+") +// w.WriteInt(1) +// } +// w.EndSexp() +// +// When writing values inside a struct, the FieldName method must be called before +// each value to set the value's field name. The Annotation method may likewise +// be called before writing any value to add an annotation to the value. +// +// var w Writer +// w.Annotation("user") +// w.BeginStruct() +// { +// w.FieldName("id") +// w.WriteString("qu33nb33") +// w.FieldName("name") +// w.WriteString("Beyoncé") +// } +// w.EndStruct() +// +// When you're done writing values, you should call Finish to ensure everything has +// been flushed from in-memory buffers. While individual methods all return an error +// on failure, implementations will remember any errors, no-op subsequent calls, and +// return the previous error. This lets you keep code a bit cleaner by only checking +// the return value of the final method call (generally Finish). +// +// var w Writer +// writeSomeStuff(w) +// if err := w.Finish(); err != nil { +// return err +// } +// type Writer interface { - InStruct() bool - InList() bool - InSexp() bool - Err() error - - FieldName(val string) - Annotation(val string) - Annotations(vals ...string) - - BeginStruct() - EndStruct() - - BeginList() - EndList() - - BeginSexp() - EndSexp() - - WriteNull() - WriteNullWithType(t Type) - - WriteBool(val bool) - - WriteInt(val int64) - WriteBigInt(val *big.Int) - WriteFloat(val float64) - WriteDecimal(val *Decimal) - - WriteTimestamp(val time.Time) - - WriteSymbol(val string) - WriteString(val string) - - WriteBlob(val []byte) - WriteClob(val []byte) + // FieldName sets the field name for the next value written. + FieldName(val string) error + + // Annotation adds a single annotation to the next value written. + Annotation(val string) error + + // Annotations adds multiple annotations to the next value written. + Annotations(vals ...string) error + + // WriteNull writes an untyped null value. + WriteNull() error + // WriteNullType writes a null value with a type qualifier, e.g. null.bool. + WriteNullType(t Type) error + + // WriteBool writes a boolean value. + WriteBool(val bool) error + + // WriteInt writes an integer value. + WriteInt(val int64) error + // WriteBigInt writes a big integer value. + WriteBigInt(val *big.Int) error + // WriteFloat writes a floating-point value. + WriteFloat(val float64) error + // WriteDecimal writes an arbitrary-precision decimal value. + WriteDecimal(val *Decimal) error + + // WriteTimestamp writes a timestamp value. + WriteTimestamp(val time.Time) error + + // WriteSymbol writes a symbol value. + WriteSymbol(val string) error + // WriteString writes a string value. + WriteString(val string) error + + // WriteClob writes a clob value. + WriteClob(val []byte) error + // WriteBlob writes a blob value. + WriteBlob(val []byte) error + + // BeginList begins writing a list value. + BeginList() error + // EndList finishes writing a list value. + EndList() error + + // BeginSexp begins writing an s-expression value. + BeginSexp() error + // EndSexp finishes writing an s-expression value. + EndSexp() error + + // BeginStruct begins writing a struct value. + BeginStruct() error + // EndStruct finishes writing a struct value. + EndStruct() error + + // Finish finishes writing values and flushes any buffered data. Finish() error } diff --git a/binaryreader.go b/binaryreader.go index 63cbb27e..ab749d96 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -1,6 +1,8 @@ package ion import ( + "bufio" + "bytes" "errors" "fmt" "io" @@ -16,19 +18,19 @@ type binaryReader struct { // NewBinaryReader creates a new binary reader. func NewBinaryReader(in io.Reader, cat *Catalog) Reader { - r := &binaryReader{ - cat: cat, - } - r.bits.Init(in) - return r + return newBinaryReaderBuf(bufio.NewReader(in), cat) } // NewBinaryReaderBytes creates a new binary reader for the given bytes. func NewBinaryReaderBytes(in []byte, cat *Catalog) Reader { + return NewBinaryReader(bytes.NewReader(in), cat) +} + +func newBinaryReaderBuf(in *bufio.Reader, cat *Catalog) Reader { r := &binaryReader{ cat: cat, } - r.bits.InitBytes(in) + r.bits.Init(in) return r } @@ -76,6 +78,15 @@ func (r *binaryReader) next() (bool, error) { err := r.readAnnotations() return false, err + case bitcodeNull: + if !r.bits.Null() { + // NOP padding; skip it and keep going. + err := r.bits.SkipValue() + return false, err + } + r.valueType = NullType + return true, nil + case bitcodeFalse, bitcodeTrue: r.valueType = BoolType if !r.bits.Null() { diff --git a/binaryreader_test.go b/binaryreader_test.go index a838c344..8a87c395 100644 --- a/binaryreader_test.go +++ b/binaryreader_test.go @@ -199,8 +199,10 @@ func TestReadBinaryDecimals(t *testing.T) { func TestReadBinaryFloats(t *testing.T) { r := readBinary([]byte{ - 0x40, // 0 - 0x4F, // null.float + 0x40, // 0 + 0x4F, // null.float + 0x44, 0x7F, 0x7F, 0xFF, 0xFF, // MaxFloat32 + 0x44, 0xFF, 0x7F, 0xFF, 0xFF, // -MaxFloat32 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf @@ -210,6 +212,8 @@ func TestReadBinaryFloats(t *testing.T) { _float(t, r, 0) _null(t, r, FloatType) + _float(t, r, math.MaxFloat32) + _float(t, r, -math.MaxFloat32) _float(t, r, math.MaxFloat64) _float(t, r, -math.MaxFloat64) _float(t, r, math.Inf(1)) @@ -257,6 +261,24 @@ func TestReadBinaryBools(t *testing.T) { _eof(t, r) } +func TestReadBinaryNulls(t *testing.T) { + r := readBinary([]byte{ + 0x00, // 1-byte NOP + 0x0F, // null + 0x01, 0xFF, // 2-byte NOP + 0xE3, 0x81, 0x81, 0x0F, // $ion::null + 0x0E, 0x8F, // 16-byte NOP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xE4, 0x82, 0xEE, 0xEF, 0x0F, // foo::bar::null + }) + + _null(t, r, NullType) + _nullAF(t, r, NullType, "", []string{"$ion"}) + _nullAF(t, r, NullType, "", []string{"foo", "bar"}) + _eof(t, r) +} + func readBinary(ion []byte) Reader { prefix := []byte{ 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 diff --git a/binarywriter.go b/binarywriter.go index b509b289..8171ad49 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -209,186 +209,179 @@ func (w *binaryWriter) endContainer(t ctx) error { } // WriteNull writes an untyped null. -func (w *binaryWriter) WriteNull() { - w.WriteNullWithType(NullType) +func (w *binaryWriter) WriteNull() error { + return w.WriteNullType(NoType) } -// WriteNullWithType writes a typed null. -func (w *binaryWriter) WriteNullWithType(t Type) { - if w.err != nil { - return +// WriteNullType writes a typed null. +func (w *binaryWriter) WriteNullType(t Type) error { + if w.err == nil { + w.err = w.writeValue(func() error { + return w.write([]byte{binaryNulls[t]}) + }) } - - w.err = w.writeValue(func() error { - return w.write([]byte{binaryNulls[t]}) - }) + return w.err } // WriteBool writes a bool. -func (w *binaryWriter) WriteBool(val bool) { - if w.err != nil { - return +func (w *binaryWriter) WriteBool(val bool) error { + if w.err == nil { + w.err = w.writeValue(func() error { + if val { + return w.write([]byte{0x11}) + } + return w.write([]byte{0x10}) + }) } - - w.err = w.writeValue(func() error { - if val { - return w.write([]byte{0x11}) - } - return w.write([]byte{0x10}) - }) + return w.err } // WriteInt writes an integer. -func (w *binaryWriter) WriteInt(val int64) { - if w.err != nil { - return - } - - w.err = w.writeValue(func() error { - if val == 0 { - return w.write([]byte{0x20}) - } +func (w *binaryWriter) WriteInt(val int64) error { + if w.err == nil { + w.err = w.writeValue(func() error { + if val == 0 { + return w.write([]byte{0x20}) + } - code := byte(0x20) - mag := uint64(val) + code := byte(0x20) + mag := uint64(val) - if val < 0 { - code = 0x30 - mag = uint64(-val) - } + if val < 0 { + code = 0x30 + mag = uint64(-val) + } - len := uintLen(mag) - buflen := len + tagLen(len) + len := uintLen(mag) + buflen := len + tagLen(len) - buf := make([]byte, 0, buflen) - buf = appendTag(buf, code, len) - buf = appendUint(buf, mag) + buf := make([]byte, 0, buflen) + buf = appendTag(buf, code, len) + buf = appendUint(buf, mag) - return w.write(buf) - }) + return w.write(buf) + }) + } + return w.err } // WriteBigInt writes a big integer. -func (w *binaryWriter) WriteBigInt(val *big.Int) { - if w.err != nil { - return - } - - w.err = w.writeValue(func() error { - sign := val.Sign() - if sign == 0 { - return w.write([]byte{0x20}) - } +func (w *binaryWriter) WriteBigInt(val *big.Int) error { + if w.err == nil { + w.err = w.writeValue(func() error { + sign := val.Sign() + if sign == 0 { + return w.write([]byte{0x20}) + } - code := byte(0x20) - if sign < 0 { - code = 0x30 - } + code := byte(0x20) + if sign < 0 { + code = 0x30 + } - bs := val.Bytes() + bs := val.Bytes() - bl := uint64(len(bs)) - if bl < 64 { - buflen := bl + tagLen(bl) - buf := make([]byte, 0, buflen) + bl := uint64(len(bs)) + if bl < 64 { + buflen := bl + tagLen(bl) + buf := make([]byte, 0, buflen) - buf = appendTag(buf, code, bl) - buf = append(buf, bs...) - return w.write(buf) - } + buf = appendTag(buf, code, bl) + buf = append(buf, bs...) + return w.write(buf) + } - // no sense in copying, emit tag separately. - if err := w.writeTag(code, bl); err != nil { - return err - } - return w.write(bs) - }) + // no sense in copying, emit tag separately. + if err := w.writeTag(code, bl); err != nil { + return err + } + return w.write(bs) + }) + } + return w.err } // WriteFloat writes a floating-point value. -func (w *binaryWriter) WriteFloat(val float64) { - if w.err != nil { - return - } - - w.err = w.writeValue(func() error { - if val == 0 { - return w.write([]byte{0x40}) - } +func (w *binaryWriter) WriteFloat(val float64) error { + if w.err == nil { + w.err = w.writeValue(func() error { + if val == 0 { + return w.write([]byte{0x40}) + } - bs := make([]byte, 9) - bs[0] = 0x48 + bs := make([]byte, 9) + bs[0] = 0x48 - bits := math.Float64bits(val) - binary.BigEndian.PutUint64(bs[1:], bits) + bits := math.Float64bits(val) + binary.BigEndian.PutUint64(bs[1:], bits) - return w.write(bs) - }) + return w.write(bs) + }) + } + return w.err } // WriteDecimal writes a decimal value. -func (w *binaryWriter) WriteDecimal(val *Decimal) { - if w.err != nil { - return - } - - w.writeValue(func() error { - coef, exp := val.CoEx() - - vlen := uint64(0) - if exp != 0 { - vlen += varIntLen(int64(exp)) - } - if coef.Sign() != 0 { - vlen += bigIntLen(coef) - } +func (w *binaryWriter) WriteDecimal(val *Decimal) error { + if w.err == nil { + w.err = w.writeValue(func() error { + coef, exp := val.CoEx() + + vlen := uint64(0) + if exp != 0 { + vlen += varIntLen(int64(exp)) + } + if coef.Sign() != 0 { + vlen += bigIntLen(coef) + } - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - buf = appendTag(buf, 0x50, vlen) - if exp != 0 { - buf = appendVarInt(buf, int64(exp)) - } - buf = appendBigInt(buf, coef) + buf = appendTag(buf, 0x50, vlen) + if exp != 0 { + buf = appendVarInt(buf, int64(exp)) + } + buf = appendBigInt(buf, coef) - return w.write(buf) - }) + return w.write(buf) + }) + } + return w.err } // WriteTimestamp writes a timestamp value. -func (w *binaryWriter) WriteTimestamp(val time.Time) { - if w.err != nil { - return - } - - w.err = w.writeValue(func() error { - _, offset := val.Zone() - offset /= 60 - utc := val.In(time.UTC) - - vlen := timeLen(offset, utc) - buflen := vlen + tagLen(vlen) +func (w *binaryWriter) WriteTimestamp(val time.Time) error { + if w.err == nil { + w.err = w.writeValue(func() error { + _, offset := val.Zone() + offset /= 60 + utc := val.In(time.UTC) + + vlen := timeLen(offset, utc) + buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + buf := make([]byte, 0, buflen) - buf = appendTag(buf, 0x60, vlen) - buf = appendTime(buf, offset, utc) + buf = appendTag(buf, 0x60, vlen) + buf = appendTime(buf, offset, utc) - return w.write(buf) - }) + return w.write(buf) + }) + } + return w.err } // WriteSymbol writes a symbol value. -func (w *binaryWriter) WriteSymbol(val string) { +func (w *binaryWriter) WriteSymbol(val string) error { if w.err != nil { - return + return w.err } id, err := w.resolve(val) if err != nil { w.err = err - return + return w.err } w.err = w.writeValue(func() error { @@ -401,6 +394,8 @@ func (w *binaryWriter) WriteSymbol(val string) { return w.write(buf) }) + + return w.err } // Resolve resolves a symbol to its ID. @@ -424,109 +419,107 @@ func (w *binaryWriter) resolve(sym string) (uint64, error) { } // WriteString writes a string. -func (w *binaryWriter) WriteString(val string) { - if w.err != nil { - return - } - - w.err = w.writeValue(func() error { - if len(val) == 0 { - return w.write([]byte{0x80}) - } +func (w *binaryWriter) WriteString(val string) error { + if w.err == nil { + w.err = w.writeValue(func() error { + if len(val) == 0 { + return w.write([]byte{0x80}) + } - vlen := uint64(len(val)) - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + vlen := uint64(len(val)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - buf = appendTag(buf, 0x80, vlen) - buf = append(buf, val...) + buf = appendTag(buf, 0x80, vlen) + buf = append(buf, val...) - return w.write(buf) - }) + return w.write(buf) + }) + } + return w.err } // WriteClob writes a clob. -func (w *binaryWriter) WriteClob(val []byte) { - w.writeLob(0x90, val) +func (w *binaryWriter) WriteClob(val []byte) error { + return w.writeLob(0x90, val) } // WriteBlob writes a blob. -func (w *binaryWriter) WriteBlob(val []byte) { - w.writeLob(0xA0, val) +func (w *binaryWriter) WriteBlob(val []byte) error { + return w.writeLob(0xA0, val) } // WriteLob writes a [bc]lob. -func (w *binaryWriter) writeLob(code byte, val []byte) { - if w.err != nil { - return - } +func (w *binaryWriter) writeLob(code byte, val []byte) error { + if w.err == nil { + w.err = w.writeValue(func() error { + vlen := uint64(len(val)) - w.err = w.writeValue(func() error { - vlen := uint64(len(val)) + if vlen < 64 { + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - if vlen < 64 { - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + buf = appendTag(buf, code, vlen) + buf = append(buf, val...) - buf = appendTag(buf, code, vlen) - buf = append(buf, val...) - - return w.write(buf) - } + return w.write(buf) + } - if err := w.writeTag(code, vlen); err != nil { - return err - } - return w.write(val) - }) + if err := w.writeTag(code, vlen); err != nil { + return err + } + return w.write(val) + }) + } + return w.err } // BeginList begins writing a list. -func (w *binaryWriter) BeginList() { - if w.err != nil { - return +func (w *binaryWriter) BeginList() error { + if w.err == nil { + w.err = w.beginContainer(ctxInList, 0xB0) } - w.err = w.beginContainer(ctxInList, 0xB0) + return w.err } // EndList finishes writing a list. -func (w *binaryWriter) EndList() { - if w.err != nil { - return +func (w *binaryWriter) EndList() error { + if w.err == nil { + w.err = w.endContainer(ctxInList) } - w.err = w.endContainer(ctxInList) + return w.err } // BeginSexp begins writing an s-expression. -func (w *binaryWriter) BeginSexp() { - if w.err != nil { - return +func (w *binaryWriter) BeginSexp() error { + if w.err == nil { + w.err = w.beginContainer(ctxInSexp, 0xC0) } - w.err = w.beginContainer(ctxInSexp, 0xC0) + return w.err } // EndSexp finishes writing an s-expression. -func (w *binaryWriter) EndSexp() { - if w.err != nil { - return +func (w *binaryWriter) EndSexp() error { + if w.err == nil { + w.err = w.endContainer(ctxInSexp) } - w.err = w.endContainer(ctxInSexp) + return w.err } // BeginStruct begins writing a struct. -func (w *binaryWriter) BeginStruct() { - if w.err != nil { - return +func (w *binaryWriter) BeginStruct() error { + if w.err == nil { + w.err = w.beginContainer(ctxInStruct, 0xD0) } - w.err = w.beginContainer(ctxInStruct, 0xD0) + return w.err } // EndStruct finishes writing a struct. -func (w *binaryWriter) EndStruct() { - if w.err != nil { - return +func (w *binaryWriter) EndStruct() error { + if w.err == nil { + w.err = w.endContainer(ctxInStruct) } - w.err = w.endContainer(ctxInStruct) + return w.err } // Finish finishes writing a datagram. @@ -535,8 +528,7 @@ func (w *binaryWriter) Finish() error { return w.err } if w.ctx.peek() != ctxAtTopLevel { - w.err = errors.New("ion: not at top level") - return w.err + return errors.New("ion: not at top level") } w.fieldName = "" diff --git a/binarywriter_test.go b/binarywriter_test.go index eaa6c933..db0f51ca 100644 --- a/binarywriter_test.go +++ b/binarywriter_test.go @@ -306,18 +306,18 @@ func TestWriteBinaryNulls(t *testing.T) { testBinaryWriter(t, eval, func(w Writer) { w.WriteNull() - w.WriteNullWithType(BoolType) - w.WriteNullWithType(IntType) - w.WriteNullWithType(FloatType) - w.WriteNullWithType(DecimalType) - w.WriteNullWithType(TimestampType) - w.WriteNullWithType(SymbolType) - w.WriteNullWithType(StringType) - w.WriteNullWithType(ClobType) - w.WriteNullWithType(BlobType) - w.WriteNullWithType(ListType) - w.WriteNullWithType(SexpType) - w.WriteNullWithType(StructType) + w.WriteNullType(BoolType) + w.WriteNullType(IntType) + w.WriteNullType(FloatType) + w.WriteNullType(DecimalType) + w.WriteNullType(TimestampType) + w.WriteNullType(SymbolType) + w.WriteNullType(StringType) + w.WriteNullType(ClobType) + w.WriteNullType(BlobType) + w.WriteNullType(ListType) + w.WriteNullType(SexpType) + w.WriteNullType(StructType) }) } @@ -377,8 +377,8 @@ func writeBinary(t *testing.T, f func(w Writer)) []byte { f(w) - if w.Err() != nil { - t.Fatal(w.Err()) + if err := w.Finish(); err != nil { + t.Fatal(err) } return buf.Bytes() diff --git a/bitstream.go b/bitstream.go index 5b64a3f1..34ddc21c 100644 --- a/bitstream.go +++ b/bitstream.go @@ -102,12 +102,12 @@ type bitstream struct { len uint64 } -func (b *bitstream) Init(in io.Reader) { - b.in = bufio.NewReader(in) +func (b *bitstream) Init(in *bufio.Reader) { + b.in = in } func (b *bitstream) InitBytes(in []byte) { - b.Init(bytes.NewReader(in)) + b.in = bufio.NewReader(bytes.NewReader(in)) } func (b *bitstream) Code() bitcode { @@ -504,8 +504,11 @@ func (b *bitstream) ReadTimestamp() (time.Time, error) { return time.Time{}, err } - utc := time.Date(ts[0], time.Month(ts[1]), ts[2], ts[3], ts[4], ts[5], int(nsecs), time.UTC) + b.state = b.stateAfterValue() + b.code = bitcodeNone + b.len = 0 + utc := time.Date(ts[0], time.Month(ts[1]), ts[2], ts[3], ts[4], ts[5], int(nsecs), time.UTC) return utc.In(time.FixedZone("fixed", int(offset)*60)), nil } diff --git a/bitstream_test.go b/bitstream_test.go index 5e6f7231..2b8dd8db 100644 --- a/bitstream_test.go +++ b/bitstream_test.go @@ -100,3 +100,12 @@ func TestBitstream(t *testing.T) { next(bitcodeEOF, false, 0) next(bitcodeEOF, false, 0) } + +func TestBitcodeString(t *testing.T) { + for i := bitcodeNone; i <= bitcodeAnnotation+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected non-empty string for bitcode %v", uint8(i)) + } + } +} diff --git a/marshal.go b/marshal.go index d0d1dcc4..11ab52e9 100644 --- a/marshal.go +++ b/marshal.go @@ -132,30 +132,24 @@ func (m *Encoder) encodeValue(v reflect.Value) error { t := v.Type() switch t.Kind() { case reflect.Bool: - m.w.WriteBool(v.Bool()) - return m.w.Err() + return m.w.WriteBool(v.Bool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - m.w.WriteInt(v.Int()) - return m.w.Err() + return m.w.WriteInt(v.Int()) case reflect.Uint8, reflect.Uint16, reflect.Uint32: - m.w.WriteInt(int64(v.Uint())) - return m.w.Err() + return m.w.WriteInt(int64(v.Uint())) case reflect.Uint, reflect.Uint64, reflect.Uintptr: i := big.Int{} i.SetUint64(v.Uint()) - m.w.WriteBigInt(&i) - return m.w.Err() + return m.w.WriteBigInt(&i) case reflect.Float32, reflect.Float64: - m.w.WriteFloat(v.Float()) - return m.w.Err() + return m.w.WriteFloat(v.Float()) case reflect.String: - m.w.WriteString(v.String()) - return m.w.Err() + return m.w.WriteString(v.String()) case reflect.Interface, reflect.Ptr: return m.encodePtr(v) @@ -181,8 +175,7 @@ func (m *Encoder) encodeValue(v reflect.Value) error { // the pointer is pointing to. func (m *Encoder) encodePtr(v reflect.Value) error { if v.IsNil() { - m.w.WriteNull() - return m.w.Err() + return m.w.WriteNull() } return m.encodeValue(v.Elem()) } @@ -190,8 +183,7 @@ func (m *Encoder) encodePtr(v reflect.Value) error { // EncodeMap encodes a map to the output writer as an Ion struct. func (m *Encoder) encodeMap(v reflect.Value) error { if v.IsNil() { - m.w.WriteNull() - return m.w.Err() + return m.w.WriteNull() } m.w.BeginStruct() @@ -209,8 +201,7 @@ func (m *Encoder) encodeMap(v reflect.Value) error { } } - m.w.EndStruct() - return m.w.Err() + return m.w.EndStruct() } // A mapkey holds the reflective map key value as well as its stringified form. @@ -245,8 +236,7 @@ func (m *Encoder) encodeSlice(v reflect.Value) error { } if v.IsNil() { - m.w.WriteNull() - return m.w.Err() + return m.w.WriteNull() } return m.encodeArray(v) @@ -255,11 +245,9 @@ func (m *Encoder) encodeSlice(v reflect.Value) error { // EncodeBlob encodes a []byte to the output writer as an Ion blob. func (m *Encoder) encodeBlob(v reflect.Value) error { if v.IsNil() { - m.w.WriteNull() - } else { - m.w.WriteBlob(v.Bytes()) + return m.w.WriteNull() } - return m.w.Err() + return m.w.WriteBlob(v.Bytes()) } // EncodeArray encodes an array to the output writer as an Ion list. @@ -272,8 +260,7 @@ func (m *Encoder) encodeArray(v reflect.Value) error { } } - m.w.EndList() - return m.w.Err() + return m.w.EndList() } // EncodeStruct encodes a struct to the output writer as an Ion struct. @@ -315,22 +302,19 @@ FieldLoop: } } - m.w.EndStruct() - return m.w.Err() + return m.w.EndStruct() } // EncodeTime encodes a time.Time to the output writer as an Ion timestamp. func (m *Encoder) encodeTime(v reflect.Value) error { t := v.Interface().(time.Time) - m.w.WriteTimestamp(t) - return m.w.Err() + return m.w.WriteTimestamp(t) } // EncodeDecimal encodes an ion.Decimal to the output writer as an Ion decimal. func (m *Encoder) encodeDecimal(v reflect.Value) error { d := v.Addr().Interface().(*Decimal) - m.w.WriteDecimal(d) - return m.w.Err() + return m.w.WriteDecimal(d) } // EmptyValue returns true if the given value is the empty value for its type. diff --git a/reader.go b/reader.go index 252c13a9..e8a05dd7 100644 --- a/reader.go +++ b/reader.go @@ -1,12 +1,27 @@ package ion import ( + "bufio" "errors" + "io" "math" "math/big" "time" ) +// NewReader creates a new Ion reader of the appropriate type by peeking +// at the first several bytes of input for a binary version marker. +func NewReader(in io.Reader) Reader { + br := bufio.NewReader(in) + + bs, err := br.Peek(4) + if err == nil && bs[0] == 0xE0 && bs[3] == 0xEA { + return newBinaryReaderBuf(br, nil) + } + + return newTextReaderBuf(br) +} + // A reader holds common implementation stuff to both the text and binary readers. type reader struct { ctx ctxstack diff --git a/reader_test.go b/reader_test.go index 25f40598..2e204de1 100644 --- a/reader_test.go +++ b/reader_test.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "os" "path/filepath" - "strings" "testing" ) @@ -122,17 +121,7 @@ func testReadFile(t *testing.T, path string, d drainfunc) { } defer file.Close() - var r Reader - - if strings.HasSuffix(path, ".ion") { - r = NewTextReader(file) - // r.(*textReader).debug = true - } else if strings.HasSuffix(path, ".10n") { - // Binary ion not yet supported. - return - } else { - t.Fatal("unexpected suffix on file", path) - } + r := NewReader(file) d(t, r, path) } diff --git a/symboltable.go b/symboltable.go index 284f4311..e87b239d 100644 --- a/symboltable.go +++ b/symboltable.go @@ -96,24 +96,23 @@ func (s *sst) FindByID(id int) (string, bool) { func (s *sst) WriteTo(w Writer) error { w.Annotation("$ion_shared_symbol_table") w.BeginStruct() + { + w.FieldName("name") + w.WriteString(s.name) - w.FieldName("name") - w.WriteString(s.name) - - w.FieldName("version") - w.WriteInt(int64(s.version)) - - w.FieldName("symbols") - w.BeginList() + w.FieldName("version") + w.WriteInt(int64(s.version)) - for _, sym := range s.symbols { - w.WriteString(sym) + w.FieldName("symbols") + w.BeginList() + { + for _, sym := range s.symbols { + w.WriteString(sym) + } + } + w.EndList() } - - w.EndList() // symbols - - w.EndStruct() - return w.Err() + return w.EndStruct() } func (s *sst) String() string { @@ -321,12 +320,10 @@ func (t *lst) WriteTo(w Writer) error { for _, sym := range t.symbols { w.WriteString(sym) } - w.EndList() } - w.EndStruct() - return w.Err() + return w.EndStruct() } func (t *lst) String() string { diff --git a/textreader.go b/textreader.go index 3e36898c..782538e6 100644 --- a/textreader.go +++ b/textreader.go @@ -51,12 +51,7 @@ type textReader struct { // NewTextReader creates a new text reader. func NewTextReader(in io.Reader) Reader { - return &textReader{ - tok: tokenizer{ - in: bufio.NewReader(in), - }, - state: trsBeforeTypeAnnotations, - } + return newTextReaderBuf(bufio.NewReader(in)) } // NewTextReaderStr creates a new text reader from a string. @@ -64,6 +59,15 @@ func NewTextReaderStr(str string) Reader { return NewTextReader(strings.NewReader(str)) } +func newTextReaderBuf(in *bufio.Reader) Reader { + return &textReader{ + tok: tokenizer{ + in: in, + }, + state: trsBeforeTypeAnnotations, + } +} + // SymbolTable returns the current symbol table. func (t *textReader) SymbolTable() SymbolTable { // TODO: Include me if present in the input stream? diff --git a/textwriter.go b/textwriter.go index 85a855b4..0b1aa47a 100644 --- a/textwriter.go +++ b/textwriter.go @@ -131,51 +131,51 @@ func (w *textWriter) end(t ctx, c byte) error { } // BeginStruct begins writing a struct. -func (w *textWriter) BeginStruct() { - if w.err != nil { - return +func (w *textWriter) BeginStruct() error { + if w.err == nil { + w.err = w.begin(ctxInStruct, '{') } - w.err = w.begin(ctxInStruct, '{') + return w.err } // EndStruct finishes writing a struct. -func (w *textWriter) EndStruct() { - if w.err != nil { - return +func (w *textWriter) EndStruct() error { + if w.err == nil { + w.err = w.end(ctxInStruct, '}') } - w.err = w.end(ctxInStruct, '}') + return w.err } // BeginList begins writing a list. -func (w *textWriter) BeginList() { - if w.err != nil { - return +func (w *textWriter) BeginList() error { + if w.err == nil { + w.err = w.begin(ctxInList, '[') } - w.err = w.begin(ctxInList, '[') + return w.err } // EndList finishes writing a list. -func (w *textWriter) EndList() { - if w.err != nil { - return +func (w *textWriter) EndList() error { + if w.err == nil { + w.err = w.end(ctxInList, ']') } - w.err = w.end(ctxInList, ']') + return w.err } // BeginSexp begins writing an s-expression. -func (w *textWriter) BeginSexp() { - if w.err != nil { - return +func (w *textWriter) BeginSexp() error { + if w.err == nil { + w.err = w.begin(ctxInSexp, '(') } - w.err = w.begin(ctxInSexp, '(') + return w.err } // EndSexp finishes writing an s-expression. -func (w *textWriter) EndSexp() { - if w.err != nil { - return +func (w *textWriter) EndSexp() error { + if w.err == nil { + w.err = w.end(ctxInSexp, ')') } - w.err = w.end(ctxInSexp, ')') + return w.err } // writeValue writes a value whose raw encoding is produced by the @@ -210,203 +210,205 @@ func (w *textWriter) writeValueStreaming(f func() error) error { } // WriteNull writes an untyped null. -func (w *textWriter) WriteNull() { - w.WriteNullWithType(NullType) +func (w *textWriter) WriteNull() error { + return w.WriteNullType(NoType) } -// WriteNullWithType writes a typed null. -func (w *textWriter) WriteNullWithType(t Type) { - if w.err != nil { - return +// WriteNullType writes a typed null. +func (w *textWriter) WriteNullType(t Type) error { + if w.err == nil { + w.err = w.writeValue(func() string { + switch t { + case NoType: + return "null" + case NullType: + return "null.null" + case BoolType: + return "null.bool" + case IntType: + return "null.int" + case FloatType: + return "null.float" + case DecimalType: + return "null.decimal" + case TimestampType: + return "null.timestamp" + case StringType: + return "null.string" + case SymbolType: + return "null.symbol" + case BlobType: + return "null.blob" + case ClobType: + return "null.clob" + case StructType: + return "null.struct" + case ListType: + return "null.list" + case SexpType: + return "null.sexp" + default: + panic(fmt.Sprintf("invalid type: %v", t)) + } + }) } - w.err = w.writeValue(func() string { - switch t { - case NoType, NullType: - return "null" - case BoolType: - return "null.bool" - case IntType: - return "null.int" - case FloatType: - return "null.float" - case DecimalType: - return "null.decimal" - case TimestampType: - return "null.timestamp" - case StringType: - return "null.string" - case SymbolType: - return "null.symbol" - case BlobType: - return "null.blob" - case ClobType: - return "null.clob" - case StructType: - return "null.struct" - case ListType: - return "null.list" - case SexpType: - return "null.sexp" - default: - panic(fmt.Sprintf("invalid type: %v", t)) - } - }) + return w.err } // WriteBool writes a boolean value. -func (w *textWriter) WriteBool(val bool) { - if w.err != nil { - return +func (w *textWriter) WriteBool(val bool) error { + if w.err == nil { + w.err = w.writeValue(func() string { + if val { + return "true" + } + return "false" + }) } - w.err = w.writeValue(func() string { - if val { - return "true" - } - return "false" - }) + return w.err } // WriteInt writes an integer value. -func (w *textWriter) WriteInt(val int64) { - if w.err != nil { - return +func (w *textWriter) WriteInt(val int64) error { + if w.err == nil { + w.err = w.writeValue(func() string { + return fmt.Sprintf("%d", val) + }) } - w.err = w.writeValue(func() string { - return fmt.Sprintf("%d", val) - }) + return w.err } // WriteBigInt writes a (big) integer value. -func (w *textWriter) WriteBigInt(val *big.Int) { - if w.err != nil { - return +func (w *textWriter) WriteBigInt(val *big.Int) error { + if w.err == nil { + w.err = w.writeValue(func() string { + return val.String() + }) } - w.err = w.writeValue(func() string { - return val.String() - }) + return w.err } // WriteFloat writes a floating-point value. -func (w *textWriter) WriteFloat(val float64) { - if w.err != nil { - return - } - w.err = w.writeValue(func() string { - // Built-in go formatting isn't quite up to the task. :( - str := strconv.FormatFloat(val, 'e', -1, 64) - - switch str { - case "NaN": - return "nan" - case "+Inf": - return "+inf" - case "-Inf": - return "-inf" - default: - break - } +func (w *textWriter) WriteFloat(val float64) error { + if w.err == nil { + w.err = w.writeValue(func() string { + // Built-in go formatting isn't quite up to the task. :( + str := strconv.FormatFloat(val, 'e', -1, 64) + + switch str { + case "NaN": + return "nan" + case "+Inf": + return "+inf" + case "-Inf": + return "-inf" + default: + break + } - idx := strings.Index(str, "e") - if idx < 0 { - str += "e0" - } else if idx+2 < len(str) && str[idx+2] == '0' { - str = str[:idx+2] + str[idx+3:] - } + idx := strings.Index(str, "e") + if idx < 0 { + str += "e0" + } else if idx+2 < len(str) && str[idx+2] == '0' { + str = str[:idx+2] + str[idx+3:] + } - return str - }) + return str + }) + } + return w.err } // WriteDecimal writes an arbitrary-precision decimal value. -func (w *textWriter) WriteDecimal(val *Decimal) { - if w.err != nil { - return +func (w *textWriter) WriteDecimal(val *Decimal) error { + if w.err == nil { + w.err = w.writeValue(func() string { + return val.String() + }) } - w.err = w.writeValue(func() string { - return val.String() - }) + return w.err } // WriteTimestamp writes a timestamp. -func (w *textWriter) WriteTimestamp(val time.Time) { - if w.err != nil { - return +func (w *textWriter) WriteTimestamp(val time.Time) error { + if w.err == nil { + w.err = w.writeValue(func() string { + return val.Format(time.RFC3339Nano) + }) } - w.err = w.writeValue(func() string { - return val.Format(time.RFC3339Nano) - }) + return w.err } // WriteSymbol writes a symbol. -func (w *textWriter) WriteSymbol(val string) { - if w.err != nil { - return +func (w *textWriter) WriteSymbol(val string) error { + if w.err == nil { + w.err = w.writeValueStreaming(func() error { + return writeSymbol(val, w.out) + }) } - w.err = w.writeValueStreaming(func() error { - return writeSymbol(val, w.out) - }) + return w.err } // WriteString writes a string. -func (w *textWriter) WriteString(val string) { - if w.err != nil { - return +func (w *textWriter) WriteString(val string) error { + if w.err == nil { + w.err = w.writeValueStreaming(func() error { + if err := writeRawChar('"', w.out); err != nil { + return err + } + if err := writeEscapedString(val, w.out); err != nil { + return err + } + return writeRawChar('"', w.out) + }) } - w.err = w.writeValueStreaming(func() error { - if err := writeRawChar('"', w.out); err != nil { - return err - } - if err := writeEscapedString(val, w.out); err != nil { - return err - } - return writeRawChar('"', w.out) - }) + return w.err } // WriteBlob writes a blob. -func (w *textWriter) WriteBlob(val []byte) { - if w.err != nil { - return - } - w.err = w.writeValueStreaming(func() error { - if err := writeRawString("{{", w.out); err != nil { - return err - } +func (w *textWriter) WriteBlob(val []byte) error { + if w.err == nil { + w.err = w.writeValueStreaming(func() error { + if err := writeRawString("{{", w.out); err != nil { + return err + } - enc := base64.NewEncoder(base64.StdEncoding, w.out) - enc.Write(val) - if err := enc.Close(); err != nil { - return err - } + enc := base64.NewEncoder(base64.StdEncoding, w.out) + enc.Write(val) + if err := enc.Close(); err != nil { + return err + } - return writeRawString("}}", w.out) - }) + return writeRawString("}}", w.out) + }) + } + return w.err } // WriteClob writes a clob. -func (w *textWriter) WriteClob(val []byte) { - if w.err != nil { - return - } - w.err = w.writeValueStreaming(func() error { - if err := writeRawString("{{\"", w.out); err != nil { - return err - } +func (w *textWriter) WriteClob(val []byte) error { + if w.err == nil { + w.err = w.writeValueStreaming(func() error { + if err := writeRawString("{{\"", w.out); err != nil { + return err + } - for _, c := range val { - if c < 32 || c == '\\' || c == '"' || c > 0x7F { - if err := writeEscapedChar(c, w.out); err != nil { - return err - } - } else { - if err := writeRawChar(c, w.out); err != nil { - return err + for _, c := range val { + if c < 32 || c == '\\' || c == '"' || c > 0x7F { + if err := writeEscapedChar(c, w.out); err != nil { + return err + } + } else { + if err := writeRawChar(c, w.out); err != nil { + return err + } } } - } - return writeRawString("\"}}", w.out) - }) + return writeRawString("\"}}", w.out) + }) + } + return w.err } // Finish finishes the current datagram. @@ -415,8 +417,7 @@ func (w *textWriter) Finish() error { return w.err } if w.ctx.peek() != ctxAtTopLevel { - w.err = errors.New("not at top level") - return w.err + return errors.New("not at top level") } if w.opts&TextWriterQuietFinish == 0 { diff --git a/textwriter_test.go b/textwriter_test.go index a3262b1d..22383145 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -10,8 +10,7 @@ import ( func TestTopLevelFieldName(t *testing.T) { writeText(func(w Writer) { - w.FieldName("foo") - if w.Err() == nil { + if err := w.FieldName("foo"); err == nil { t.Error("expected an error") } }) @@ -19,30 +18,15 @@ func TestTopLevelFieldName(t *testing.T) { func TestEmptyStruct(t *testing.T) { testTextWriter(t, "{}", func(w Writer) { - if w.InStruct() { - t.Error("already in struct") - } - - w.BeginStruct() - if w.Err() != nil { - t.Fatal(w.Err()) - } - - if !w.InStruct() { - t.Error("not in struct after begin") - } - - w.EndStruct() - if w.Err() != nil { - t.Fatal(w.Err()) + if err := w.BeginStruct(); err != nil { + t.Fatal(err) } - if w.InStruct() { - t.Error("still in struct after end") + if err := w.EndStruct(); err != nil { + t.Fatal(err) } - w.EndStruct() - if w.Err() == nil { + if err := w.EndStruct(); err == nil { t.Fatal("no error from ending struct too many times") } }) @@ -54,10 +38,10 @@ func TestAnnotatedStruct(t *testing.T) { w.Annotation("$bar") w.Annotation(".baz") w.BeginStruct() - w.EndStruct() + err := w.EndStruct() - if w.Err() != nil { - t.Fatal(w.Err()) + if err != nil { + t.Fatal(err) } }) } @@ -81,22 +65,15 @@ func TestNestedStruct(t *testing.T) { func TestEmptyList(t *testing.T) { testTextWriter(t, "[]", func(w Writer) { - w.BeginList() - if w.Err() != nil { - t.Fatal(w.Err()) - } - - if w.InStruct() { - t.Error("instruct returns true in a list") + if err := w.BeginList(); err != nil { + t.Fatal(err) } - w.EndList() - if w.Err() != nil { - t.Fatal(w.Err()) + if err := w.EndList(); err != nil { + t.Fatal(err) } - w.EndList() - if w.Err() == nil { + if err := w.EndList(); err == nil { t.Error("no error calling endlist at top level") } }) @@ -140,17 +117,30 @@ func TestWriteSexps(t *testing.T) { }) } -func TestNull(t *testing.T) { - expected := "[null,foo::null,null.int,bar::null.sexp]" +func TestNulls(t *testing.T) { + expected := "[null,foo::null.null,null.bool,null.int,null.float,null.decimal," + + "null.timestamp,null.symbol,null.string,null.clob,null.blob," + + "null.list,'null'::null.sexp,null.struct]" + testTextWriter(t, expected, func(w Writer) { w.BeginList() w.WriteNull() w.Annotation("foo") - w.WriteNullWithType(NullType) - w.WriteNullWithType(IntType) - w.Annotation("bar") - w.WriteNullWithType(SexpType) + w.WriteNullType(NullType) + w.WriteNullType(BoolType) + w.WriteNullType(IntType) + w.WriteNullType(FloatType) + w.WriteNullType(DecimalType) + w.WriteNullType(TimestampType) + w.WriteNullType(SymbolType) + w.WriteNullType(StringType) + w.WriteNullType(ClobType) + w.WriteNullType(BlobType) + w.WriteNullType(ListType) + w.Annotation("null") + w.WriteNullType(SexpType) + w.WriteNullType(StructType) w.EndList() }) diff --git a/writer.go b/writer.go index 974f0861..19d20cfb 100644 --- a/writer.go +++ b/writer.go @@ -37,29 +37,31 @@ func (w *writer) Err() error { // FieldName sets the field name for the next value written. // It may only be called while writing a struct. -func (w *writer) FieldName(val string) { +func (w *writer) FieldName(val string) error { if w.err != nil { - return + return w.err } if !w.InStruct() { w.err = errors.New("FieldName() called but not writing a struct") - return + return w.err } + w.fieldName = val + return nil } // Annotation adds an annotation to the next value written. -func (w *writer) Annotation(val string) { - if w.err != nil { - return +func (w *writer) Annotation(val string) error { + if w.err == nil { + w.annotations = append(w.annotations, val) } - w.annotations = append(w.annotations, val) + return w.err } // Annotations adds one or more annotations to the next value written. -func (w *writer) Annotations(val ...string) { - if w.err != nil { - return +func (w *writer) Annotations(val ...string) error { + if w.err == nil { + w.annotations = append(w.annotations, val...) } - w.annotations = append(w.annotations, val...) + return w.err } From aa2c30274d2f844b510b22aa05ad1cd00aed2672 Mon Sep 17 00:00:00 2001 From: David Murray Date: Sun, 1 Sep 2019 13:46:13 +1000 Subject: [PATCH 35/56] Refactoring things --- api.go | 364 ------------------------- binaryreader.go | 13 +- binaryreader_test.go | 2 +- binarywriter.go | 2 +- catalog.go | 18 +- consts.go | 21 +- reader.go | 153 ++++++++++- textreader.go | 12 - textreader_test.go | 36 +-- textutils.go | 56 ++-- textwriter.go | 519 ++++++++++++++++-------------------- tokenizer.go | 13 + type.go | 121 +++++++++ api_test.go => type_test.go | 4 +- unmarshal.go | 3 +- unmarshal_test.go | 20 +- writer.go | 142 ++++++++-- 17 files changed, 744 insertions(+), 755 deletions(-) delete mode 100644 api.go create mode 100644 type.go rename api_test.go => type_test.go (94%) diff --git a/api.go b/api.go deleted file mode 100644 index 59d95372..00000000 --- a/api.go +++ /dev/null @@ -1,364 +0,0 @@ -package ion - -import ( - "fmt" - "math/big" - "time" -) - -// A Type represents the type of an Ion Value. -type Type uint8 - -const ( - // NoType is returned by a Reader that is not currently pointing at a value. - NoType Type = iota - - // NullType is the type of the (unqualified) Ion null value. - NullType - - // BoolType is the type of an Ion boolean, true or false. - BoolType - - // IntType is the type of a signed Ion integer of arbitrary size. - IntType - - // FloatType is the type of a fixed-precision Ion floating-point value. - FloatType - - // DecimalType is the type of an arbitrary-precision Ion decimal value. - DecimalType - - // TimestampType is the type of an arbitrary-precision Ion timestamp. - TimestampType - - // SymbolType is the type of an Ion symbol, mapped to an integer ID by a SymbolTable - // to (potentially) save space. - SymbolType - - // StringType is the type of a non-symbol Unicode string, represented directly. - StringType - - // ClobType is the type of a character large object. Like a BlobType, it stores an - // arbitrary sequence of bytes, but it represents them in text form as an escaped-ASCII - // string rather than a base64-encoded string. - ClobType - - // BlobType is the type of a binary large object; a sequence of arbitrary bytes. - BlobType - - // ListType is the type of a list, recursively containing zero or more Ion values. - ListType - - // SexpType is the type of an s-expression. Like a ListType, it contains a sequence - // of zero or more Ion values, but with a lisp-like syntax when encoded as text. - SexpType - - // StructType is the type of a structure, recursively containing a sequence of named - // (by an Ion symbol) Ion values. - StructType -) - -// String implements fmt.Stringer for Type. -func (t Type) String() string { - switch t { - case NoType: - return "" - case NullType: - return "null" - case BoolType: - return "bool" - case IntType: - return "int" - case FloatType: - return "float" - case DecimalType: - return "decimal" - case TimestampType: - return "timestamp" - case StringType: - return "string" - case SymbolType: - return "symbol" - case BlobType: - return "blob" - case ClobType: - return "clob" - case StructType: - return "struct" - case ListType: - return "list" - case SexpType: - return "sexp" - default: - return fmt.Sprintf("", uint8(t)) - } -} - -// IntSize returns the size of an integer, allowing you to pick the -// appropriate Reader method to call to retrieve the value without loss. -type IntSize uint8 - -const ( - // NullInt is the size of null.int and other things that aren't actually ints. - NullInt IntSize = iota - // Int32 is an integer that can be losslessly stored in an int32. - Int32 - // Int64 is an integer that can be losslessly stored in an int64. - Int64 - // BigInt is an integer that can only be losslessly stored in a big.Int. - BigInt -) - -// String implements fmt.Stringer for IntSize. -func (i IntSize) String() string { - switch i { - case NullInt: - return "null.int" - case Int32: - return "int32" - case Int64: - return "int64" - case BigInt: - return "big.Int" - default: - return fmt.Sprintf("", uint8(i)) - } -} - -// A Reader reads a stream of Ion values. -// -// The Reader has a logical position within the stream of values, influencing the -// values returnedd from its methods. Initially, the Reader is positioned before the -// first value in the stream. A call to Next advances the Reader to the first value -// in the stream, with subsequent calls advancing to subsequent values. When a call to -// Next moves the Reader to the position after the final value in the stream, it returns -// false, making it easy to loop through the values in a stream. -// -// var r Reader -// for r.Next() { -// // ... -// } -// -// Next also returns false in case of error. This can be distinguished from a legitimate -// end-of-stream by calling Err after exiting the loop. -// -// When positioned on an Ion value, the type of the value can be retrieved by calling -// Type. If it has an associated field name (inside a struct) or annotations, they can -// be read by calling FieldName and Annotations respectively. -// -// For atomic values, an appropriate XxxValue method can be called to read the value. -// For lists, sexps, and structs, you should instead call StepIn to move the Reader in -// to the contained sequence of values. The Reader will initially be positioned before -// the first value in the container. Calling Next without calling StepIn will skip over -// the composite value and return the next value in the outer value stream. -// -// At any point while reading through a composite value, including when Next returns false -// to indicate the end of the contained values, you may call StepOut to move back to the -// outer sequence of values. The Reader will be positioned at the end of the composite value, -// such that a call to Next will move to the immediately-following value (if any). -// -// r := NewTextReaderStr("[foo, bar] [baz]") -// for r.Next() { -// if err := r.StepIn(); err != nil { -// return err -// } -// for r.Next() { -// fmt.Println(r.StringValue()) -// } -// if err := r.StepOut(); err != nil { -// return err -// } -// } -// if err := r.Err(); err != nil { -// return err -// } -// -type Reader interface { - - // SymbolTable returns the current symbol table, or nil if there isn't one. - // Text Readers do not, generally speaking, have an associated symbol table. - // Binary Readers do. - SymbolTable() SymbolTable - - // Next advances the Reader to the next position in the current value stream. - // It returns true if this is the position of an Ion value, and false if it - // is not. On error, it returns false and sets Err. - Next() bool - - // Err returns an error if a previous call call to Next has failed. - Err() error - - // Type returns the type of the Ion value the Reader is currently positioned on. - // It returns NoType if the Reader is positioned before or after a value. - Type() Type - - // IsNull returns true if the current value is an explicit null. This may be true - // even if the Type is not NullType (for example, null.struct has type Struct). Yes, - // that's a bit confusing. - IsNull() bool - - // FieldName returns the field name associated with the current value. It returns - // the empty string if there is no current value or the current value has no field - // name. - FieldName() string - - // Annotations returns the set of annotations associated with the current value. - // It returns nil if there is no current value or the current value has no annotations. - Annotations() []string - - // StepIn steps in to the current value if it is a container. It returns an error if there - // is no current value or if the value is not a container. On success, the Reader is - // positioned before the first value in the container. - StepIn() error - - // StepOut steps out of the current container value being read. It returns an error if - // this Reader is not currently stepped in to a container. On success, the Reader is - // positioned after the end of the container, but before any subsequent values in the - // stream. - StepOut() error - - // BoolValue returns the current value as a boolean (if that makes sense). It returns - // an error if the current value is not an Ion bool. - BoolValue() (bool, error) - - // IntSize returns the size of integer needed to losslessly represent the current value - // (if that makes sense). It returns an error if the current value is not an Ion int. - IntSize() (IntSize, error) - - // IntValue returns the current value as a 32-bit integer (if that makes sense). It - // returns an error if the current value is not an Ion integer or requires more than - // 32 bits to represent losslessly. - IntValue() (int, error) - - // Int64Value returns the current value as a 64-bit integer (if that makes sense). It - // returns an error if the current value is not an Ion integer or requires more than - // 64 bits to represent losslessly. - Int64Value() (int64, error) - - // BigIntValue returns the current value as a big.Integer (if that makes sense). It - // returns an error if the current value is not an Ion integer. - BigIntValue() (*big.Int, error) - - // FloatValue returns the current value as a 64-bit floating point number (if that - // makes sense). It returns an error if the current value is not an Ion float. - FloatValue() (float64, error) - - // DecimalValue returns the current value as an arbitrary-precision Decimal (if that - // makes sense). It returns an error if the current value is not an Ion decimal. - DecimalValue() (*Decimal, error) - - // TimeValue returns the current value as a timestamp (if that makes sense). It returns - // an error if the current value is not an Ion timestamp. - TimeValue() (time.Time, error) - - // StringValue returns the current value as a string (if that makes sense). It returns - // an error if the current value is not an Ion symbol or an Ion string. - StringValue() (string, error) - - // ByteValue returns the current value as a byte slice (if that makes sense). It returns - // an error if the current value is not an Ion clob or an Ion blob. - ByteValue() ([]byte, error) -} - -// A Writer writes a stream of Ion values. -// -// The various Write methods write atomic values to the current output stream. The -// Begin methods begin writing a list, sexp, or struct respectively. Subsequent -// calls to Write will write values inside of the container until a matching -// End method is called. -// -// var w Writer -// w.BeginSexp() -// { -// w.WriteInt(1) -// w.WriteSymbol("+") -// w.WriteInt(1) -// } -// w.EndSexp() -// -// When writing values inside a struct, the FieldName method must be called before -// each value to set the value's field name. The Annotation method may likewise -// be called before writing any value to add an annotation to the value. -// -// var w Writer -// w.Annotation("user") -// w.BeginStruct() -// { -// w.FieldName("id") -// w.WriteString("qu33nb33") -// w.FieldName("name") -// w.WriteString("Beyoncé") -// } -// w.EndStruct() -// -// When you're done writing values, you should call Finish to ensure everything has -// been flushed from in-memory buffers. While individual methods all return an error -// on failure, implementations will remember any errors, no-op subsequent calls, and -// return the previous error. This lets you keep code a bit cleaner by only checking -// the return value of the final method call (generally Finish). -// -// var w Writer -// writeSomeStuff(w) -// if err := w.Finish(); err != nil { -// return err -// } -// -type Writer interface { - - // FieldName sets the field name for the next value written. - FieldName(val string) error - - // Annotation adds a single annotation to the next value written. - Annotation(val string) error - - // Annotations adds multiple annotations to the next value written. - Annotations(vals ...string) error - - // WriteNull writes an untyped null value. - WriteNull() error - // WriteNullType writes a null value with a type qualifier, e.g. null.bool. - WriteNullType(t Type) error - - // WriteBool writes a boolean value. - WriteBool(val bool) error - - // WriteInt writes an integer value. - WriteInt(val int64) error - // WriteBigInt writes a big integer value. - WriteBigInt(val *big.Int) error - // WriteFloat writes a floating-point value. - WriteFloat(val float64) error - // WriteDecimal writes an arbitrary-precision decimal value. - WriteDecimal(val *Decimal) error - - // WriteTimestamp writes a timestamp value. - WriteTimestamp(val time.Time) error - - // WriteSymbol writes a symbol value. - WriteSymbol(val string) error - // WriteString writes a string value. - WriteString(val string) error - - // WriteClob writes a clob value. - WriteClob(val []byte) error - // WriteBlob writes a blob value. - WriteBlob(val []byte) error - - // BeginList begins writing a list value. - BeginList() error - // EndList finishes writing a list value. - EndList() error - - // BeginSexp begins writing an s-expression value. - BeginSexp() error - // EndSexp finishes writing an s-expression value. - EndSexp() error - - // BeginStruct begins writing a struct value. - BeginStruct() error - // EndStruct finishes writing a struct value. - EndStruct() error - - // Finish finishes writing values and flushes any buffered data. - Finish() error -} diff --git a/binaryreader.go b/binaryreader.go index ab749d96..7df471be 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -2,12 +2,11 @@ package ion import ( "bufio" - "bytes" "errors" "fmt" - "io" ) +// A binaryReader reads binary Ion. type binaryReader struct { bits bitstream cat *Catalog @@ -16,16 +15,6 @@ type binaryReader struct { reader } -// NewBinaryReader creates a new binary reader. -func NewBinaryReader(in io.Reader, cat *Catalog) Reader { - return newBinaryReaderBuf(bufio.NewReader(in), cat) -} - -// NewBinaryReaderBytes creates a new binary reader for the given bytes. -func NewBinaryReaderBytes(in []byte, cat *Catalog) Reader { - return NewBinaryReader(bytes.NewReader(in), cat) -} - func newBinaryReaderBuf(in *bufio.Reader, cat *Catalog) Reader { r := &binaryReader{ cat: cat, diff --git a/binaryreader_test.go b/binaryreader_test.go index 8a87c395..9e064dbf 100644 --- a/binaryreader_test.go +++ b/binaryreader_test.go @@ -295,5 +295,5 @@ func readBinary(ion []byte) Reader { // ] // } } - return NewBinaryReaderBytes(append(prefix, ion...), nil) + return NewReaderBytes(append(prefix, ion...)) } diff --git a/binarywriter.go b/binarywriter.go index 8171ad49..a8c6fd34 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -99,7 +99,7 @@ func (w *binaryWriter) beginValue() error { } } - if w.InStruct() { + if w.inStruct() { if name == "" { return errors.New("ion: field name not set") } diff --git a/catalog.go b/catalog.go index 315d0926..a0657f8f 100644 --- a/catalog.go +++ b/catalog.go @@ -1,8 +1,12 @@ package ion -import "fmt" +import ( + "bytes" + "fmt" + "io" +) -// A Catalog stores shared symbol tables. +// A Catalog stores shared symbol tables and serves as a reader factory. type Catalog struct { ssts map[string]SharedSymbolTable } @@ -18,3 +22,13 @@ func (c *Catalog) Find(name string, version int) SharedSymbolTable { key := fmt.Sprintf("%v/%v", name, version) return c.ssts[key] } + +// NewReader creates a new reader using this catalog. +func (c *Catalog) NewReader(in io.Reader) Reader { + return newReader(in, c) +} + +// NewReaderBytes creates a new reader using this catalog. +func (c *Catalog) NewReaderBytes(in []byte) Reader { + return newReader(bytes.NewReader(in), c) +} diff --git a/consts.go b/consts.go index e14ba0a2..1a49dec4 100644 --- a/consts.go +++ b/consts.go @@ -6,7 +6,7 @@ import ( ) var binaryNulls = func() []byte { - ret := make([]byte, int(StructType)+1) + ret := make([]byte, StructType+1) ret[NoType] = 0x0F ret[NullType] = 0x0F ret[BoolType] = 0x1F @@ -24,6 +24,25 @@ var binaryNulls = func() []byte { return ret }() +var textNulls []string = func() []string { + ret := make([]string, StructType+1) + ret[NoType] = "null" + ret[NullType] = "null.null" + ret[BoolType] = "null.bool" + ret[IntType] = "null.int" + ret[FloatType] = "null.float" + ret[DecimalType] = "null.decimal" + ret[TimestampType] = "null.timestamp" + ret[SymbolType] = "null.symbol" + ret[StringType] = "null.string" + ret[ClobType] = "null.clob" + ret[BlobType] = "null.blob" + ret[ListType] = "null.list" + ret[SexpType] = "null.sexp" + ret[StructType] = "null.struct" + return ret +}() + var hexChars = []byte{ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', diff --git a/reader.go b/reader.go index e8a05dd7..f0216fc8 100644 --- a/reader.go +++ b/reader.go @@ -2,21 +2,172 @@ package ion import ( "bufio" + "bytes" "errors" "io" "math" "math/big" + "strings" "time" ) +// A Reader reads a stream of Ion values. +// +// The Reader has a logical position within the stream of values, influencing the +// values returnedd from its methods. Initially, the Reader is positioned before the +// first value in the stream. A call to Next advances the Reader to the first value +// in the stream, with subsequent calls advancing to subsequent values. When a call to +// Next moves the Reader to the position after the final value in the stream, it returns +// false, making it easy to loop through the values in a stream. +// +// var r Reader +// for r.Next() { +// // ... +// } +// +// Next also returns false in case of error. This can be distinguished from a legitimate +// end-of-stream by calling Err after exiting the loop. +// +// When positioned on an Ion value, the type of the value can be retrieved by calling +// Type. If it has an associated field name (inside a struct) or annotations, they can +// be read by calling FieldName and Annotations respectively. +// +// For atomic values, an appropriate XxxValue method can be called to read the value. +// For lists, sexps, and structs, you should instead call StepIn to move the Reader in +// to the contained sequence of values. The Reader will initially be positioned before +// the first value in the container. Calling Next without calling StepIn will skip over +// the composite value and return the next value in the outer value stream. +// +// At any point while reading through a composite value, including when Next returns false +// to indicate the end of the contained values, you may call StepOut to move back to the +// outer sequence of values. The Reader will be positioned at the end of the composite value, +// such that a call to Next will move to the immediately-following value (if any). +// +// r := NewTextReaderStr("[foo, bar] [baz]") +// for r.Next() { +// if err := r.StepIn(); err != nil { +// return err +// } +// for r.Next() { +// fmt.Println(r.StringValue()) +// } +// if err := r.StepOut(); err != nil { +// return err +// } +// } +// if err := r.Err(); err != nil { +// return err +// } +// +type Reader interface { + + // SymbolTable returns the current symbol table, or nil if there isn't one. + // Text Readers do not, generally speaking, have an associated symbol table. + // Binary Readers do. + SymbolTable() SymbolTable + + // Next advances the Reader to the next position in the current value stream. + // It returns true if this is the position of an Ion value, and false if it + // is not. On error, it returns false and sets Err. + Next() bool + + // Err returns an error if a previous call call to Next has failed. + Err() error + + // Type returns the type of the Ion value the Reader is currently positioned on. + // It returns NoType if the Reader is positioned before or after a value. + Type() Type + + // IsNull returns true if the current value is an explicit null. This may be true + // even if the Type is not NullType (for example, null.struct has type Struct). Yes, + // that's a bit confusing. + IsNull() bool + + // FieldName returns the field name associated with the current value. It returns + // the empty string if there is no current value or the current value has no field + // name. + FieldName() string + + // Annotations returns the set of annotations associated with the current value. + // It returns nil if there is no current value or the current value has no annotations. + Annotations() []string + + // StepIn steps in to the current value if it is a container. It returns an error if there + // is no current value or if the value is not a container. On success, the Reader is + // positioned before the first value in the container. + StepIn() error + + // StepOut steps out of the current container value being read. It returns an error if + // this Reader is not currently stepped in to a container. On success, the Reader is + // positioned after the end of the container, but before any subsequent values in the + // stream. + StepOut() error + + // BoolValue returns the current value as a boolean (if that makes sense). It returns + // an error if the current value is not an Ion bool. + BoolValue() (bool, error) + + // IntSize returns the size of integer needed to losslessly represent the current value + // (if that makes sense). It returns an error if the current value is not an Ion int. + IntSize() (IntSize, error) + + // IntValue returns the current value as a 32-bit integer (if that makes sense). It + // returns an error if the current value is not an Ion integer or requires more than + // 32 bits to represent losslessly. + IntValue() (int, error) + + // Int64Value returns the current value as a 64-bit integer (if that makes sense). It + // returns an error if the current value is not an Ion integer or requires more than + // 64 bits to represent losslessly. + Int64Value() (int64, error) + + // BigIntValue returns the current value as a big.Integer (if that makes sense). It + // returns an error if the current value is not an Ion integer. + BigIntValue() (*big.Int, error) + + // FloatValue returns the current value as a 64-bit floating point number (if that + // makes sense). It returns an error if the current value is not an Ion float. + FloatValue() (float64, error) + + // DecimalValue returns the current value as an arbitrary-precision Decimal (if that + // makes sense). It returns an error if the current value is not an Ion decimal. + DecimalValue() (*Decimal, error) + + // TimeValue returns the current value as a timestamp (if that makes sense). It returns + // an error if the current value is not an Ion timestamp. + TimeValue() (time.Time, error) + + // StringValue returns the current value as a string (if that makes sense). It returns + // an error if the current value is not an Ion symbol or an Ion string. + StringValue() (string, error) + + // ByteValue returns the current value as a byte slice (if that makes sense). It returns + // an error if the current value is not an Ion clob or an Ion blob. + ByteValue() ([]byte, error) +} + // NewReader creates a new Ion reader of the appropriate type by peeking // at the first several bytes of input for a binary version marker. func NewReader(in io.Reader) Reader { + return newReader(in, nil) +} + +// NewReaderStr creates a new reader from a string. +func NewReaderStr(str string) Reader { + return NewReader(strings.NewReader(str)) +} + +// NewReaderBytes creates a new reader for the given bytes. +func NewReaderBytes(in []byte) Reader { + return NewReader(bytes.NewReader(in)) +} + +func newReader(in io.Reader, cat *Catalog) Reader { br := bufio.NewReader(in) bs, err := br.Peek(4) if err == nil && bs[0] == 0xE0 && bs[3] == 0xEA { - return newBinaryReaderBuf(br, nil) + return newBinaryReaderBuf(br, cat) } return newTextReaderBuf(br) diff --git a/textreader.go b/textreader.go index 782538e6..2ea18b72 100644 --- a/textreader.go +++ b/textreader.go @@ -5,10 +5,8 @@ import ( "encoding/base64" "errors" "fmt" - "io" "math" "strconv" - "strings" ) // trs is the state of the text reader. @@ -49,16 +47,6 @@ type textReader struct { debug bool } -// NewTextReader creates a new text reader. -func NewTextReader(in io.Reader) Reader { - return newTextReaderBuf(bufio.NewReader(in)) -} - -// NewTextReaderStr creates a new text reader from a string. -func NewTextReaderStr(str string) Reader { - return NewTextReader(strings.NewReader(str)) -} - func newTextReaderBuf(in *bufio.Reader) Reader { return &textReader{ tok: tokenizer{ diff --git a/textreader_test.go b/textreader_test.go index 7c79ba3a..629ac69d 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -9,7 +9,7 @@ import ( ) func TestIgnoreValues(t *testing.T) { - r := NewTextReaderStr("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") + r := NewReaderStr("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") _next(t, r, SexpType) _next(t, r, StructType) @@ -22,7 +22,7 @@ func TestIgnoreValues(t *testing.T) { func TestReadSexps(t *testing.T) { test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _sexp(t, r, f) _eof(t, r) }) @@ -51,7 +51,7 @@ func TestReadSexps(t *testing.T) { func TestStructs(t *testing.T) { test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _struct(t, r, f) _eof(t, r) }) @@ -73,7 +73,7 @@ func TestStructs(t *testing.T) { } func TestMultipleStructs(t *testing.T) { - r := NewTextReaderStr("{} {} {}") + r := NewReaderStr("{} {} {}") for i := 0; i < 3; i++ { _struct(t, r, func(t *testing.T, r Reader) { @@ -85,7 +85,7 @@ func TestMultipleStructs(t *testing.T) { } func TestNullStructs(t *testing.T) { - r := NewTextReaderStr("null.struct 'null'::{foo:bar}") + r := NewReaderStr("null.struct 'null'::{foo:bar}") _null(t, r, StructType) _nextAF(t, r, StructType, "", []string{"null"}) @@ -95,7 +95,7 @@ func TestNullStructs(t *testing.T) { func TestLists(t *testing.T) { test := func(str string, f containerhandler) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _list(t, r, f) _eof(t, r) }) @@ -123,7 +123,7 @@ func TestReadNestedLists(t *testing.T) { _eof(t, r) } - r := NewTextReaderStr("[[], [[]]]") + r := NewReaderStr("[[], [[]]]") _list(t, r, func(t *testing.T, r Reader) { _list(t, r, empty) @@ -141,7 +141,7 @@ func TestReadNestedLists(t *testing.T) { func TestClobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _next(t, r, ClobType) val, err := r.ByteValue() @@ -165,7 +165,7 @@ func TestClobs(t *testing.T) { func TestBlobs(t *testing.T) { test := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _next(t, r, BlobType) val, err := r.ByteValue() @@ -188,7 +188,7 @@ func TestBlobs(t *testing.T) { func TestTimestamps(t *testing.T) { testA := func(str string, etas []string, eval time.Time) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _nextAF(t, r, TimestampType, "", etas) val, err := r.TimeValue() @@ -228,7 +228,7 @@ func TestDecimals(t *testing.T) { t.Run(str, func(t *testing.T) { ee := MustParseDecimal(eval) - r := NewTextReaderStr(str) + r := NewReaderStr(str) _nextAF(t, r, DecimalType, "", etas) val, err := r.DecimalValue() @@ -260,7 +260,7 @@ func TestDecimals(t *testing.T) { func TestFloats(t *testing.T) { testA := func(str string, etas []string, eval float64) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _floatAF(t, r, "", etas, eval) _eof(t, r) }) @@ -282,7 +282,7 @@ func TestFloats(t *testing.T) { func TestInts(t *testing.T) { test := func(str string, f func(*testing.T, Reader)) { t.Run(str, func(t *testing.T) { - r := NewTextReaderStr(str) + r := NewReaderStr(str) _next(t, r, IntType) f(t, r) @@ -352,7 +352,7 @@ func TestInts(t *testing.T) { } func TestStrings(t *testing.T) { - r := NewTextReaderStr(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) + r := NewReaderStr(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) _stringAF(t, r, "", []string{"foo"}, "bar") _string(t, r, "baz") @@ -363,7 +363,7 @@ func TestStrings(t *testing.T) { } func TestSymbols(t *testing.T) { - r := NewTextReaderStr("'null'::foo bar a::b::'baz' null.symbol") + r := NewReaderStr("'null'::foo bar a::b::'baz' null.symbol") _symbolAF(t, r, "", []string{"null"}, "foo") _symbol(t, r, "bar") @@ -374,7 +374,7 @@ func TestSymbols(t *testing.T) { } func TestSpecialSymbols(t *testing.T) { - r := NewTextReaderStr("null\nnull.struct\ntrue\nfalse\nnan") + r := NewReaderStr("null\nnull.struct\ntrue\nfalse\nnan") _null(t, r, NullType) _null(t, r, StructType) @@ -386,7 +386,7 @@ func TestSpecialSymbols(t *testing.T) { } func TestOperators(t *testing.T) { - r := NewTextReaderStr("(a*(b+c))") + r := NewReaderStr("(a*(b+c))") _sexp(t, r, func(t *testing.T, r Reader) { _symbol(t, r, "a") @@ -402,7 +402,7 @@ func TestOperators(t *testing.T) { } func TestTopLevelOperators(t *testing.T) { - r := NewTextReaderStr("a + b") + r := NewReaderStr("a + b") _symbol(t, r, "a") diff --git a/textutils.go b/textutils.go index e57d4ad2..4b2a0e83 100644 --- a/textutils.go +++ b/textutils.go @@ -5,12 +5,14 @@ import ( "io" "math/big" "strconv" + "strings" "time" ) // Does this symbol need to be quoted in text form? func symbolNeedsQuoting(sym string) bool { - if sym == "" || sym == "null" || sym == "true" || sym == "false" || sym == "nan" { + switch sym { + case "", "null", "true", "false", "nan": return true } @@ -91,9 +93,8 @@ func isDigit(c int) bool { // Is this a valid part of an operator symbol? func isOperatorChar(c int) bool { switch c { - case '!', '#', '%', '&', '*', '+', '-', '.', '/', ';', '<', '=': - return true - case '>', '?', '@', '^', '`', '|', '~': + case '!', '#', '%', '&', '*', '+', '-', '.', '/', ';', '<', '=', + '>', '?', '@', '^', '`', '|', '~': return true default: return false @@ -105,9 +106,8 @@ func isOperatorChar(c int) bool { // characters. Use tokenizer.isStopChar(c) or check for it yourself. func isStopChar(c int) bool { switch c { - case -1, '{', '}', '[', ']', '(', ')', ',', '"', '\'': - return true - case ' ', '\t', '\n', '\r': + case -1, '{', '}', '[', ']', '(', ')', ',', '"', '\'', + ' ', '\t', '\n', '\r': return true default: return false @@ -119,9 +119,34 @@ func isWhitespace(c int) bool { switch c { case ' ', '\t', '\n', '\r': return true - default: - return false } + return false +} + +// Formats a float64 in Ion text style. +func formatFloat(val float64) string { + str := strconv.FormatFloat(val, 'e', -1, 64) + + // Ion uses lower case for special values. + switch str { + case "NaN": + return "nan" + case "+Inf": + return "+inf" + case "-Inf": + return "-inf" + } + + idx := strings.Index(str, "e") + if idx < 0 { + // We need to add an 'e' or it will get interpreted as an Ion decimal. + str += "e0" + } else if idx+2 < len(str) && str[idx+2] == '0' { + // FormatFloat returns exponents with a leading ±0 in some cases; strip it. + str = str[:idx+2] + str[idx+3:] + } + + return str } // Write the given symbol out, quoting and encoding if necessary. @@ -172,19 +197,6 @@ func writeEscapedString(str string, out io.Writer) error { return nil } -func fromHex(c int) (int, error) { - if c >= '0' && c <= '9' { - return c - '0', nil - } - if c >= 'a' && c <= 'f' { - return 10 + (c - 'a'), nil - } - if c >= 'A' && c <= 'F' { - return 10 + (c - 'A'), nil - } - return 0, invalidChar(c) -} - // Write out the given character in escaped form. func writeEscapedChar(c byte, out io.Writer) error { switch c { diff --git a/textwriter.go b/textwriter.go index 0b1aa47a..add2cdce 100644 --- a/textwriter.go +++ b/textwriter.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "math/big" - "strconv" - "strings" "time" ) @@ -43,107 +41,177 @@ func NewTextWriterOpts(out io.Writer, opts TextWriterOpts) Writer { } } -// beginValue begins the process of writing a value, by writing out -// a separator (if needed), field name (if in a struct), and type -// annotations (if any). -func (w *textWriter) beginValue() error { - if w.needsSeparator { - var sep byte - switch w.ctx.peek() { - case ctxInStruct, ctxInList: - sep = ',' - case ctxInSexp: - sep = ' ' - default: - sep = '\n' - } +// // writeValue writes a value whose raw encoding is produced by the +// // given function. +// func (w *textWriter) writeValue(f func() string) error { +// if err := w.beginValue(); err != nil { +// return err +// } + +// sym := f() +// if err := writeRawString(sym, w.out); err != nil { +// return err +// } + +// w.endValue() +// return nil +// } + +// // writeValue writes a value by calling the given function, which is +// // expected to write the raw value to w.out. +// func (w *textWriter) writeValueStreaming(f func() error) error { +// if err := w.beginValue(); err != nil { +// return err +// } + +// if err := f(); err != nil { +// return err +// } + +// w.endValue() +// return nil +// } - if err := writeRawChar(sep, w.out); err != nil { - return err - } - } +// WriteNull writes an untyped null. +func (w *textWriter) WriteNull() error { + return w.WriteNullType(NoType) +} - if w.InStruct() { - if w.fieldName == "" { - return errors.New("field name not set") - } - name := w.fieldName - w.fieldName = "" +// WriteNullType writes a typed null. +func (w *textWriter) WriteNullType(t Type) error { + return w.writeValue(textNulls[t]) +} - if err := writeSymbol(name, w.out); err != nil { - return err - } - if err := writeRawChar(':', w.out); err != nil { - return err - } +// WriteBool writes a boolean value. +func (w *textWriter) WriteBool(val bool) error { + str := "false" + if val { + str = "true" } + return w.writeValue(str) +} - if len(w.annotations) > 0 { - as := w.annotations - w.annotations = nil +// WriteInt writes an integer value. +func (w *textWriter) WriteInt(val int64) error { + return w.writeValue(fmt.Sprintf("%d", val)) +} - for _, a := range as { - if err := writeSymbol(a, w.out); err != nil { - return err - } - if err := writeRawString("::", w.out); err != nil { - return err - } - } - } +// WriteBigInt writes a (big) integer value. +func (w *textWriter) WriteBigInt(val *big.Int) error { + return w.writeValue(val.String()) +} - return nil +// WriteFloat writes a floating-point value. +func (w *textWriter) WriteFloat(val float64) error { + return w.writeValue(formatFloat(val)) } -// endValue finishes the process of writing a value. -func (w *textWriter) endValue() { - w.needsSeparator = true +// WriteDecimal writes an arbitrary-precision decimal value. +func (w *textWriter) WriteDecimal(val *Decimal) error { + return w.writeValue(val.String()) } -// begin starts writing a container of the given type. -func (w *textWriter) begin(t ctx, c byte) error { - if err := w.beginValue(); err != nil { - return err +// WriteTimestamp writes a timestamp. +func (w *textWriter) WriteTimestamp(val time.Time) error { + return w.writeValue(val.Format(time.RFC3339Nano)) +} + +// WriteSymbol writes a symbol. +func (w *textWriter) WriteSymbol(val string) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err } - w.ctx.push(t) - w.needsSeparator = false + if w.err = writeSymbol(val, w.out); w.err != nil { + return w.err + } - return writeRawChar(c, w.out) + w.endValue() + return nil } -// end finishes writing a container of the given type -func (w *textWriter) end(t ctx, c byte) error { - if w.ctx.peek() != t { - return errors.New("not in that kind of container") +// WriteString writes a string. +func (w *textWriter) WriteString(val string) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err } - if err := writeRawChar(c, w.out); err != nil { - return err + if w.err = writeRawChar('"', w.out); w.err != nil { + return w.err + } + if w.err = writeEscapedString(val, w.out); w.err != nil { + return w.err + } + if w.err = writeRawChar('"', w.out); w.err != nil { + return w.err } - w.fieldName = "" - w.annotations = nil - w.ctx.pop() w.endValue() - return nil } -// BeginStruct begins writing a struct. -func (w *textWriter) BeginStruct() error { - if w.err == nil { - w.err = w.begin(ctxInStruct, '{') +// WriteClob writes a clob. +func (w *textWriter) WriteClob(val []byte) error { + if w.err != nil { + return w.err } - return w.err + if w.err = w.beginValue(); w.err != nil { + return w.err + } + + if w.err = writeRawString("{{\"", w.out); w.err != nil { + return w.err + } + for _, c := range val { + if c < 32 || c == '\\' || c == '"' || c > 0x7F { + if err := writeEscapedChar(c, w.out); err != nil { + return err + } + } else { + if err := writeRawChar(c, w.out); err != nil { + return err + } + } + } + if w.err = writeRawString("\"}}", w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil } -// EndStruct finishes writing a struct. -func (w *textWriter) EndStruct() error { - if w.err == nil { - w.err = w.end(ctxInStruct, '}') +// WriteBlob writes a blob. +func (w *textWriter) WriteBlob(val []byte) error { + if w.err != nil { + return w.err } - return w.err + if w.err = w.beginValue(); w.err != nil { + return w.err + } + + if w.err = writeRawString("{{", w.out); w.err != nil { + return w.err + } + + enc := base64.NewEncoder(base64.StdEncoding, w.out) + enc.Write(val) + if w.err = enc.Close(); w.err != nil { + return w.err + } + + if w.err = writeRawString("}}", w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil } // BeginList begins writing a list. @@ -178,256 +246,141 @@ func (w *textWriter) EndSexp() error { return w.err } -// writeValue writes a value whose raw encoding is produced by the -// given function. -func (w *textWriter) writeValue(f func() string) error { - if err := w.beginValue(); err != nil { - return err +// BeginStruct begins writing a struct. +func (w *textWriter) BeginStruct() error { + if w.err == nil { + w.err = w.begin(ctxInStruct, '{') } + return w.err +} - sym := f() - if err := writeRawString(sym, w.out); err != nil { - return err +// EndStruct finishes writing a struct. +func (w *textWriter) EndStruct() error { + if w.err == nil { + w.err = w.end(ctxInStruct, '}') } - - w.endValue() - return nil + return w.err } -// writeValue writes a value by calling the given function, which is -// expected to write the raw value to w.out. -func (w *textWriter) writeValueStreaming(f func() error) error { - if err := w.beginValue(); err != nil { - return err +// Finish finishes writing the current datagram. +func (w *textWriter) Finish() error { + if w.err != nil { + return w.err + } + if w.ctx.peek() != ctxAtTopLevel { + return errors.New("ion: Finish not at top level") } - if err := f(); err != nil { - return err + if w.opts&TextWriterQuietFinish == 0 { + if w.err = writeRawChar('\n', w.out); w.err != nil { + return w.err + } + w.needsSeparator = false } - w.endValue() + w.clear() return nil } -// WriteNull writes an untyped null. -func (w *textWriter) WriteNull() error { - return w.WriteNullType(NoType) -} - -// WriteNullType writes a typed null. -func (w *textWriter) WriteNullType(t Type) error { - if w.err == nil { - w.err = w.writeValue(func() string { - switch t { - case NoType: - return "null" - case NullType: - return "null.null" - case BoolType: - return "null.bool" - case IntType: - return "null.int" - case FloatType: - return "null.float" - case DecimalType: - return "null.decimal" - case TimestampType: - return "null.timestamp" - case StringType: - return "null.string" - case SymbolType: - return "null.symbol" - case BlobType: - return "null.blob" - case ClobType: - return "null.clob" - case StructType: - return "null.struct" - case ListType: - return "null.list" - case SexpType: - return "null.sexp" - default: - panic(fmt.Sprintf("invalid type: %v", t)) - } - }) +// writeValue writes a stringified value to the output stream. +func (w *textWriter) writeValue(val string) error { + if w.err != nil { + return w.err } - return w.err -} - -// WriteBool writes a boolean value. -func (w *textWriter) WriteBool(val bool) error { - if w.err == nil { - w.err = w.writeValue(func() string { - if val { - return "true" - } - return "false" - }) + if w.err = w.beginValue(); w.err != nil { + return w.err } - return w.err -} -// WriteInt writes an integer value. -func (w *textWriter) WriteInt(val int64) error { - if w.err == nil { - w.err = w.writeValue(func() string { - return fmt.Sprintf("%d", val) - }) + if w.err = writeRawString(val, w.out); w.err != nil { + return w.err } - return w.err -} -// WriteBigInt writes a (big) integer value. -func (w *textWriter) WriteBigInt(val *big.Int) error { - if w.err == nil { - w.err = w.writeValue(func() string { - return val.String() - }) - } - return w.err + w.endValue() + return nil } -// WriteFloat writes a floating-point value. -func (w *textWriter) WriteFloat(val float64) error { - if w.err == nil { - w.err = w.writeValue(func() string { - // Built-in go formatting isn't quite up to the task. :( - str := strconv.FormatFloat(val, 'e', -1, 64) - - switch str { - case "NaN": - return "nan" - case "+Inf": - return "+inf" - case "-Inf": - return "-inf" - default: - break - } - - idx := strings.Index(str, "e") - if idx < 0 { - str += "e0" - } else if idx+2 < len(str) && str[idx+2] == '0' { - str = str[:idx+2] + str[idx+3:] - } +// beginValue begins the process of writing a value, by writing out +// a separator (if needed), field name (if in a struct), and type +// annotations (if any). +func (w *textWriter) beginValue() error { + if w.needsSeparator { + var sep byte + switch w.ctx.peek() { + case ctxInStruct, ctxInList: + sep = ',' + case ctxInSexp: + sep = ' ' + default: + sep = '\n' + } - return str - }) + if err := writeRawChar(sep, w.out); err != nil { + return err + } } - return w.err -} -// WriteDecimal writes an arbitrary-precision decimal value. -func (w *textWriter) WriteDecimal(val *Decimal) error { - if w.err == nil { - w.err = w.writeValue(func() string { - return val.String() - }) - } - return w.err -} + if w.inStruct() { + if w.fieldName == "" { + return errors.New("ion: field name not set") + } + name := w.fieldName + w.fieldName = "" -// WriteTimestamp writes a timestamp. -func (w *textWriter) WriteTimestamp(val time.Time) error { - if w.err == nil { - w.err = w.writeValue(func() string { - return val.Format(time.RFC3339Nano) - }) + if err := writeSymbol(name, w.out); err != nil { + return err + } + if err := writeRawChar(':', w.out); err != nil { + return err + } } - return w.err -} -// WriteSymbol writes a symbol. -func (w *textWriter) WriteSymbol(val string) error { - if w.err == nil { - w.err = w.writeValueStreaming(func() error { - return writeSymbol(val, w.out) - }) - } - return w.err -} + if len(w.annotations) > 0 { + as := w.annotations + w.annotations = nil -// WriteString writes a string. -func (w *textWriter) WriteString(val string) error { - if w.err == nil { - w.err = w.writeValueStreaming(func() error { - if err := writeRawChar('"', w.out); err != nil { + for _, a := range as { + if err := writeSymbol(a, w.out); err != nil { return err } - if err := writeEscapedString(val, w.out); err != nil { + if err := writeRawString("::", w.out); err != nil { return err } - return writeRawChar('"', w.out) - }) + } } - return w.err -} - -// WriteBlob writes a blob. -func (w *textWriter) WriteBlob(val []byte) error { - if w.err == nil { - w.err = w.writeValueStreaming(func() error { - if err := writeRawString("{{", w.out); err != nil { - return err - } - enc := base64.NewEncoder(base64.StdEncoding, w.out) - enc.Write(val) - if err := enc.Close(); err != nil { - return err - } + return nil +} - return writeRawString("}}", w.out) - }) - } - return w.err +// endValue finishes the process of writing a value. +func (w *textWriter) endValue() { + w.needsSeparator = true } -// WriteClob writes a clob. -func (w *textWriter) WriteClob(val []byte) error { - if w.err == nil { - w.err = w.writeValueStreaming(func() error { - if err := writeRawString("{{\"", w.out); err != nil { - return err - } +// begin starts writing a container of the given type. +func (w *textWriter) begin(t ctx, c byte) error { + if err := w.beginValue(); err != nil { + return err + } - for _, c := range val { - if c < 32 || c == '\\' || c == '"' || c > 0x7F { - if err := writeEscapedChar(c, w.out); err != nil { - return err - } - } else { - if err := writeRawChar(c, w.out); err != nil { - return err - } - } - } + w.ctx.push(t) + w.needsSeparator = false - return writeRawString("\"}}", w.out) - }) - } - return w.err + return writeRawChar(c, w.out) } -// Finish finishes the current datagram. -func (w *textWriter) Finish() error { - if w.err != nil { - return w.err - } - if w.ctx.peek() != ctxAtTopLevel { - return errors.New("not at top level") +// end finishes writing a container of the given type +func (w *textWriter) end(t ctx, c byte) error { + if w.ctx.peek() != t { + return errors.New("ion: End called with wrong container type") } - if w.opts&TextWriterQuietFinish == 0 { - if w.err = writeRawChar('\n', w.out); w.err != nil { - return w.err - } + if err := writeRawChar(c, w.out); err != nil { + return err } - w.fieldName = "" - w.annotations = nil - w.needsSeparator = false + w.clear() + w.ctx.pop() + w.endValue() + return nil } diff --git a/tokenizer.go b/tokenizer.go index 1831478a..f545328f 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -685,6 +685,19 @@ func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { return val, nil } +func fromHex(c int) (int, error) { + if c >= '0' && c <= '9' { + return c - '0', nil + } + if c >= 'a' && c <= 'f' { + return 10 + (c - 'a'), nil + } + if c >= 'A' && c <= 'F' { + return 10 + (c - 'A'), nil + } + return 0, invalidChar(c) +} + func (t *tokenizer) readBinary() (string, error) { isB := func(c int) bool { return c == 'b' || c == 'B' diff --git a/type.go b/type.go new file mode 100644 index 00000000..0371a818 --- /dev/null +++ b/type.go @@ -0,0 +1,121 @@ +package ion + +import "fmt" + +// A Type represents the type of an Ion Value. +type Type uint8 + +const ( + // NoType is returned by a Reader that is not currently pointing at a value. + NoType Type = iota + + // NullType is the type of the (unqualified) Ion null value. + NullType + + // BoolType is the type of an Ion boolean, true or false. + BoolType + + // IntType is the type of a signed Ion integer of arbitrary size. + IntType + + // FloatType is the type of a fixed-precision Ion floating-point value. + FloatType + + // DecimalType is the type of an arbitrary-precision Ion decimal value. + DecimalType + + // TimestampType is the type of an arbitrary-precision Ion timestamp. + TimestampType + + // SymbolType is the type of an Ion symbol, mapped to an integer ID by a SymbolTable + // to (potentially) save space. + SymbolType + + // StringType is the type of a non-symbol Unicode string, represented directly. + StringType + + // ClobType is the type of a character large object. Like a BlobType, it stores an + // arbitrary sequence of bytes, but it represents them in text form as an escaped-ASCII + // string rather than a base64-encoded string. + ClobType + + // BlobType is the type of a binary large object; a sequence of arbitrary bytes. + BlobType + + // ListType is the type of a list, recursively containing zero or more Ion values. + ListType + + // SexpType is the type of an s-expression. Like a ListType, it contains a sequence + // of zero or more Ion values, but with a lisp-like syntax when encoded as text. + SexpType + + // StructType is the type of a structure, recursively containing a sequence of named + // (by an Ion symbol) Ion values. + StructType +) + +// String implements fmt.Stringer for Type. +func (t Type) String() string { + switch t { + case NoType: + return "" + case NullType: + return "null" + case BoolType: + return "bool" + case IntType: + return "int" + case FloatType: + return "float" + case DecimalType: + return "decimal" + case TimestampType: + return "timestamp" + case StringType: + return "string" + case SymbolType: + return "symbol" + case BlobType: + return "blob" + case ClobType: + return "clob" + case StructType: + return "struct" + case ListType: + return "list" + case SexpType: + return "sexp" + default: + return fmt.Sprintf("", uint8(t)) + } +} + +// IntSize represents the size of an integer. +type IntSize uint8 + +const ( + // NullInt is the size of null.int and other things that aren't actually ints. + NullInt IntSize = iota + // Int32 is the size of an Ion integer that can be losslessly stored in an int32. + Int32 + // Int64 is the size of an Ion integer that can be losslessly stored in an int64. + Int64 + // BigInt is the size of an Ion integer that can only be losslessly stored in a big.Int. + BigInt +) + +// String implements fmt.Stringer for IntSize. +func (i IntSize) String() string { + switch i { + case NullInt: + return "null.int" + case Int32: + return "int32" + case Int64: + return "int64" + case BigInt: + return "big.Int" + default: + return fmt.Sprintf("", uint8(i)) + } +} diff --git a/api_test.go b/type_test.go similarity index 94% rename from api_test.go rename to type_test.go index bbc50a82..e1702baa 100644 --- a/api_test.go +++ b/type_test.go @@ -1,8 +1,6 @@ package ion -import ( - "testing" -) +import "testing" func TestTypeToString(t *testing.T) { for i := NoType; i <= StructType+1; i++ { diff --git a/unmarshal.go b/unmarshal.go index 6b018fa1..825238cb 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -17,8 +17,7 @@ var ( // Unmarshal unmarshals Ion data to the given object. func Unmarshal(data []byte, v interface{}) error { - // TODO: Figure out if it's text or binary instead of hardcoding text. - return NewDecoder(NewTextReader(bytes.NewReader(data))).DecodeTo(v) + return NewDecoder(NewReader(bytes.NewReader(data))).DecodeTo(v) } // UnmarshalStr unmarshals Ion data from a string to the given object. diff --git a/unmarshal_test.go b/unmarshal_test.go index 763d8256..e3e2d2d6 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -272,7 +272,7 @@ func TestUnmarshalBigInt(t *testing.T) { func TestDecodeFloat(t *testing.T) { test32 := func(str string, eval float32) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val float32 err := d.DecodeTo(&val) @@ -292,7 +292,7 @@ func TestDecodeFloat(t *testing.T) { test64 := func(str string, eval float64) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val float64 err := d.DecodeTo(&val) @@ -313,7 +313,7 @@ func TestDecodeFloat(t *testing.T) { func TestDecodeDecimal(t *testing.T) { test := func(str string, eval *Decimal) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val *Decimal err := d.DecodeTo(&val) @@ -334,7 +334,7 @@ func TestDecodeDecimal(t *testing.T) { func TestDecodeTimeTo(t *testing.T) { test := func(str string, eval time.Time) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val time.Time err := d.DecodeTo(&val) @@ -354,7 +354,7 @@ func TestDecodeTimeTo(t *testing.T) { func TestDecodeStringTo(t *testing.T) { test := func(str string, eval string) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val string err := d.DecodeTo(&val) @@ -376,7 +376,7 @@ func TestDecodeStringTo(t *testing.T) { func TestDecodeLobTo(t *testing.T) { testSlice := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val []byte err := d.DecodeTo(&val) @@ -396,7 +396,7 @@ func TestDecodeLobTo(t *testing.T) { testArray := func(str string, eval []byte) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) var val [8]byte err := d.DecodeTo(&val) @@ -416,7 +416,7 @@ func TestDecodeLobTo(t *testing.T) { func TestDecodeStructTo(t *testing.T) { test := func(str string, val, eval interface{}) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) err := d.DecodeTo(val) if err != nil { t.Fatal(err) @@ -446,7 +446,7 @@ func TestDecodeStructTo(t *testing.T) { func TestDecodeListTo(t *testing.T) { test := func(str string, val, eval interface{}) { t.Run(str, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(str)) + d := NewDecoder(NewReaderStr(str)) err := d.DecodeTo(val) if err != nil { t.Fatal(err) @@ -478,7 +478,7 @@ func TestDecodeListTo(t *testing.T) { func TestDecode(t *testing.T) { test := func(data string, eval interface{}) { t.Run(data, func(t *testing.T) { - d := NewDecoder(NewTextReaderStr(data)) + d := NewDecoder(NewReaderStr(data)) val, err := d.Decode() if err != nil { t.Fatal(err) diff --git a/writer.go b/writer.go index 19d20cfb..232db8d1 100644 --- a/writer.go +++ b/writer.go @@ -3,9 +3,114 @@ package ion import ( "errors" "io" + "math/big" + "time" ) -// writer holds shared stuff for all writers. +// A Writer writes a stream of Ion values. +// +// The various Write methods write atomic values to the current output stream. The +// Begin methods begin writing a list, sexp, or struct respectively. Subsequent +// calls to Write will write values inside of the container until a matching +// End method is called. +// +// var w Writer +// w.BeginSexp() +// { +// w.WriteInt(1) +// w.WriteSymbol("+") +// w.WriteInt(1) +// } +// w.EndSexp() +// +// When writing values inside a struct, the FieldName method must be called before +// each value to set the value's field name. The Annotation method may likewise +// be called before writing any value to add an annotation to the value. +// +// var w Writer +// w.Annotation("user") +// w.BeginStruct() +// { +// w.FieldName("id") +// w.WriteString("qu33nb33") +// w.FieldName("name") +// w.WriteString("Beyoncé") +// } +// w.EndStruct() +// +// When you're done writing values, you should call Finish to ensure everything has +// been flushed from in-memory buffers. While individual methods all return an error +// on failure, implementations will remember any errors, no-op subsequent calls, and +// return the previous error. This lets you keep code a bit cleaner by only checking +// the return value of the final method call (generally Finish). +// +// var w Writer +// writeSomeStuff(w) +// if err := w.Finish(); err != nil { +// return err +// } +// +type Writer interface { + + // FieldName sets the field name for the next value written. + FieldName(val string) error + + // Annotation adds a single annotation to the next value written. + Annotation(val string) error + + // Annotations adds multiple annotations to the next value written. + Annotations(vals ...string) error + + // WriteNull writes an untyped null value. + WriteNull() error + // WriteNullType writes a null value with a type qualifier, e.g. null.bool. + WriteNullType(t Type) error + + // WriteBool writes a boolean value. + WriteBool(val bool) error + + // WriteInt writes an integer value. + WriteInt(val int64) error + // WriteBigInt writes a big integer value. + WriteBigInt(val *big.Int) error + // WriteFloat writes a floating-point value. + WriteFloat(val float64) error + // WriteDecimal writes an arbitrary-precision decimal value. + WriteDecimal(val *Decimal) error + + // WriteTimestamp writes a timestamp value. + WriteTimestamp(val time.Time) error + + // WriteSymbol writes a symbol value. + WriteSymbol(val string) error + // WriteString writes a string value. + WriteString(val string) error + + // WriteClob writes a clob value. + WriteClob(val []byte) error + // WriteBlob writes a blob value. + WriteBlob(val []byte) error + + // BeginList begins writing a list value. + BeginList() error + // EndList finishes writing a list value. + EndList() error + + // BeginSexp begins writing an s-expression value. + BeginSexp() error + // EndSexp finishes writing an s-expression value. + EndSexp() error + + // BeginStruct begins writing a struct value. + BeginStruct() error + // EndStruct finishes writing a struct value. + EndStruct() error + + // Finish finishes writing values and flushes any buffered data. + Finish() error +} + +// A writer holds shared stuff for all writers. type writer struct { out io.Writer ctx ctxstack @@ -15,34 +120,14 @@ type writer struct { annotations []string } -// InStruct returns true if we're currently writing a struct. -func (w *writer) InStruct() bool { - return w.ctx.peek() == ctxInStruct -} - -// InList returns true if we're currently writing a list. -func (w *writer) InList() bool { - return w.ctx.peek() == ctxInList -} - -// InSexp returns true if we're currently writing an s-expression. -func (w *writer) InSexp() bool { - return w.ctx.peek() == ctxInSexp -} - -// Err returns the current error, or nil if there are none yet. -func (w *writer) Err() error { - return w.err -} - // FieldName sets the field name for the next value written. // It may only be called while writing a struct. func (w *writer) FieldName(val string) error { if w.err != nil { return w.err } - if !w.InStruct() { - w.err = errors.New("FieldName() called but not writing a struct") + if !w.inStruct() { + w.err = errors.New("ion: Writer.FieldName called when not writing a struct") return w.err } @@ -65,3 +150,14 @@ func (w *writer) Annotations(val ...string) error { } return w.err } + +// InStruct returns true if we're currently writing a struct. +func (w *writer) inStruct() bool { + return w.ctx.peek() == ctxInStruct +} + +// Clear clears field name and annotations after writing a value. +func (w *writer) clear() { + w.fieldName = "" + w.annotations = nil +} From 879367005bd01edb4dab11076d01aa47e12f9d7f Mon Sep 17 00:00:00 2001 From: David Murray Date: Sun, 1 Sep 2019 14:10:49 +1000 Subject: [PATCH 36/56] More refactoring --- binarywriter.go | 706 ++++++++++++++++------------ bufnode.go => bufstack.go | 0 bufnode_test.go => bufstack_test.go | 0 textwriter.go | 31 -- textwriter_test.go | 36 +- 5 files changed, 418 insertions(+), 355 deletions(-) rename bufnode.go => bufstack.go (100%) rename bufnode_test.go => bufstack_test.go (100%) diff --git a/binarywriter.go b/binarywriter.go index a8c6fd34..6633cbcb 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -45,331 +45,244 @@ func NewBinaryWriterLST(out io.Writer, lst SymbolTable) Writer { } } -// Emit emits the given node. If we're currently at the top level, that -// means actually emitting to the output stream. If not, we emit append -// to the current bufseq. -func (w *binaryWriter) emit(node bufnode) error { - s := w.bufs.peek() - if s == nil { - return node.EmitTo(w.out) - } - s.Append(node) - return nil -} - -// Write emits the given bytes as an atom. -func (w *binaryWriter) write(bs []byte) error { - return w.emit(atom(bs)) +// WriteNull writes an untyped null. +func (w *binaryWriter) WriteNull() error { + return w.WriteNullType(NoType) } -// WriteTag writes out a type+length tag. Use me when you've already got the value to -// be written as a []byte and don't want to copy it. -func (w *binaryWriter) writeTag(code byte, len uint64) error { - tl := tagLen(len) +// WriteNullType writes a typed null. +func (w *binaryWriter) WriteNullType(t Type) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } - tag := make([]byte, 0, tl) - tag = appendTag(tag, code, len) + if w.err = w.write([]byte{binaryNulls[t]}); w.err != nil { + return w.err + } - return w.write(tag) + w.err = w.endValue() + return w.err } -// WriteLST writes out a local symbol table. -func (w *binaryWriter) writeLST(lst SymbolTable) error { - if err := w.write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { - return err +// WriteBool writes a bool. +func (w *binaryWriter) WriteBool(val bool) error { + if w.err != nil { + return w.err } - return lst.WriteTo(w) -} - -// BeginValue begins the process of writing a value by writing out -// its field name and annotations. -func (w *binaryWriter) beginValue() error { - // We have to record/empty these before calling writeLST, which - // will end up using/modifying them. Ugh. - name := w.fieldName - w.fieldName = "" - as := w.annotations - w.annotations = nil - - // If we have a local symbol table and haven't written it out yet, do that now. - if w.lst != nil && !w.wroteLST { - w.wroteLST = true - if err := w.writeLST(w.lst); err != nil { - return err - } + if w.err = w.beginValue(); w.err != nil { + return w.err } - if w.inStruct() { - if name == "" { - return errors.New("ion: field name not set") - } - - id, err := w.resolve(name) - if err != nil { - return err - } - - buf := make([]byte, 0, 10) - buf = appendVarUint(buf, id) - if err := w.write(buf); err != nil { - return err - } + b := byte(0x10) + if val { + b = 0x11 } - if len(as) > 0 { - ids := make([]uint64, len(as)) - idlen := uint64(0) - - for i, a := range as { - id, err := w.resolve(a) - if err != nil { - return err - } - - ids[i] = id - idlen += varUintLen(id) - } - - buflen := idlen + varUintLen(idlen) - buf := make([]byte, 0, buflen) - - buf = appendVarUint(buf, idlen) - for _, id := range ids { - buf = appendVarUint(buf, id) - } - - w.bufs.push(&container{code: 0xE0}) - if err := w.write(buf); err != nil { - return err - } + if w.err = w.write([]byte{b}); w.err != nil { + return w.err } - return nil + w.err = w.endValue() + return w.err } -// EndValue ends the process of writing a value by flushing it and its annotations -// up a level, if needed. -func (w *binaryWriter) endValue() error { - seq := w.bufs.peek() - if seq != nil { - if c, ok := seq.(*container); ok && c.code == 0xE0 { - w.bufs.pop() - return w.emit(seq) - } +// WriteInt writes an integer. +func (w *binaryWriter) WriteInt(val int64) error { + if w.err != nil { + return w.err } - return nil -} - -// WriteValue writes an atomic value, invoking the given function to write the -// actual value contents. -func (w *binaryWriter) writeValue(f func() error) error { - if err := w.beginValue(); err != nil { - return err + if w.err = w.beginValue(); w.err != nil { + return w.err } - if err := f(); err != nil { - return err + if w.err = w.writeInt(val); w.err != nil { + return w.err } - return w.endValue() + w.err = w.endValue() + return w.err } -// BeginContainer begins writing a new container. -func (w *binaryWriter) beginContainer(t ctx, code byte) error { - if err := w.beginValue(); err != nil { - return err +// WriteInt writes the actual integer value. +func (w *binaryWriter) writeInt(val int64) error { + if val == 0 { + return w.write([]byte{0x20}) } - w.ctx.push(t) - w.bufs.push(&container{code: code}) - - return nil -} + code := byte(0x20) + mag := uint64(val) -// EndContainer ends writing a container, emitting its buffered contents up -// a level in the stack. -func (w *binaryWriter) endContainer(t ctx) error { - if w.ctx.peek() != t { - return errors.New("ion: not in that kind of container") + if val < 0 { + code = 0x30 + mag = uint64(-val) } - seq := w.bufs.peek() - if seq != nil { - w.bufs.pop() - if err := w.emit(seq); err != nil { - return err - } - } + len := uintLen(mag) + buflen := len + tagLen(len) - w.fieldName = "" - w.annotations = nil - w.ctx.pop() + buf := make([]byte, 0, buflen) + buf = appendTag(buf, code, len) + buf = appendUint(buf, mag) - return w.endValue() + return w.write(buf) } -// WriteNull writes an untyped null. -func (w *binaryWriter) WriteNull() error { - return w.WriteNullType(NoType) -} +// WriteBigInt writes a big integer. +func (w *binaryWriter) WriteBigInt(val *big.Int) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } -// WriteNullType writes a typed null. -func (w *binaryWriter) WriteNullType(t Type) error { - if w.err == nil { - w.err = w.writeValue(func() error { - return w.write([]byte{binaryNulls[t]}) - }) + if w.err = w.writeBigInt(val); w.err != nil { + return w.err } + + w.err = w.endValue() return w.err } -// WriteBool writes a bool. -func (w *binaryWriter) WriteBool(val bool) error { - if w.err == nil { - w.err = w.writeValue(func() error { - if val { - return w.write([]byte{0x11}) - } - return w.write([]byte{0x10}) - }) +// WriteBigInt writes the actual big integer value. +func (w *binaryWriter) writeBigInt(val *big.Int) error { + sign := val.Sign() + if sign == 0 { + return w.write([]byte{0x20}) } - return w.err -} -// WriteInt writes an integer. -func (w *binaryWriter) WriteInt(val int64) error { - if w.err == nil { - w.err = w.writeValue(func() error { - if val == 0 { - return w.write([]byte{0x20}) - } + code := byte(0x20) + if sign < 0 { + code = 0x30 + } - code := byte(0x20) - mag := uint64(val) + bs := val.Bytes() - if val < 0 { - code = 0x30 - mag = uint64(-val) - } + bl := uint64(len(bs)) + if bl < 64 { + buflen := bl + tagLen(bl) + buf := make([]byte, 0, buflen) - len := uintLen(mag) - buflen := len + tagLen(len) + buf = appendTag(buf, code, bl) + buf = append(buf, bs...) + return w.write(buf) + } + + // no sense in copying, emit tag separately. + if err := w.writeTag(code, bl); err != nil { + return err + } + return w.write(bs) +} - buf := make([]byte, 0, buflen) - buf = appendTag(buf, code, len) - buf = appendUint(buf, mag) +// WriteFloat writes a floating-point value. +func (w *binaryWriter) WriteFloat(val float64) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } - return w.write(buf) - }) + if w.err = w.writeFloat(val); w.err != nil { + return w.err } + + w.err = w.endValue() return w.err + } -// WriteBigInt writes a big integer. -func (w *binaryWriter) WriteBigInt(val *big.Int) error { - if w.err == nil { - w.err = w.writeValue(func() error { - sign := val.Sign() - if sign == 0 { - return w.write([]byte{0x20}) - } +// WriteFloat writes the actual float value. +func (w *binaryWriter) writeFloat(val float64) error { + if val == 0 { + return w.write([]byte{0x40}) + } - code := byte(0x20) - if sign < 0 { - code = 0x30 - } + bs := make([]byte, 9) + bs[0] = 0x48 - bs := val.Bytes() + bits := math.Float64bits(val) + binary.BigEndian.PutUint64(bs[1:], bits) - bl := uint64(len(bs)) - if bl < 64 { - buflen := bl + tagLen(bl) - buf := make([]byte, 0, buflen) + return w.write(bs) +} - buf = appendTag(buf, code, bl) - buf = append(buf, bs...) - return w.write(buf) - } +// WriteDecimal writes a decimal value. +func (w *binaryWriter) WriteDecimal(val *Decimal) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } - // no sense in copying, emit tag separately. - if err := w.writeTag(code, bl); err != nil { - return err - } - return w.write(bs) - }) + if w.err = w.writeDecimal(val); w.err != nil { + return w.err } + + w.err = w.endValue() return w.err } -// WriteFloat writes a floating-point value. -func (w *binaryWriter) WriteFloat(val float64) error { - if w.err == nil { - w.err = w.writeValue(func() error { - if val == 0 { - return w.write([]byte{0x40}) - } - - bs := make([]byte, 9) - bs[0] = 0x48 +// WriteDecimal writes the actual decimal value. +func (w *binaryWriter) writeDecimal(val *Decimal) error { + coef, exp := val.CoEx() - bits := math.Float64bits(val) - binary.BigEndian.PutUint64(bs[1:], bits) - - return w.write(bs) - }) + vlen := uint64(0) + if exp != 0 { + vlen += varIntLen(int64(exp)) + } + if coef.Sign() != 0 { + vlen += bigIntLen(coef) } - return w.err -} -// WriteDecimal writes a decimal value. -func (w *binaryWriter) WriteDecimal(val *Decimal) error { - if w.err == nil { - w.err = w.writeValue(func() error { - coef, exp := val.CoEx() + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - vlen := uint64(0) - if exp != 0 { - vlen += varIntLen(int64(exp)) - } - if coef.Sign() != 0 { - vlen += bigIntLen(coef) - } + buf = appendTag(buf, 0x50, vlen) + if exp != 0 { + buf = appendVarInt(buf, int64(exp)) + } + buf = appendBigInt(buf, coef) - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + return w.write(buf) +} - buf = appendTag(buf, 0x50, vlen) - if exp != 0 { - buf = appendVarInt(buf, int64(exp)) - } - buf = appendBigInt(buf, coef) +// WriteTimestamp writes a timestamp value. +func (w *binaryWriter) WriteTimestamp(val time.Time) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } - return w.write(buf) - }) + if w.err = w.writeTimestamp(val); w.err != nil { + return w.err } + + w.err = w.endValue() return w.err } -// WriteTimestamp writes a timestamp value. -func (w *binaryWriter) WriteTimestamp(val time.Time) error { - if w.err == nil { - w.err = w.writeValue(func() error { - _, offset := val.Zone() - offset /= 60 - utc := val.In(time.UTC) +func (w *binaryWriter) writeTimestamp(val time.Time) error { + _, offset := val.Zone() + offset /= 60 + utc := val.In(time.UTC) - vlen := timeLen(offset, utc) - buflen := vlen + tagLen(vlen) + vlen := timeLen(offset, utc) + buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + buf := make([]byte, 0, buflen) - buf = appendTag(buf, 0x60, vlen) - buf = appendTime(buf, offset, utc) + buf = appendTag(buf, 0x60, vlen) + buf = appendTime(buf, offset, utc) - return w.write(buf) - }) - } - return w.err + return w.write(buf) } // WriteSymbol writes a symbol value. @@ -377,25 +290,32 @@ func (w *binaryWriter) WriteSymbol(val string) error { if w.err != nil { return w.err } + if w.err = w.beginValue(); w.err != nil { + return w.err + } - id, err := w.resolve(val) - if err != nil { - w.err = err + if w.err = w.writeSymbol(val); w.err != nil { return w.err } - w.err = w.writeValue(func() error { - vlen := uintLen(uint64(id)) - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + w.err = w.endValue() + return w.err +} - buf = appendTag(buf, 0x70, vlen) - buf = appendUint(buf, uint64(id)) +func (w *binaryWriter) writeSymbol(val string) error { + id, err := w.resolve(val) + if err != nil { + return err + } - return w.write(buf) - }) + vlen := uintLen(uint64(id)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - return w.err + buf = appendTag(buf, 0x70, vlen) + buf = appendUint(buf, uint64(id)) + + return w.write(buf) } // Resolve resolves a symbol to its ID. @@ -420,64 +340,93 @@ func (w *binaryWriter) resolve(sym string) (uint64, error) { // WriteString writes a string. func (w *binaryWriter) WriteString(val string) error { - if w.err == nil { - w.err = w.writeValue(func() error { - if len(val) == 0 { - return w.write([]byte{0x80}) - } + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } - vlen := uint64(len(val)) - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + if w.err = w.writeString(val); w.err != nil { + return w.err + } - buf = appendTag(buf, 0x80, vlen) - buf = append(buf, val...) + w.err = w.endValue() + return w.err +} - return w.write(buf) - }) +func (w *binaryWriter) writeString(val string) error { + if len(val) == 0 { + return w.write([]byte{0x80}) } - return w.err + + vlen := uint64(len(val)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x80, vlen) + buf = append(buf, val...) + + return w.write(buf) } // WriteClob writes a clob. func (w *binaryWriter) WriteClob(val []byte) error { - return w.writeLob(0x90, val) + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } + + if w.err = w.writeLob(0x90, val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err } // WriteBlob writes a blob. func (w *binaryWriter) WriteBlob(val []byte) error { - return w.writeLob(0xA0, val) + if w.err != nil { + return w.err + } + if w.err = w.beginValue(); w.err != nil { + return w.err + } + + if w.err = w.writeLob(0xA0, val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err } -// WriteLob writes a [bc]lob. func (w *binaryWriter) writeLob(code byte, val []byte) error { - if w.err == nil { - w.err = w.writeValue(func() error { - vlen := uint64(len(val)) + vlen := uint64(len(val)) - if vlen < 64 { - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) + if vlen < 64 { + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) - buf = appendTag(buf, code, vlen) - buf = append(buf, val...) + buf = appendTag(buf, code, vlen) + buf = append(buf, val...) - return w.write(buf) - } + return w.write(buf) + } - if err := w.writeTag(code, vlen); err != nil { - return err - } - return w.write(val) - }) + if err := w.writeTag(code, vlen); err != nil { + return err } - return w.err + return w.write(val) } // BeginList begins writing a list. func (w *binaryWriter) BeginList() error { if w.err == nil { - w.err = w.beginContainer(ctxInList, 0xB0) + w.err = w.begin(ctxInList, 0xB0) } return w.err } @@ -485,7 +434,7 @@ func (w *binaryWriter) BeginList() error { // EndList finishes writing a list. func (w *binaryWriter) EndList() error { if w.err == nil { - w.err = w.endContainer(ctxInList) + w.err = w.end(ctxInList) } return w.err } @@ -493,7 +442,7 @@ func (w *binaryWriter) EndList() error { // BeginSexp begins writing an s-expression. func (w *binaryWriter) BeginSexp() error { if w.err == nil { - w.err = w.beginContainer(ctxInSexp, 0xC0) + w.err = w.begin(ctxInSexp, 0xC0) } return w.err } @@ -501,7 +450,7 @@ func (w *binaryWriter) BeginSexp() error { // EndSexp finishes writing an s-expression. func (w *binaryWriter) EndSexp() error { if w.err == nil { - w.err = w.endContainer(ctxInSexp) + w.err = w.end(ctxInSexp) } return w.err } @@ -509,7 +458,7 @@ func (w *binaryWriter) EndSexp() error { // BeginStruct begins writing a struct. func (w *binaryWriter) BeginStruct() error { if w.err == nil { - w.err = w.beginContainer(ctxInStruct, 0xD0) + w.err = w.begin(ctxInStruct, 0xD0) } return w.err } @@ -517,7 +466,7 @@ func (w *binaryWriter) BeginStruct() error { // EndStruct finishes writing a struct. func (w *binaryWriter) EndStruct() error { if w.err == nil { - w.err = w.endContainer(ctxInStruct) + w.err = w.end(ctxInStruct) } return w.err } @@ -531,8 +480,7 @@ func (w *binaryWriter) Finish() error { return errors.New("ion: not at top level") } - w.fieldName = "" - w.annotations = nil + w.clear() w.wroteLST = false seq := w.bufs.peek() @@ -553,3 +501,149 @@ func (w *binaryWriter) Finish() error { return nil } + +// Emit emits the given node. If we're currently at the top level, that +// means actually emitting to the output stream. If not, we emit append +// to the current bufseq. +func (w *binaryWriter) emit(node bufnode) error { + s := w.bufs.peek() + if s == nil { + return node.EmitTo(w.out) + } + s.Append(node) + return nil +} + +// Write emits the given bytes as an atom. +func (w *binaryWriter) write(bs []byte) error { + return w.emit(atom(bs)) +} + +// WriteTag writes out a type+length tag. Use me when you've already got the value to +// be written as a []byte and don't want to copy it. +func (w *binaryWriter) writeTag(code byte, len uint64) error { + tl := tagLen(len) + + tag := make([]byte, 0, tl) + tag = appendTag(tag, code, len) + + return w.write(tag) +} + +// WriteLST writes out a local symbol table. +func (w *binaryWriter) writeLST(lst SymbolTable) error { + if err := w.write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { + return err + } + return lst.WriteTo(w) +} + +// BeginValue begins the process of writing a value by writing out +// its field name and annotations. +func (w *binaryWriter) beginValue() error { + // We have to record/empty these before calling writeLST, which + // will end up using/modifying them. Ugh. + name := w.fieldName + as := w.annotations + w.clear() + + // If we have a local symbol table and haven't written it out yet, do that now. + if w.lst != nil && !w.wroteLST { + w.wroteLST = true + if err := w.writeLST(w.lst); err != nil { + return err + } + } + + if w.inStruct() { + if name == "" { + return errors.New("ion: field name not set") + } + + id, err := w.resolve(name) + if err != nil { + return err + } + + buf := make([]byte, 0, 10) + buf = appendVarUint(buf, id) + if err := w.write(buf); err != nil { + return err + } + } + + if len(as) > 0 { + ids := make([]uint64, len(as)) + idlen := uint64(0) + + for i, a := range as { + id, err := w.resolve(a) + if err != nil { + return err + } + + ids[i] = id + idlen += varUintLen(id) + } + + buflen := idlen + varUintLen(idlen) + buf := make([]byte, 0, buflen) + + buf = appendVarUint(buf, idlen) + for _, id := range ids { + buf = appendVarUint(buf, id) + } + + w.bufs.push(&container{code: 0xE0}) + if err := w.write(buf); err != nil { + return err + } + } + + return nil +} + +// EndValue ends the process of writing a value by flushing it and its annotations +// up a level, if needed. +func (w *binaryWriter) endValue() error { + seq := w.bufs.peek() + if seq != nil { + if c, ok := seq.(*container); ok && c.code == 0xE0 { + w.bufs.pop() + return w.emit(seq) + } + } + return nil +} + +// Begin begins writing a new container. +func (w *binaryWriter) begin(t ctx, code byte) error { + if err := w.beginValue(); err != nil { + return err + } + + w.ctx.push(t) + w.bufs.push(&container{code: code}) + + return nil +} + +// End ends writing a container, emitting its buffered contents up a level in the stack. +func (w *binaryWriter) end(t ctx) error { + if w.ctx.peek() != t { + return errors.New("ion: not in that kind of container") + } + + seq := w.bufs.peek() + if seq != nil { + w.bufs.pop() + if err := w.emit(seq); err != nil { + return err + } + } + + w.clear() + w.ctx.pop() + + return w.endValue() +} diff --git a/bufnode.go b/bufstack.go similarity index 100% rename from bufnode.go rename to bufstack.go diff --git a/bufnode_test.go b/bufstack_test.go similarity index 100% rename from bufnode_test.go rename to bufstack_test.go diff --git a/textwriter.go b/textwriter.go index add2cdce..f7648873 100644 --- a/textwriter.go +++ b/textwriter.go @@ -41,37 +41,6 @@ func NewTextWriterOpts(out io.Writer, opts TextWriterOpts) Writer { } } -// // writeValue writes a value whose raw encoding is produced by the -// // given function. -// func (w *textWriter) writeValue(f func() string) error { -// if err := w.beginValue(); err != nil { -// return err -// } - -// sym := f() -// if err := writeRawString(sym, w.out); err != nil { -// return err -// } - -// w.endValue() -// return nil -// } - -// // writeValue writes a value by calling the given function, which is -// // expected to write the raw value to w.out. -// func (w *textWriter) writeValueStreaming(f func() error) error { -// if err := w.beginValue(); err != nil { -// return err -// } - -// if err := f(); err != nil { -// return err -// } - -// w.endValue() -// return nil -// } - // WriteNull writes an untyped null. func (w *textWriter) WriteNull() error { return w.WriteNullType(NoType) diff --git a/textwriter_test.go b/textwriter_test.go index 22383145..a2a6c329 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -8,7 +8,7 @@ import ( "time" ) -func TestTopLevelFieldName(t *testing.T) { +func TestWriteTextTopLevelFieldName(t *testing.T) { writeText(func(w Writer) { if err := w.FieldName("foo"); err == nil { t.Error("expected an error") @@ -16,7 +16,7 @@ func TestTopLevelFieldName(t *testing.T) { }) } -func TestEmptyStruct(t *testing.T) { +func TestWriteTextEmptyStruct(t *testing.T) { testTextWriter(t, "{}", func(w Writer) { if err := w.BeginStruct(); err != nil { t.Fatal(err) @@ -32,7 +32,7 @@ func TestEmptyStruct(t *testing.T) { }) } -func TestAnnotatedStruct(t *testing.T) { +func TestWriteTextAnnotatedStruct(t *testing.T) { testTextWriter(t, "foo::$bar::'.baz'::{}", func(w Writer) { w.Annotation("foo") w.Annotation("$bar") @@ -46,7 +46,7 @@ func TestAnnotatedStruct(t *testing.T) { }) } -func TestNestedStruct(t *testing.T) { +func TestWriteTextNestedStruct(t *testing.T) { testTextWriter(t, "{foo:'true'::{},'null':{}}", func(w Writer) { w.BeginStruct() @@ -63,7 +63,7 @@ func TestNestedStruct(t *testing.T) { }) } -func TestEmptyList(t *testing.T) { +func TestWriteTextEmptyList(t *testing.T) { testTextWriter(t, "[]", func(w Writer) { if err := w.BeginList(); err != nil { t.Fatal(err) @@ -79,7 +79,7 @@ func TestEmptyList(t *testing.T) { }) } -func TestNestedLists(t *testing.T) { +func TestWriteTextNestedLists(t *testing.T) { testTextWriter(t, "[{},foo::{},'null'::[]]", func(w Writer) { w.BeginList() @@ -98,7 +98,7 @@ func TestNestedLists(t *testing.T) { }) } -func TestWriteSexps(t *testing.T) { +func TestWriteTextSexps(t *testing.T) { testTextWriter(t, "()\n(())\n(() ())", func(w Writer) { w.BeginSexp() w.EndSexp() @@ -117,7 +117,7 @@ func TestWriteSexps(t *testing.T) { }) } -func TestNulls(t *testing.T) { +func TestWriteTextNulls(t *testing.T) { expected := "[null,foo::null.null,null.bool,null.int,null.float,null.decimal," + "null.timestamp,null.symbol,null.string,null.clob,null.blob," + "null.list,'null'::null.sexp,null.struct]" @@ -146,7 +146,7 @@ func TestNulls(t *testing.T) { }) } -func TestBool(t *testing.T) { +func TestWriteTextBool(t *testing.T) { expected := "true\n(false '123'::true)\n'false'::false" testTextWriter(t, expected, func(w Writer) { w.WriteBool(true) @@ -202,7 +202,7 @@ func TestWriteTextBigInt(t *testing.T) { }) } -func TestFloat(t *testing.T) { +func TestWriteTextFloat(t *testing.T) { expected := "{z:0e+0,nz:-0e+0,s:1.234e+1,l:1.234e-55,n:nan,i:+inf,ni:-inf}" testTextWriter(t, expected, func(w Writer) { w.BeginStruct() @@ -228,7 +228,7 @@ func TestFloat(t *testing.T) { }) } -func TestDecimal(t *testing.T) { +func TestWriteTextDecimal(t *testing.T) { expected := "0.\n-1.23d-98" testTextWriter(t, expected, func(w Writer) { w.WriteDecimal(MustParseDecimal("0")) @@ -236,7 +236,7 @@ func TestDecimal(t *testing.T) { }) } -func TestTimestamp(t *testing.T) { +func TestWriteTextTimestamp(t *testing.T) { expected := "1970-01-01T00:00:00.001Z\n1970-01-01T01:23:00+01:23" testTextWriter(t, expected, func(w Writer) { w.WriteTimestamp(time.Unix(0, 1000000).In(time.UTC)) @@ -244,7 +244,7 @@ func TestTimestamp(t *testing.T) { }) } -func TestSymbol(t *testing.T) { +func TestWriteTextSymbol(t *testing.T) { expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸'}" testTextWriter(t, expected, func(w Writer) { w.BeginStruct() @@ -266,7 +266,7 @@ func TestSymbol(t *testing.T) { }) } -func TestString(t *testing.T) { +func TestWriteTextString(t *testing.T) { expected := `("hello" "" ("\\\"\n\"\\" zany::"🤪"))` testTextWriter(t, expected, func(w Writer) { w.BeginSexp() @@ -283,7 +283,7 @@ func TestString(t *testing.T) { }) } -func TestBlob(t *testing.T) { +func TestWriteTextBlob(t *testing.T) { expected := "{{AAEC/f7/}}\n{{SGVsbG8gV29ybGQ=}}\nempty::{{}}" testTextWriter(t, expected, func(w Writer) { w.WriteBlob([]byte{0, 1, 2, 0xFD, 0xFE, 0xFF}) @@ -293,7 +293,7 @@ func TestBlob(t *testing.T) { }) } -func TestClob(t *testing.T) { +func TestWriteTextClob(t *testing.T) { expected := "{hello:{{\"world\"}},bits:{{\"\\0\\x01\\xFE\\xFF\"}}}" testTextWriter(t, expected, func(w Writer) { w.BeginStruct() @@ -305,7 +305,7 @@ func TestClob(t *testing.T) { }) } -func TestFinish(t *testing.T) { +func TestWriteTextFinish(t *testing.T) { expected := "1\nfoo\n\"bar\"\n{}\n" testTextWriter(t, expected, func(w Writer) { w.WriteInt(1) @@ -319,7 +319,7 @@ func TestFinish(t *testing.T) { }) } -func TestBadFinish(t *testing.T) { +func TestWriteTextBadFinish(t *testing.T) { buf := strings.Builder{} w := NewTextWriter(&buf) From b53948dda6a1b449e7c7ae007e6fb03fad2da9b4 Mon Sep 17 00:00:00 2001 From: David Murray Date: Sun, 1 Sep 2019 14:25:24 +1000 Subject: [PATCH 37/56] Even more refactoring --- binaryreader.go | 4 +-- bufstack.go => buf.go | 0 bufstack_test.go => buf_test.go | 0 catalog.go | 7 +++++ marshal.go | 48 ++++++++++++++++++++------------- textreader.go | 18 ++----------- 6 files changed, 40 insertions(+), 37 deletions(-) rename bufstack.go => buf.go (100%) rename bufstack_test.go => buf_test.go (100%) diff --git a/binaryreader.go b/binaryreader.go index 7df471be..eaee044e 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -8,11 +8,11 @@ import ( // A binaryReader reads binary Ion. type binaryReader struct { + reader + bits bitstream cat *Catalog lst SymbolTable - - reader } func newBinaryReaderBuf(in *bufio.Reader, cat *Catalog) Reader { diff --git a/bufstack.go b/buf.go similarity index 100% rename from bufstack.go rename to buf.go diff --git a/bufstack_test.go b/buf_test.go similarity index 100% rename from bufstack_test.go rename to buf_test.go diff --git a/catalog.go b/catalog.go index a0657f8f..b867d0e3 100644 --- a/catalog.go +++ b/catalog.go @@ -32,3 +32,10 @@ func (c *Catalog) NewReader(in io.Reader) Reader { func (c *Catalog) NewReaderBytes(in []byte) Reader { return newReader(bytes.NewReader(in), c) } + +// Unmarshal unmarshals Ion data using this catalog. +func (c *Catalog) Unmarshal(data []byte, v interface{}) error { + r := c.NewReader(bytes.NewReader(data)) + d := NewDecoder(r) + return d.DecodeTo(v) +} diff --git a/marshal.go b/marshal.go index 11ab52e9..0e89562f 100644 --- a/marshal.go +++ b/marshal.go @@ -20,34 +20,44 @@ const ( // MarshalText marshals values to text ion. func MarshalText(v interface{}) ([]byte, error) { - return marshal(func(w io.Writer) Writer { - return NewTextWriterOpts(w, TextWriterQuietFinish) - }, EncodeSortMaps, v) + buf := bytes.Buffer{} + w := NewTextWriterOpts(&buf, TextWriterQuietFinish) + e := Encoder{ + w: w, + opts: EncodeSortMaps, + } + + if err := e.Encode(v); err != nil { + return nil, err + } + if err := e.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil } // MarshalBinary marshals values to binary ion. func MarshalBinary(v interface{}, ssts ...SharedSymbolTable) ([]byte, error) { - return marshal(func(w io.Writer) Writer { - return NewBinaryWriter(w, ssts...) - }, 0, v) + buf := bytes.Buffer{} + w := NewBinaryWriter(&buf, ssts...) + e := Encoder{w: w} + + if err := e.Encode(v); err != nil { + return nil, err + } + if err := e.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil } // MarshalBinaryLST marshals values to binary ion with a fixed local symbol table. func MarshalBinaryLST(v interface{}, lst SymbolTable) ([]byte, error) { - return marshal(func(w io.Writer) Writer { - return NewBinaryWriterLST(w, lst) - }, 0, v) -} - -// marshal marshals a value using the given writer type. -func marshal(wf func(io.Writer) Writer, opts EncoderOpts, v interface{}) ([]byte, error) { buf := bytes.Buffer{} - w := wf(&buf) - - e := Encoder{ - w: w, - opts: opts, - } + w := NewBinaryWriterLST(&buf, lst) + e := Encoder{w: w} if err := e.Encode(v); err != nil { return nil, err diff --git a/textreader.go b/textreader.go index 2ea18b72..262eccea 100644 --- a/textreader.go +++ b/textreader.go @@ -39,12 +39,10 @@ func (s trs) String() string { // A textReader is a Reader that reads text Ion. type textReader struct { - tok tokenizer - state trs - reader - debug bool + tok tokenizer + state trs } func newTextReaderBuf(in *bufio.Reader) Reader { @@ -68,10 +66,6 @@ func (t *textReader) Next() bool { return false } - if t.debug { - fmt.Println("ion: state =", t.state) - } - // If we haven't fully read the current value, skip over it. err := t.finishValue() if err != nil { @@ -79,10 +73,6 @@ func (t *textReader) Next() bool { return false } - if t.debug { - fmt.Println("ion: state after finish =", t.state) - } - t.clear() // Loop until we've consumed enough tokens to know what the next value is. @@ -92,10 +82,6 @@ func (t *textReader) Next() bool { return false } - if t.debug { - fmt.Println("ion: read token ", t.tok.Token()) - } - var done bool var err error From c863cd28602dca5d05420dfd45d6b1576eeaeebc Mon Sep 17 00:00:00 2001 From: David Murray Date: Sun, 1 Sep 2019 17:02:16 +1000 Subject: [PATCH 38/56] README, LICENSE, etc --- LICENSE | 177 +++++++++++++++++++++++++++++++++++++ NOTICE | 2 + README.md | 226 ++++++++++++++++++++++++++++++++++++++++++++++++ binaryreader.go | 4 +- catalog.go | 58 ++++++++++--- catalog_test.go | 62 +++++++++++++ marshal.go | 29 +++---- reader.go | 5 +- unmarshal.go | 16 ++++ 9 files changed, 543 insertions(+), 36 deletions(-) create mode 100644 LICENSE create mode 100644 NOTICE create mode 100644 README.md create mode 100644 catalog_test.go diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..f13a8433 --- /dev/null +++ b/LICENSE @@ -0,0 +1,177 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + \ No newline at end of file diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..48a14bab --- /dev/null +++ b/NOTICE @@ -0,0 +1,2 @@ +Amazon Ion Go +Copyright 2019 David Murray diff --git a/README.md b/README.md new file mode 100644 index 00000000..525d615b --- /dev/null +++ b/README.md @@ -0,0 +1,226 @@ +# Ion Go +A Golang implementation of Amazon's [Ion data notation](https://amzn.github.io/ion-docs/). + +## Using the Library +Import `github.com/fernomac/ion-go` and you're off to the races. + +### Marshaling and Unmarshaling +Similar to Golang's built-in [json](https://golang.org/pkg/encoding/json/) package, +you can marshal and unmarshal go types to Ion. Marshaling requires you to specify +whether you'd like text or binary Ion. Unmarshaling is smart enough to do the right +thing. Both respect json name tags, and `Marshal` honors omitempty. +```Go +type T struct { + A string + B struct { + RenamedC int `json:"C"` + D []int `json:",omitempty"` + } +} + +func main() { + t := T{} + + err := ion.Unmarshal([]byte("{A:\"Ion!\",B:{C:2,D:[3,4]}}"), &t) + if err != nil { + panic(err) + } + fmt.Printf("--- t:\n%v\n\n", t) + + text, err := ion.MarshalText(&t) + if err != nil { + panic(err) + } + fmt.Printf("--- text:\n%s\n\n", string(text)) + + binary, err := ion.MarshalBinary(&t) + if err != nil { + panic(err) + } + fmt.Printf("--- binary:\n%X\n\n", binary) +} +``` + +### Encoding and Decoding +To read or write multiple values at once, use an `Encoder` or `Decoder`: +```Go +func main() { + dec := ion.NewTextDecoder(os.Stdin) + enc := ion.NewBinaryEncoder(os.Stdout) + + for { + // Decode one Ion whole value from stdin. + val, err := dec.Decode() + if err == ion.ErrNoInput { + break + } else if err != nil { + panic(err) + } + + // Encode it to stdout. + if err := enc.Encode(val); err != nil { + panic(err) + } + } + + if err := enc.Finish(); err != nil { + panic(err) + } +} +``` + +### Reading and Writing +For low-level streaming read and write access, use a `Reader` or `Writer`. +```Go +func copy(in ion.Reader, out ion.Writer) { + for in.Next() { + name := in.FieldName() + if name != "" { + out.FieldName(name) + } + + annos := in.Annotations() + if len(annos) > 0 { + out.Annotations(annos...) + } + + switch in.Type() { + case ion.BoolType: + val, err := in.BoolValue() + if err != nil { + panic(err) + } + out.WriteBool(val) + + case ion.IntType: + val, err := in.Int64Value() + if err != nil { + panic(err) + } + out.WriteInt(val) + + case ion.StringType: + val, err := in.StringValue() + if err != nil { + panic(err) + } + out.WriteString(val) + + case ion.ListType: + in.StepIn() + out.BeginList() + copy(in, out) + in.StepOut() + out.EndList() + + case ion.StructType: + in.StepIn() + out.BeginStruct() + copy(in, out) + in.StepOut() + out.EndStruct() + } + } + + if in.Err() != nil { + panic(in.Err()) + } +} + +func main() { + in := ion.NewReader(os.Stdin) + out := ion.NewBinaryWriter(os.Stdout) + + copy(in, out) + + if err := out.Finish(); err != nil { + panic(err) + } +} +``` + +### Symbol Tables +By default, when writing binary Ion, a local symbol table is built as you write +values (which are buffered in memory until you call `Finish` so the symbol table +can be written out first). You can optionally provide one or more +`SharedSymbolTable`s to the writer, which it will reference as needed rather +than directly including those symbols in the local symbol table. +```Go +type Item struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` +} + +var ItemSharedSymbols = ion.NewSharedSymbolTable("item", 1, []string{ + "item", + "id", + "name", + "description", +}) + +type SpicyItem struct { + Item + Spiciness int `json:"spiciness"` +} + +func WriteSpicyItemsTo(out io.Writer, items []SpicyItem) error { + writer := ion.NewBinaryWriter(out, ItemSharedSymbols) + + for _, item := range items { + writer.Annotation("item") + if err := ion.EncodeTo(writer, item); err != nil { + return err + } + } + + return writer.Finish() +} +``` + +You can alternatively provide the writer with a complete, pre-built local symbol table. +This allows values to be written without buffering, however any attempt to write a +symbol that is not included in the symbol table will result in an error: +```Go +func WriteItemsToLST(out io.Writer, items []SpicyItem) error { + lst := ion.NewLocalSymbolTable([]SharedSymbolTable{ItemSharedSymbols}, []string{ + "spiciness", + }) + + writer := ion.NewBinaryWriterLST(out, lst) + + for _, item := range items { + writer.Annotation("item") + if err := ion.EncodeTo(writer, item); err != nil { + return err + } + } + + return writer.Finish() +} +``` + +When reading binary Ion, shared symbol tables are provided by a `Catalog`. A basic +catalog can be constructed by calling `NewCatalog`; a smarter implementation may +load shared symbol tables from a database on demand. +```Go +func ReadItemsFrom(in io.Reader) ([]Item, error) { + item := Item{} + items := []Item{} + + cat := ion.NewCatalog(ItemSharedSymbols) + dec := ion.NewDecoder(ion.NewReaderCat(in, cat)) + + for { + err := dec.DecodeTo(&item) + if err == ion.ErrNoInput { + return items, nil + } + if err != nil { + return nil, err + } + + items = append(items, item) + } +} +``` diff --git a/binaryreader.go b/binaryreader.go index eaee044e..df3fc6ed 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -11,11 +11,11 @@ type binaryReader struct { reader bits bitstream - cat *Catalog + cat Catalog lst SymbolTable } -func newBinaryReaderBuf(in *bufio.Reader, cat *Catalog) Reader { +func newBinaryReaderBuf(in *bufio.Reader, cat Catalog) Reader { r := &binaryReader{ cat: cat, } diff --git a/catalog.go b/catalog.go index b867d0e3..60a96d18 100644 --- a/catalog.go +++ b/catalog.go @@ -4,38 +4,70 @@ import ( "bytes" "fmt" "io" + "strings" ) -// A Catalog stores shared symbol tables and serves as a reader factory. -type Catalog struct { +// A Catalog provides access to shared symbol tables. +type Catalog interface { + Find(name string, version int) SharedSymbolTable +} + +// A basicCatalog wraps an in-memory collection of shared symbol tables. +type basicCatalog struct { ssts map[string]SharedSymbolTable } +// NewCatalog creates a new basic catalog containing the given symbol tables. +func NewCatalog(ssts ...SharedSymbolTable) Catalog { + cat := &basicCatalog{make(map[string]SharedSymbolTable)} + for _, sst := range ssts { + cat.add(sst) + } + return cat +} + // Add adds a shared symbol table to the catalog. -func (c *Catalog) Add(sst SharedSymbolTable) { +func (c *basicCatalog) add(sst SharedSymbolTable) { key := fmt.Sprintf("%v/%v", sst.Name(), sst.Version()) c.ssts[key] = sst } // Find attempts to find a shared symbol table with the given name and version. -func (c *Catalog) Find(name string, version int) SharedSymbolTable { +func (c *basicCatalog) Find(name string, version int) SharedSymbolTable { key := fmt.Sprintf("%v/%v", name, version) return c.ssts[key] } -// NewReader creates a new reader using this catalog. -func (c *Catalog) NewReader(in io.Reader) Reader { - return newReader(in, c) +// A System is a reader factory wrapping a catalog. +type System struct { + Catalog Catalog +} + +// NewReader creates a new reader using this system's catalog. +func (s System) NewReader(in io.Reader) Reader { + return NewReaderCat(in, s.Catalog) } -// NewReaderBytes creates a new reader using this catalog. -func (c *Catalog) NewReaderBytes(in []byte) Reader { - return newReader(bytes.NewReader(in), c) +// NewReaderStr creates a new reader using this system's catalog. +func (s System) NewReaderStr(in string) Reader { + return NewReaderCat(strings.NewReader(in), s.Catalog) +} + +// NewReaderBytes creates a new reader using this system's catalog. +func (s System) NewReaderBytes(in []byte) Reader { + return NewReaderCat(bytes.NewReader(in), s.Catalog) +} + +// Unmarshal unmarshals Ion data using this system's catalog. +func (s System) Unmarshal(data []byte, v interface{}) error { + r := s.NewReaderBytes(data) + d := NewDecoder(r) + return d.DecodeTo(v) } -// Unmarshal unmarshals Ion data using this catalog. -func (c *Catalog) Unmarshal(data []byte, v interface{}) error { - r := c.NewReader(bytes.NewReader(data)) +// UnmarshalStr unmarshals Ion data using this system's catalog. +func (s System) UnmarshalStr(data string, v interface{}) error { + r := s.NewReaderStr(data) d := NewDecoder(r) return d.DecodeTo(v) } diff --git a/catalog_test.go b/catalog_test.go new file mode 100644 index 00000000..2e3bd2e0 --- /dev/null +++ b/catalog_test.go @@ -0,0 +1,62 @@ +package ion + +import ( + "bytes" + "fmt" + "testing" +) + +type Item struct { + ID int `json:"id"` + Name string `json:"name"` + Description string `json:"description"` +} + +func TestCatalog(t *testing.T) { + sst := NewSharedSymbolTable("item", 1, []string{ + "item", + "id", + "name", + "description", + }) + + buf := bytes.Buffer{} + out := NewBinaryWriter(&buf, sst) + + for i := 0; i < 10; i++ { + out.Annotation("item") + MarshalTo(out, &Item{ + ID: i, + Name: fmt.Sprintf("Item %v", i), + Description: fmt.Sprintf("The %vth test item", i), + }) + } + if err := out.Finish(); err != nil { + t.Fatal(err) + } + + bs := buf.Bytes() + + sys := System{Catalog: NewCatalog(sst)} + in := sys.NewReaderBytes(bs) + + i := 0 + for ; ; i++ { + item := Item{} + err := UnmarshalFrom(in, &item) + if err == ErrNoInput { + break + } + if err != nil { + t.Fatal(err) + } + + if item.ID != i { + t.Errorf("expected id=%v, got %v", i, item.ID) + } + } + + if i != 10 { + t.Errorf("expected i=10, got %v", i) + } +} diff --git a/marshal.go b/marshal.go index 0e89562f..7a5529d4 100644 --- a/marshal.go +++ b/marshal.go @@ -69,6 +69,16 @@ func MarshalBinaryLST(v interface{}, lst SymbolTable) ([]byte, error) { return buf.Bytes(), nil } +// MarshalTo marshals the given value to the given writer. It does +// not call Finish, so is suitable for encoding values inside of +// a partially-constructed Ion value. +func MarshalTo(w Writer, v interface{}) error { + e := Encoder{ + w: w, + } + return e.Encode(v) +} + // An Encoder writes Ion values to an output stream. type Encoder struct { w Writer @@ -103,25 +113,6 @@ func NewBinaryEncoderLST(w io.Writer, lst SymbolTable) *Encoder { return NewEncoder(NewBinaryWriterLST(w, lst)) } -// EncodeTo encodes the given value to the given writer. It does -// not call Finish, so is suitable for encoding values inside of -// a partially-constructed Ion value. -func EncodeTo(w Writer, v interface{}) error { - e := Encoder{ - w: w, - } - return e.Encode(v) -} - -// EncodeToOpts is like EncodeTo but accepts additional opts. -func EncodeToOpts(w Writer, opts EncoderOpts, v interface{}) error { - e := Encoder{ - w: w, - opts: opts, - } - return e.Encode(v) -} - // Encode marshals the given value to Ion, writing it to the underlying writer. func (m *Encoder) Encode(v interface{}) error { return m.encodeValue(reflect.ValueOf(v)) diff --git a/reader.go b/reader.go index f0216fc8..2463c8db 100644 --- a/reader.go +++ b/reader.go @@ -149,7 +149,7 @@ type Reader interface { // NewReader creates a new Ion reader of the appropriate type by peeking // at the first several bytes of input for a binary version marker. func NewReader(in io.Reader) Reader { - return newReader(in, nil) + return NewReaderCat(in, nil) } // NewReaderStr creates a new reader from a string. @@ -162,7 +162,8 @@ func NewReaderBytes(in []byte) Reader { return NewReader(bytes.NewReader(in)) } -func newReader(in io.Reader, cat *Catalog) Reader { +// NewReaderCat creates a new reader with the given catalog. +func NewReaderCat(in io.Reader, cat Catalog) Reader { br := bufio.NewReader(in) bs, err := br.Peek(4) diff --git a/unmarshal.go b/unmarshal.go index 825238cb..e5e61090 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "math/big" "reflect" "strconv" @@ -25,6 +26,14 @@ func UnmarshalStr(data string, v interface{}) error { return Unmarshal([]byte(data), v) } +// UnmarshalFrom unmarshal Ion data from a reader to the given object. +func UnmarshalFrom(r Reader, v interface{}) error { + d := Decoder{ + r: r, + } + return d.DecodeTo(v) +} + // A Decoder decodes go values from an Ion reader. type Decoder struct { r Reader @@ -37,6 +46,13 @@ func NewDecoder(r Reader) *Decoder { } } +// NewTextDecoder creates a new text decoder. Well, a decoder that uses a reader with +// no shared symbol tables, it'll work to read binary too if the binary doesn't reference +// any shared symbol tables. +func NewTextDecoder(in io.Reader) *Decoder { + return NewDecoder(NewReader(in)) +} + // Decode decodes a value from the underlying Ion reader without any expectations // about what it's going to get. Structs become map[string]interface{}s, Lists and // Sexps become []interface{}s. From 02beac68ebf1009a67abebd2fb62908f51861174 Mon Sep 17 00:00:00 2001 From: David Murray Date: Mon, 2 Sep 2019 14:46:21 +1000 Subject: [PATCH 39/56] improved error handling, refactoring --- binaryreader.go | 229 ++++++++++++++-------- binaryreader_test.go | 157 ++++++++++++++- binarywriter.go | 267 +++++++++---------------- binarywriter_test.go | 2 + bitstream.go | 452 +++++++++++++++++++++++++++---------------- bitstream_test.go | 6 +- catalog.go | 25 ++- decimal.go | 61 +++--- decimal_test.go | 4 +- err.go | 88 +++++++++ reader.go | 157 +++++++++------ skipper.go | 32 +-- skipper_test.go | 46 ++--- symboltable.go | 238 +++++++++++++++-------- symboltable_test.go | 10 +- textreader.go | 41 ++-- textreader_test.go | 27 +++ textutils.go | 6 +- textutils_test.go | 4 +- textwriter.go | 60 +++--- textwriter_test.go | 5 +- tokenizer.go | 76 ++++---- type.go | 4 + writer.go | 2 + 24 files changed, 1270 insertions(+), 729 deletions(-) create mode 100644 err.go diff --git a/binaryreader.go b/binaryreader.go index df3fc6ed..9a64a7c2 100644 --- a/binaryreader.go +++ b/binaryreader.go @@ -2,7 +2,6 @@ package ion import ( "bufio" - "errors" "fmt" ) @@ -23,10 +22,12 @@ func newBinaryReaderBuf(in *bufio.Reader, cat Catalog) Reader { return r } +// SymbolTable returns the current symbol table. func (r *binaryReader) SymbolTable() SymbolTable { return r.lst } +// Next moves the reader to the next value. func (r *binaryReader) Next() bool { if r.eof || r.err != nil { return false @@ -45,12 +46,15 @@ func (r *binaryReader) Next() bool { return !r.eof } +// Next consumes the next raw value from the stream, returning true if it +// represents a user-facing value and false if it does not. func (r *binaryReader) next() (bool, error) { if err := r.bits.Next(); err != nil { return false, err } - switch r.bits.Code() { + code := r.bits.Code() + switch code { case bitcodeEOF: r.eof = true return true, nil @@ -68,7 +72,7 @@ func (r *binaryReader) next() (bool, error) { return false, err case bitcodeNull: - if !r.bits.Null() { + if !r.bits.IsNull() { // NOP padding; skip it and keep going. err := r.bits.SkipValue() return false, err @@ -78,14 +82,14 @@ func (r *binaryReader) next() (bool, error) { case bitcodeFalse, bitcodeTrue: r.valueType = BoolType - if !r.bits.Null() { + if !r.bits.IsNull() { r.value = (r.bits.Code() == bitcodeTrue) } return true, nil case bitcodeInt, bitcodeNegInt: r.valueType = IntType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadInt() if err != nil { return false, err @@ -96,7 +100,7 @@ func (r *binaryReader) next() (bool, error) { case bitcodeFloat: r.valueType = FloatType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadFloat() if err != nil { return false, err @@ -107,7 +111,7 @@ func (r *binaryReader) next() (bool, error) { case bitcodeDecimal: r.valueType = DecimalType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadDecimal() if err != nil { return false, err @@ -118,7 +122,7 @@ func (r *binaryReader) next() (bool, error) { case bitcodeTimestamp: r.valueType = TimestampType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadTimestamp() if err != nil { return false, err @@ -129,8 +133,8 @@ func (r *binaryReader) next() (bool, error) { case bitcodeSymbol: r.valueType = SymbolType - if !r.bits.Null() { - id, err := r.bits.ReadSymbol() + if !r.bits.IsNull() { + id, err := r.bits.ReadSymbolID() if err != nil { return false, err } @@ -140,7 +144,7 @@ func (r *binaryReader) next() (bool, error) { case bitcodeString: r.valueType = StringType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadString() if err != nil { return false, err @@ -151,7 +155,7 @@ func (r *binaryReader) next() (bool, error) { case bitcodeClob: r.valueType = ClobType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadBytes() if err != nil { return false, err @@ -162,7 +166,7 @@ func (r *binaryReader) next() (bool, error) { case bitcodeBlob: r.valueType = BlobType - if !r.bits.Null() { + if !r.bits.IsNull() { val, err := r.bits.ReadBytes() if err != nil { return false, err @@ -173,84 +177,70 @@ func (r *binaryReader) next() (bool, error) { case bitcodeList: r.valueType = ListType - if !r.bits.Null() { + if !r.bits.IsNull() { r.value = ListType } return true, nil case bitcodeSexp: r.valueType = SexpType - if !r.bits.Null() { + if !r.bits.IsNull() { r.value = SexpType } return true, nil case bitcodeStruct: r.valueType = StructType - if !r.bits.Null() { + if !r.bits.IsNull() { r.value = StructType } - if len(r.annotations) > 0 && r.annotations[0] == "$ion_symbol_table" { + // If it's a local symbol table, install it and keep going. + if r.ctx.peek() == ctxAtTopLevel && isIonSymbolTable(r.annotations) { err := r.readLocalSymbolTable() return false, err } return true, nil - - default: - panic(fmt.Sprintf("unsupported bitcode %v", r.bits.Code())) } + panic(fmt.Sprintf("invalid bitcode %v", code)) +} + +func isIonSymbolTable(as []string) bool { + return len(as) > 0 && as[0] == "$ion_symbol_table" } +// ReadBVM reads a BVM, validates it, and resets the local symbol table. func (r *binaryReader) readBVM() error { major, minor, err := r.bits.ReadBVM() if err != nil { return err } - if major != 1 && minor != 0 { - return fmt.Errorf("ion: unsupported version %v.%v", major, minor) + switch major { + case 1: + switch minor { + case 0: + r.lst = V1SystemSymbolTable + return nil + } } - r.lst = V1SystemSymbolTable - return nil -} - -func (r *binaryReader) readFieldName() error { - id, err := r.bits.ReadFieldID() - if err != nil { - return err + return &UnsupportedVersionError{ + int(major), + int(minor), + r.bits.Pos() - 4, } - - r.fieldName = r.resolve(id) - return nil } -func (r *binaryReader) readAnnotations() error { - ids, err := r.bits.ReadAnnotations() - if err != nil { - return err - } - - as := make([]string, len(ids)) - for i, id := range ids { - as[i] = r.resolve(id) - } - - r.annotations = as - return nil -} - -func (r *binaryReader) resolve(id uint64) string { - s, ok := r.lst.FindByID(int(id)) - if !ok { - return fmt.Sprintf("$%v", id) +// ReadLocalSymbolTable reads and installs a new local symbol table. +func (r *binaryReader) readLocalSymbolTable() error { + if r.IsNull() { + r.clear() + r.lst = V1SystemSymbolTable + return nil } - return s -} -func (r *binaryReader) readLocalSymbolTable() error { if err := r.StepIn(); err != nil { return err } @@ -279,8 +269,20 @@ func (r *binaryReader) readLocalSymbolTable() error { return nil } +// ReadImports reads the imports field of a local symbol table. func (r *binaryReader) readImports() ([]SharedSymbolTable, error) { - if r.Type() != ListType { + if r.valueType == SymbolType && r.value == "$ion_symbol_table" { + // Special case that imports the current local symbol table. + if r.lst == nil || r.lst == V1SystemSymbolTable { + return nil, nil + } + + imps := r.lst.Imports() + lsst := NewSharedSymbolTable("", 0, r.lst.Symbols()) + return append(imps, lsst), nil + } + + if r.Type() != ListType || r.IsNull() { return nil, nil } if err := r.StepIn(); err != nil { @@ -293,15 +295,18 @@ func (r *binaryReader) readImports() ([]SharedSymbolTable, error) { if err != nil { return nil, err } - imps = append(imps, imp) + if imp != nil { + imps = append(imps, imp) + } } err := r.StepOut() return imps, err } +// ReadImport reads an import definition. func (r *binaryReader) readImport() (SharedSymbolTable, error) { - if r.Type() != StructType { + if r.Type() != StructType || r.IsNull() { return nil, nil } if err := r.StepIn(); err != nil { @@ -310,17 +315,28 @@ func (r *binaryReader) readImport() (SharedSymbolTable, error) { name := "" version := 0 - maxID := 0 + maxID := uint64(0) for r.Next() { var err error switch r.FieldName() { case "name": - name, err = r.StringValue() + if r.Type() == StringType { + name, err = r.StringValue() + } case "version": - version, err = r.IntValue() + if r.Type() == IntType { + version, err = r.IntValue() + } case "max_id": - maxID, err = r.IntValue() + if r.Type() == IntType { + var i int64 + i, err = r.Int64Value() + if i < 0 { + i = 0 + } + maxID = uint64(i) + } } if err != nil { return nil, err @@ -331,17 +347,27 @@ func (r *binaryReader) readImport() (SharedSymbolTable, error) { return nil, err } - if name == "" || version == 0 || maxID == 0 { - return nil, errors.New("ion: invalid import in local symbol table") + if name == "" || name == "$ion" { + return nil, nil + } + if version < 1 { + version = 1 } var imp SharedSymbolTable if r.cat != nil { - imp = r.cat.Find(name, version) - if imp != nil && imp.MaxID() != maxID { - // TODO: Better error. - return nil, errors.New("ion: maxID mismatch in imported symbol table") + imp = r.cat.FindExact(name, version) + if imp == nil { + imp = r.cat.FindLatest(name) + } + } + + if maxID == 0 { + if imp == nil || version != imp.Version() { + return nil, fmt.Errorf("ion: import of shared table %v/%v lacks a valid max_id, but an exact "+ + "match was not found in the catalog", name, version) } + maxID = imp.MaxID() } if imp == nil { @@ -350,11 +376,14 @@ func (r *binaryReader) readImport() (SharedSymbolTable, error) { version: version, maxID: maxID, } + } else { + imp = imp.Adjust(maxID) } return imp, nil } +// ReadSymbols reads the symbols from a symbol table. func (r *binaryReader) readSymbols() ([]string, error) { if r.Type() != ListType { return nil, nil @@ -371,39 +400,87 @@ func (r *binaryReader) readSymbols() ([]string, error) { return nil, err } syms = append(syms, sym) + } else { + syms = append(syms, "") } } err := r.StepOut() - return syms, err } +// ReadFieldName reads and resolves a field name. +func (r *binaryReader) readFieldName() error { + id, err := r.bits.ReadFieldID() + if err != nil { + return err + } + + r.fieldName = r.resolve(id) + return nil +} + +// ReadAnnotations reads and resolves a set of annotations. +func (r *binaryReader) readAnnotations() error { + ids, err := r.bits.ReadAnnotationIDs() + if err != nil { + return err + } + + as := make([]string, len(ids)) + for i, id := range ids { + as[i] = r.resolve(id) + } + + r.annotations = as + return nil +} + +// Resolve resolves a symbol ID to a symbol value (possibly ${id} if we're +// missing the appropriate symbol table). +func (r *binaryReader) resolve(id uint64) string { + s, ok := r.lst.FindByID(id) + if !ok { + return fmt.Sprintf("$%v", id) + } + return s +} + +// StepIn steps in to a container-type value func (r *binaryReader) StepIn() error { if r.err != nil { return r.err } - switch r.valueType { - case ListType, SexpType, StructType: - default: - return errors.New("ion: StepIn called when not on a container") - } - ctx := containerTypeToCtx(r.valueType) - r.ctx.push(ctx) + if r.valueType != ListType && r.valueType != SexpType && r.valueType != StructType { + return &UsageError{"Reader.StepIn", fmt.Sprintf("cannot step in to a %v", r.valueType)} + } + if r.value == nil { + return &UsageError{"Reader.StepIn", "cannot step in to a null container"} + } + r.ctx.push(containerTypeToCtx(r.valueType)) r.clear() r.bits.StepIn() return nil } +// StepOut steps out of a container-type value. func (r *binaryReader) StepOut() error { + if r.err != nil { + return r.err + } + if r.ctx.peek() == ctxAtTopLevel { + return &UsageError{"Reader.StepOut", "cannot step out of top-level datagram"} + } + if err := r.bits.StepOut(); err != nil { return err } r.clear() + r.ctx.pop() r.eof = false return nil diff --git a/binaryreader_test.go b/binaryreader_test.go index 9e064dbf..4f0901b5 100644 --- a/binaryreader_test.go +++ b/binaryreader_test.go @@ -8,6 +8,154 @@ import ( "time" ) +func TestReadBadBVMs(t *testing.T) { + t.Run("E00200E9", func(t *testing.T) { + // Need a good first one or we'll get sent to the text reader. + r := NewReaderBytes([]byte{0xE0, 0x01, 0x00, 0xEA, 0xE0, 0x02, 0x00, 0xE9}) + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() == nil { + t.Fatal("err is nil") + } + }) + + t.Run("E00200EA", func(t *testing.T) { + r := NewReaderBytes([]byte{0xE0, 0x02, 0x00, 0xEA}) + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() == nil { + t.Fatal("err is nil") + } + + uve, ok := r.Err().(*UnsupportedVersionError) + if !ok { + t.Fatal("err is not an UnsupportedVersionError") + } + if uve.Major != 2 { + t.Errorf("expected major=2, got %v", uve.Major) + } + if uve.Minor != 0 { + t.Errorf("expected minor=0, got %v", uve.Minor) + } + }) +} + +func TestReadNullLST(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE4, 0x82, 0x83, 0x87, 0xDF, + 0x71, 0x09, + } + r := NewReaderBytes(ion) + _symbol(t, r, "$ion_shared_symbol_table") + _eof(t, r) +} + +func TestReadEmptyLST(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE4, 0x82, 0x83, 0x87, 0xD0, + 0x71, 0x09, + } + r := NewReaderBytes(ion) + _symbol(t, r, "$ion_shared_symbol_table") + _eof(t, r) +} + +func TestReadBadLST(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE3, 0x81, 0x83, 0xD9, + 0x86, 0xB7, 0xD6, // imports:[{ + 0x84, 0x81, 'a', // name: "a", + 0x85, 0x21, 0x01, // version: 1}]} + 0x0F, // null + } + r := NewReaderBytes(ion) + if r.Next() { + t.Fatal("next returned true") + } + if r.Err() == nil { + t.Fatal("err is nil") + } +} + +func TestReadMultipleLSTs(t *testing.T) { + r := readBinary([]byte{ + 0x71, 0x0B, // $11 + 0x71, 0x6F, // bar + 0xE3, 0x81, 0x83, 0xDF, // $ion_symbol_table::null.struct + 0xEE, 0x8F, 0x81, 0x83, 0xDD, // $ion_symbol_table::{ + 0x86, 0x71, 0x03, // imports: $ion_symbol_table, + 0x87, 0xB8, // symbols:[ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" ]} + 0x71, 0x0B, // bar + 0x71, 0x0C, // $12 + 0x71, 0x6F, // $111 + 0xEC, 0x81, 0x83, 0xD9, // $ion_symbol_table::{ + 0x86, 0x71, 0x03, // imports: $ion_symbol_table + 0x87, 0xB4, // symbols:[ + 0x83, 'b', 'a', 'z', // "baz" ]} + 0x71, 0x0B, // bar + 0x71, 0x0C, // baz + }) + _symbol(t, r, "$11") + _symbol(t, r, "bar") + + _symbol(t, r, "bar") + _symbol(t, r, "$12") + _symbol(t, r, "$111") + + _symbol(t, r, "bar") + _symbol(t, r, "baz") + _eof(t, r) +} + +func TestReadBinaryLST(t *testing.T) { + r := readBinary([]byte{0x0F}) + _next(t, r, NullType) + + lst := r.SymbolTable() + if lst == nil { + t.Fatal("symboltable is nil") + } + + if lst.MaxID() != 111 { + t.Errorf("expected maxid=111, got %v", lst.MaxID()) + } + + if _, ok := lst.FindByID(109); ok { + t.Error("found a symbol for $109") + } + + sym, ok := lst.FindByID(111) + if !ok { + t.Fatal("no symbol defined for $111") + } + if sym != "bar" { + t.Errorf("expected $111=bar, got %v", sym) + } + + id, ok := lst.FindByName("foo") + if !ok { + t.Fatal("no id defined for foo") + } + if id != 110 { + t.Errorf("expected foo=$110, got $%v", id) + } + + if _, ok := lst.FindByID(112); ok { + t.Error("found a symbol for $112") + } + + if _, ok := lst.FindByName("bogus"); ok { + t.Error("found a symbol for bogus") + } +} + func TestReadBinaryStructs(t *testing.T) { r := readBinary([]byte{ 0xDF, // null.struct @@ -241,8 +389,9 @@ func TestReadBinaryInts(t *testing.T) { _int64(t, r, math.MaxInt64) _int64(t, r, -math.MaxInt64) + _uint(t, r, math.MaxInt64+1) + i := new(big.Int).SetUint64(math.MaxInt64 + 1) - _bigInt(t, r, i) _bigInt(t, r, new(big.Int).Neg(i)) _eof(t, r) @@ -279,6 +428,12 @@ func TestReadBinaryNulls(t *testing.T) { _eof(t, r) } +func TestReadEmptyBinary(t *testing.T) { + r := NewReaderBytes([]byte{0xE0, 0x01, 0x00, 0xEA}) + _eof(t, r) + _eof(t, r) +} + func readBinary(ion []byte) Reader { prefix := []byte{ 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 diff --git a/binarywriter.go b/binarywriter.go index 6633cbcb..0115faf6 100644 --- a/binarywriter.go +++ b/binarywriter.go @@ -2,11 +2,12 @@ package ion import ( "encoding/binary" - "errors" "fmt" "io" "math" "math/big" + "strconv" + "strings" "time" ) @@ -47,69 +48,27 @@ func NewBinaryWriterLST(out io.Writer, lst SymbolTable) Writer { // WriteNull writes an untyped null. func (w *binaryWriter) WriteNull() error { - return w.WriteNullType(NoType) + return w.writeValue("Writer.WriteNull", []byte{0x0F}) } // WriteNullType writes a typed null. func (w *binaryWriter) WriteNullType(t Type) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.write([]byte{binaryNulls[t]}); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err + return w.writeValue("Writer.WriteNullType", []byte{binaryNulls[t]}) } // WriteBool writes a bool. func (w *binaryWriter) WriteBool(val bool) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - b := byte(0x10) if val { b = 0x11 } - - if w.err = w.write([]byte{b}); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err + return w.writeValue("Writer.WriteBool", []byte{b}) } // WriteInt writes an integer. func (w *binaryWriter) WriteInt(val int64) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.writeInt(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -// WriteInt writes the actual integer value. -func (w *binaryWriter) writeInt(val int64) error { if val == 0 { - return w.write([]byte{0x20}) + return w.writeValue("Writer.WriteInt", []byte{0x20}) } code := byte(0x20) @@ -127,7 +86,23 @@ func (w *binaryWriter) writeInt(val int64) error { buf = appendTag(buf, code, len) buf = appendUint(buf, mag) - return w.write(buf) + return w.writeValue("Writer.WriteInt", buf) +} + +// WriteUint writes an unsigned integer. +func (w *binaryWriter) WriteUint(val uint64) error { + if val == 0 { + return w.writeValue("Writer.WriteUint", []byte{0x20}) + } + + len := uintLen(val) + buflen := len + tagLen(len) + + buf := make([]byte, 0, buflen) + buf = appendTag(buf, 0x20, len) + buf = appendUint(buf, val) + + return w.writeValue("Writer.WriteUint", buf) } // WriteBigInt writes a big integer. @@ -135,7 +110,7 @@ func (w *binaryWriter) WriteBigInt(val *big.Int) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteBigInt"); w.err != nil { return w.err } @@ -180,26 +155,8 @@ func (w *binaryWriter) writeBigInt(val *big.Int) error { // WriteFloat writes a floating-point value. func (w *binaryWriter) WriteFloat(val float64) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.writeFloat(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err - -} - -// WriteFloat writes the actual float value. -func (w *binaryWriter) writeFloat(val float64) error { if val == 0 { - return w.write([]byte{0x40}) + return w.writeValue("Writer.WriteFloat", []byte{0x40}) } bs := make([]byte, 9) @@ -208,28 +165,11 @@ func (w *binaryWriter) writeFloat(val float64) error { bits := math.Float64bits(val) binary.BigEndian.PutUint64(bs[1:], bits) - return w.write(bs) + return w.writeValue("Writer.WriteFloat", bs) } // WriteDecimal writes a decimal value. func (w *binaryWriter) WriteDecimal(val *Decimal) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.writeDecimal(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -// WriteDecimal writes the actual decimal value. -func (w *binaryWriter) writeDecimal(val *Decimal) error { coef, exp := val.CoEx() vlen := uint64(0) @@ -249,27 +189,11 @@ func (w *binaryWriter) writeDecimal(val *Decimal) error { } buf = appendBigInt(buf, coef) - return w.write(buf) + return w.writeValue("Writer.WriteDecimal", buf) } // WriteTimestamp writes a timestamp value. func (w *binaryWriter) WriteTimestamp(val time.Time) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.writeTimestamp(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -func (w *binaryWriter) writeTimestamp(val time.Time) error { _, offset := val.Zone() offset /= 60 utc := val.In(time.UTC) @@ -282,29 +206,14 @@ func (w *binaryWriter) writeTimestamp(val time.Time) error { buf = appendTag(buf, 0x60, vlen) buf = appendTime(buf, offset, utc) - return w.write(buf) + return w.writeValue("Writer.WriteTimestamp", buf) } // WriteSymbol writes a symbol value. func (w *binaryWriter) WriteSymbol(val string) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.writeSymbol(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -func (w *binaryWriter) writeSymbol(val string) error { - id, err := w.resolve(val) + id, err := w.resolve("Writer.WriteSymbol", val) if err != nil { + w.err = err return err } @@ -315,49 +224,13 @@ func (w *binaryWriter) writeSymbol(val string) error { buf = appendTag(buf, 0x70, vlen) buf = appendUint(buf, uint64(id)) - return w.write(buf) -} - -// Resolve resolves a symbol to its ID. -func (w *binaryWriter) resolve(sym string) (uint64, error) { - if w.lst != nil { - id, ok := w.lst.FindByName(sym) - if !ok { - return 0, fmt.Errorf("ion: symbol '%v' not defined in local symbol table", sym) - } - if id < 0 { - panic("negative id") - } - return uint64(id), nil - } - - id, _ := w.lstb.Add(sym) - if id < 0 { - panic("negative id") - } - return uint64(id), nil + return w.writeValue("Writer.WriteSymbol", buf) } // WriteString writes a string. func (w *binaryWriter) WriteString(val string) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(); w.err != nil { - return w.err - } - - if w.err = w.writeString(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -func (w *binaryWriter) writeString(val string) error { if len(val) == 0 { - return w.write([]byte{0x80}) + return w.writeValue("Writer.WriteString", []byte{0x80}) } vlen := uint64(len(val)) @@ -367,7 +240,7 @@ func (w *binaryWriter) writeString(val string) error { buf = appendTag(buf, 0x80, vlen) buf = append(buf, val...) - return w.write(buf) + return w.writeValue("Writer.WriteString", buf) } // WriteClob writes a clob. @@ -375,7 +248,7 @@ func (w *binaryWriter) WriteClob(val []byte) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteClob"); w.err != nil { return w.err } @@ -392,7 +265,7 @@ func (w *binaryWriter) WriteBlob(val []byte) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { return w.err } @@ -426,7 +299,7 @@ func (w *binaryWriter) writeLob(code byte, val []byte) error { // BeginList begins writing a list. func (w *binaryWriter) BeginList() error { if w.err == nil { - w.err = w.begin(ctxInList, 0xB0) + w.err = w.begin("Writer.BeginList", ctxInList, 0xB0) } return w.err } @@ -434,7 +307,7 @@ func (w *binaryWriter) BeginList() error { // EndList finishes writing a list. func (w *binaryWriter) EndList() error { if w.err == nil { - w.err = w.end(ctxInList) + w.err = w.end("Writer.EndList", ctxInList) } return w.err } @@ -442,7 +315,7 @@ func (w *binaryWriter) EndList() error { // BeginSexp begins writing an s-expression. func (w *binaryWriter) BeginSexp() error { if w.err == nil { - w.err = w.begin(ctxInSexp, 0xC0) + w.err = w.begin("Writer.BeginSexp", ctxInSexp, 0xC0) } return w.err } @@ -450,7 +323,7 @@ func (w *binaryWriter) BeginSexp() error { // EndSexp finishes writing an s-expression. func (w *binaryWriter) EndSexp() error { if w.err == nil { - w.err = w.end(ctxInSexp) + w.err = w.end("Writer.EndSexp", ctxInSexp) } return w.err } @@ -458,7 +331,7 @@ func (w *binaryWriter) EndSexp() error { // BeginStruct begins writing a struct. func (w *binaryWriter) BeginStruct() error { if w.err == nil { - w.err = w.begin(ctxInStruct, 0xD0) + w.err = w.begin("Writer.BeginStruct", ctxInStruct, 0xD0) } return w.err } @@ -466,7 +339,7 @@ func (w *binaryWriter) BeginStruct() error { // EndStruct finishes writing a struct. func (w *binaryWriter) EndStruct() error { if w.err == nil { - w.err = w.end(ctxInStruct) + w.err = w.end("Writer.EndStruct", ctxInStruct) } return w.err } @@ -477,7 +350,7 @@ func (w *binaryWriter) Finish() error { return w.err } if w.ctx.peek() != ctxAtTopLevel { - return errors.New("ion: not at top level") + return &UsageError{"Writer.Finish", "not at top level"} } w.clear() @@ -519,6 +392,23 @@ func (w *binaryWriter) write(bs []byte) error { return w.emit(atom(bs)) } +// WriteValue writes a serialized value to the output stream. +func (w *binaryWriter) writeValue(api string, val []byte) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(api); w.err != nil { + return w.err + } + + if w.err = w.write(val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err +} + // WriteTag writes out a type+length tag. Use me when you've already got the value to // be written as a []byte and don't want to copy it. func (w *binaryWriter) writeTag(code byte, len uint64) error { @@ -540,7 +430,7 @@ func (w *binaryWriter) writeLST(lst SymbolTable) error { // BeginValue begins the process of writing a value by writing out // its field name and annotations. -func (w *binaryWriter) beginValue() error { +func (w *binaryWriter) beginValue(api string) error { // We have to record/empty these before calling writeLST, which // will end up using/modifying them. Ugh. name := w.fieldName @@ -557,10 +447,10 @@ func (w *binaryWriter) beginValue() error { if w.inStruct() { if name == "" { - return errors.New("ion: field name not set") + return &UsageError{api, "field name not set"} } - id, err := w.resolve(name) + id, err := w.resolve(api, name) if err != nil { return err } @@ -577,7 +467,7 @@ func (w *binaryWriter) beginValue() error { idlen := uint64(0) for i, a := range as { - id, err := w.resolve(a) + id, err := w.resolve(api, a) if err != nil { return err } @@ -594,6 +484,8 @@ func (w *binaryWriter) beginValue() error { buf = appendVarUint(buf, id) } + // TODO: We could theoretically write the actual tag here if we know the + // length of the value ahead of time. w.bufs.push(&container{code: 0xE0}) if err := w.write(buf); err != nil { return err @@ -617,8 +509,8 @@ func (w *binaryWriter) endValue() error { } // Begin begins writing a new container. -func (w *binaryWriter) begin(t ctx, code byte) error { - if err := w.beginValue(); err != nil { +func (w *binaryWriter) begin(api string, t ctx, code byte) error { + if err := w.beginValue(api); err != nil { return err } @@ -629,9 +521,9 @@ func (w *binaryWriter) begin(t ctx, code byte) error { } // End ends writing a container, emitting its buffered contents up a level in the stack. -func (w *binaryWriter) end(t ctx) error { +func (w *binaryWriter) end(api string, t ctx) error { if w.ctx.peek() != t { - return errors.New("ion: not in that kind of container") + return &UsageError{api, "not in that kind of container"} } seq := w.bufs.peek() @@ -647,3 +539,24 @@ func (w *binaryWriter) end(t ctx) error { return w.endValue() } + +// Resolve resolves a symbol to its ID. +func (w *binaryWriter) resolve(api, sym string) (uint64, error) { + if strings.HasPrefix(sym, "$") { + id, err := strconv.ParseUint(sym[1:], 10, 64) + if err == nil { + return id, nil + } + } + + if w.lst != nil { + id, ok := w.lst.FindByName(sym) + if !ok { + return 0, &UsageError{api, fmt.Sprintf("symbol '%v' not defined", sym)} + } + return id, nil + } + + id, _ := w.lstb.Add(sym) + return id, nil +} diff --git a/binarywriter_test.go b/binarywriter_test.go index db0f51ca..c1cbaaaa 100644 --- a/binarywriter_test.go +++ b/binarywriter_test.go @@ -139,12 +139,14 @@ func TestWriteBinarySymbol(t *testing.T) { 0x71, 0x04, // name 0x71, 0x05, // version 0x71, 0x09, // $ion_shared_symbol_table + 0x74, 0xFF, 0xFF, 0xFF, 0xFF, // $4294967295 } testBinaryWriter(t, eval, func(w Writer) { w.WriteSymbol("$ion") w.WriteSymbol("name") w.WriteSymbol("version") w.WriteSymbol("$ion_shared_symbol_table") + w.WriteSymbol("$4294967295") }) } diff --git a/bitstream.go b/bitstream.go index 34ddc21c..512aa109 100644 --- a/bitstream.go +++ b/bitstream.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "encoding/binary" - "errors" "fmt" "io" "math" @@ -91,6 +90,7 @@ func (b bitcode) String() string { } } +// A bitstream is a low-level parser for binary Ion values. type bitstream struct { in *bufio.Reader pos uint64 @@ -102,28 +102,39 @@ type bitstream struct { len uint64 } +// Init initializes this stream with the given bufio.Reader. func (b *bitstream) Init(in *bufio.Reader) { b.in = in } +// InitBytes initializes this stream with the given bytes. func (b *bitstream) InitBytes(in []byte) { b.in = bufio.NewReader(bytes.NewReader(in)) } +// Code returns the typecode of the current value. func (b *bitstream) Code() bitcode { return b.code } -func (b *bitstream) Null() bool { +// IsNull returns true if the current value is null. +func (b *bitstream) IsNull() bool { return b.null } +// Pos returns the current position. +func (b *bitstream) Pos() uint64 { + return b.pos +} + +// Len returns the length of the current value. func (b *bitstream) Len() uint64 { return b.len } +// Next advances the stream to the next value. func (b *bitstream) Next() error { - // If we have an unread value, skip over it to the next one. + // If we have an unread value, skip over it to get to the next one. switch b.state { case bssOnValue, bssOnFieldID: if err := b.SkipValue(); err != nil { @@ -153,15 +164,16 @@ func (b *bitstream) Next() error { return err } - // Found the actual end of the file. + // Found the end of the file. if c == -1 { b.code = bitcodeEOF return nil } + // Parse the tag. code, len := parseTag(c) if code == bitcodeNone { - return fmt.Errorf("ion: invalid tag byte: 0x%X", c) + return &InvalidTagByteError{byte(c), b.pos - 1} } b.state = bssOnValue @@ -171,7 +183,7 @@ func (b *bitstream) Next() error { case 0: // This value is actually a BVM. It's invalid if we're not at the top level. if !b.stack.empty() { - return errors.New("ion: BVM in a container") + return &SyntaxError{"invalid BVM in a container", b.pos - 1} } b.code = bitcodeBVM b.len = 3 @@ -179,11 +191,11 @@ func (b *bitstream) Next() error { case 0x0F: // No such thing as a null annotation. - return fmt.Errorf("ion: invalid tag byte: 0x%X", c) + return &InvalidTagByteError{byte(c), b.pos - 1} } } - // Booleans are a bit special. + // Booleans are a bit special; the 'length' stores the value. if code == bitcodeFalse { switch len { case 0, 0x0F: @@ -192,8 +204,8 @@ func (b *bitstream) Next() error { code = bitcodeTrue len = 0 default: - // Other forms of bool are invalid. - return fmt.Errorf("ion: invalid tag byte: 0x%X", c) + // Other forms are invalid. + return &InvalidTagByteError{byte(c), b.pos - 1} } } @@ -204,12 +216,22 @@ func (b *bitstream) Next() error { return nil } + pos := b.pos + rem := b.remaining() + // This value's actual len is encoded as a separate varUint. if len == 0x0E { - len, err = b.readVarUint() + var lenlen uint64 + len, lenlen, err = b.readVarUintLen(rem) if err != nil { return err } + rem -= lenlen + } + + if len > rem { + msg := fmt.Sprintf("value overruns its container: %v vs %v", len, rem) + return &SyntaxError{msg, pos - 1} } b.code = code @@ -217,9 +239,11 @@ func (b *bitstream) Next() error { return nil } +// SkipValue skips over the current value. func (b *bitstream) SkipValue() error { switch b.state { case bssBeforeFieldID, bssBeforeValue: + // No current value to skip yet. return nil case bssOnFieldID: @@ -235,15 +259,16 @@ func (b *bitstream) SkipValue() error { } } b.state = b.stateAfterValue() - } - b.code = bitcodeNone - b.null = false - b.len = 0 + default: + panic(fmt.Sprintf("invalid state %v", b.state)) + } + b.clear() return nil } +// StepIn steps in to a container. func (b *bitstream) StepIn() { switch b.code { case bitcodeStruct: @@ -253,27 +278,28 @@ func (b *bitstream) StepIn() { b.state = bssBeforeValue default: - panic(fmt.Sprintf("called StepIn with code=%v", b.code)) + panic(fmt.Sprintf("StepIn called with b.code=%v", b.code)) } b.stack.push(b.code, b.pos+b.len) - b.code = bitcodeNone - b.len = 0 + b.clear() } +// StepOut steps out of a container. func (b *bitstream) StepOut() error { if b.stack.empty() { - panic("called StepOut at top level") + panic("StepOut called at top level") } cur := b.stack.peek() b.stack.pop() if cur.end < b.pos { - panic("end greater than b.pos") + panic(fmt.Sprintf("end (%v) greater than b.pos (%v)", cur.end, b.pos)) } - diff := cur.end - b.pos + + // Skip over anything left in the container we're stepping out of. if diff > 0 { if err := b.skip(diff); err != nil { return err @@ -281,56 +307,64 @@ func (b *bitstream) StepOut() error { } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.null = false - b.len = 0 + b.clear() return nil } +// ReadBVM reads a binary version marker, returning its major and minor version. func (b *bitstream) ReadBVM() (byte, byte, error) { if b.code != bitcodeBVM { - return 0, 0, errors.New("ion: not a bvm") + panic("not a BVM") } - major, err := b.read() + major, err := b.read1() if err != nil { return 0, 0, err } - if major == -1 { - return 0, 0, errors.New("ion: unexpected end of input") - } - minor, err := b.read() + minor, err := b.read1() if err != nil { return 0, 0, err } - if minor == -1 { - return 0, 0, errors.New("ion: unexpected end of input") - } - end, err := b.read() + end, err := b.read1() if err != nil { return 0, 0, err } - if end == -1 { - return 0, 0, errors.New("ion: unexpected end of input") - } if end != 0xEA { - return 0, 0, fmt.Errorf("ion: invalid BVM (0xE0 0x%X 0x%X 0x%X)", major, minor, end) + msg := fmt.Sprintf("invalid BVM: 0xE0 0x%02X 0x%02X 0x%02X", major, minor, end) + return 0, 0, &SyntaxError{msg, b.pos - 4} } b.state = bssBeforeValue - b.code = bitcodeNone - b.len = 0 + b.clear() return byte(major), byte(minor), nil } -func (b *bitstream) ReadAnnotations() ([]uint64, error) { +// ReadFieldID reads a field ID. +func (b *bitstream) ReadFieldID() (uint64, error) { + if b.code != bitcodeFieldID { + panic("not a field ID") + } + + id, err := b.readVarUint() + if err != nil { + return 0, err + } + + b.state = bssBeforeValue + b.code = bitcodeNone + + return id, nil +} + +// ReadAnnotationIDs reads a set of annotation IDs. +func (b *bitstream) ReadAnnotationIDs() ([]uint64, error) { if b.code != bitcodeAnnotation { - return nil, errors.New("ion: not an annotation") + panic("not an annotation") } alen, lenlen, err := b.readVarUintLen(b.len) @@ -339,9 +373,9 @@ func (b *bitstream) ReadAnnotations() ([]uint64, error) { } if b.len-lenlen <= alen { - // The size of the annotation is larger than the remaining free space inside the + // The size of the annotations is larger than the remaining free space inside the // annotation container. - return nil, errors.New("ion: malformed annotation") + return nil, &SyntaxError{"malformed annotation", b.pos - lenlen} } as := []uint64{} @@ -356,33 +390,15 @@ func (b *bitstream) ReadAnnotations() ([]uint64, error) { } b.state = bssBeforeValue - b.code = bitcodeNone - b.len = 0 + b.clear() return as, nil } -func (b *bitstream) ReadFieldID() (uint64, error) { - if b.code != bitcodeFieldID { - return 0, errors.New("ion: not a field id") - } - - id, err := b.readVarUint() - if err != nil { - return 0, err - } - - b.state = bssBeforeValue - b.code = bitcodeNone - - return id, nil -} - +// ReadInt reads an integer value. func (b *bitstream) ReadInt() (interface{}, error) { - switch b.code { - case bitcodeInt, bitcodeNegInt: - default: - return "", errors.New("ion: not an integer") + if b.code != bitcodeInt && b.code != bitcodeNegInt { + panic("not an integer") } bs, err := b.readN(b.len) @@ -392,16 +408,16 @@ func (b *bitstream) ReadInt() (interface{}, error) { var ret interface{} switch { - case len(bs) == 0: + case b.len == 0: // Special case for zero. ret = int64(0) - case len(bs) < 8, (len(bs) == 8 && bs[0]&0x80 == 0): + case b.len < 8, (b.len == 8 && bs[0]&0x80 == 0): // It'll fit in an int64. i := int64(0) for _, b := range bs { i <<= 8 - i |= int64(b) + i ^= int64(b) } if b.code == bitcodeNegInt { i = -i @@ -418,15 +434,15 @@ func (b *bitstream) ReadInt() (interface{}, error) { } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 + b.clear() return ret, nil } +// ReadFloat reads a float value. func (b *bitstream) ReadFloat() (float64, error) { if b.code != bitcodeFloat { - return 0, errors.New("ion: not a float") + panic("not a float") } bs, err := b.readN(b.len) @@ -448,22 +464,19 @@ func (b *bitstream) ReadFloat() (float64, error) { ret = math.Float64frombits(ui) default: - return 0, errors.New("ion: invalid float size") + return 0, &SyntaxError{"invalid float size", b.pos - b.len} } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 + b.clear() return ret, nil } +// ReadDecimal reads a decimal value. func (b *bitstream) ReadDecimal() (*Decimal, error) { if b.code != bitcodeDecimal { - return nil, errors.New("ion: not a decimal") - } - if b.len == 0 { - return NewDecimalInt(0), nil + panic("not a decimal") } d, err := b.readDecimal(b.len) @@ -472,54 +485,65 @@ func (b *bitstream) ReadDecimal() (*Decimal, error) { } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 + b.clear() return d, nil } +// ReadTimestamp reads a timestamp value. func (b *bitstream) ReadTimestamp() (time.Time, error) { if b.code != bitcodeTimestamp { - return time.Time{}, errors.New("ion: not a timestamp") + panic("not a timestamp") } - offset, olen, err := b.readVarIntLen(b.len) + len := b.len + + offset, olen, err := b.readVarIntLen(len) if err != nil { return time.Time{}, err } - b.len -= olen + len -= olen ts := []int{1, 1, 1, 0, 0, 0} - for i := 0; b.len > 0 && i < 6; i++ { - val, vlen, err := b.readVarUintLen(b.len) + for i := 0; len > 0 && i < 6; i++ { + val, vlen, err := b.readVarUintLen(len) if err != nil { return time.Time{}, err } - b.len -= vlen + len -= vlen ts[i] = int(val) } - nsecs, err := b.readNsecs() + nsecs, err := b.readNsecs(len) if err != nil { return time.Time{}, err } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 + b.clear() utc := time.Date(ts[0], time.Month(ts[1]), ts[2], ts[3], ts[4], ts[5], int(nsecs), time.UTC) return utc.In(time.FixedZone("fixed", int(offset)*60)), nil } -func (b *bitstream) readNsecs() (int64, error) { - d, err := b.readDecimal(b.len) +// ReadNsecs reads the fraction part of a timestamp and truncates it to nanoseconds. +func (b *bitstream) readNsecs(len uint64) (int, error) { + d, err := b.readDecimal(len) if err != nil { return 0, err } - return d.ShiftL(9).Trunc() + + nsec, err := d.ShiftL(9).Trunc() + if err != nil || nsec < 0 || nsec > 999999999 { + msg := fmt.Sprintf("invalid timestamp fraction: %v", d) + return 0, &SyntaxError{msg, b.pos} + } + + return int(nsec), nil } +// ReadDecimal reads a decimal value of the given length: an exponent encoded as a +// varInt, followed by an integer coefficient taking up the remaining bytes. func (b *bitstream) readDecimal(len uint64) (*Decimal, error) { exp := int64(0) coef := new(big.Int) @@ -529,22 +553,33 @@ func (b *bitstream) readDecimal(len uint64) (*Decimal, error) { if err != nil { return nil, err } + + if val > math.MaxInt32 || val < math.MinInt32 { + msg := fmt.Sprintf("decimal exponent out of range: %v", val) + return nil, &SyntaxError{msg, b.pos - vlen} + } + exp = val len -= vlen } if len > 0 { - if err := b.readIntTo(len, coef); err != nil { + if err := b.readBigInt(len, coef); err != nil { return nil, err } } - return NewDecimal(coef, int(exp)), nil + return NewDecimal(coef, int32(exp)), nil } -func (b *bitstream) ReadSymbol() (uint64, error) { +// ReadSymbolID reads a symbol value. +func (b *bitstream) ReadSymbolID() (uint64, error) { if b.code != bitcodeSymbol { - return 0, errors.New("ion: not a symbol") + panic("not a symbol") + } + + if b.len > 8 { + return 0, &SyntaxError{"symbol id too large", b.pos} } bs, err := b.readN(b.len) @@ -553,27 +588,20 @@ func (b *bitstream) ReadSymbol() (uint64, error) { } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 - - if len(bs) == 0 { - return 0, nil - } - if len(bs) > 8 { - return 0, errors.New("ion: symbol id out of range") - } + b.clear() ret := uint64(0) for _, b := range bs { ret <<= 8 - ret |= uint64(b) + ret ^= uint64(b) } return ret, nil } +// ReadString reads a string value. func (b *bitstream) ReadString() (string, error) { if b.code != bitcodeString { - return "", errors.New("ion: not a string") + panic("not a string") } bs, err := b.readN(b.len) @@ -582,15 +610,15 @@ func (b *bitstream) ReadString() (string, error) { } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 + b.clear() return string(bs), nil } +// ReadBytes reads a blob or clob value. func (b *bitstream) ReadBytes() ([]byte, error) { if b.code != bitcodeClob && b.code != bitcodeBlob { - return nil, errors.New("ion: not a lob") + panic("not a lob") } bs, err := b.readN(b.len) @@ -599,13 +627,21 @@ func (b *bitstream) ReadBytes() ([]byte, error) { } b.state = b.stateAfterValue() - b.code = bitcodeNone - b.len = 0 + b.clear() return bs, nil } -func (b *bitstream) readIntTo(len uint64, ret *big.Int) error { +// Clear clears the current code and len. +func (b *bitstream) clear() { + b.code = bitcodeNone + b.null = false + b.len = 0 +} + +// ReadBigInt reads a fixed-length integer of the given length and stores +// the value in the given big.Int. +func (b *bitstream) readBigInt(len uint64, ret *big.Int) error { bs, err := b.readN(len) if err != nil { return err @@ -625,97 +661,137 @@ func (b *bitstream) readIntTo(len uint64, ret *big.Int) error { return nil } +// ReadVarUint reads a variable-length-encoded uint. func (b *bitstream) readVarUint() (uint64, error) { - r, _, err := b.readVarUintLen(10) - return r, err + val, _, err := b.readVarUintLen(b.remaining()) + return val, err } +// ReadVarUintLen reads a variable-length-encoded uint of at most max bytes, +// returning the value and its actual length in bytes. func (b *bitstream) readVarUintLen(max uint64) (uint64, uint64, error) { - r := uint64(0) - l := uint64(0) + if max > 10 { + max = 10 + } + + val := uint64(0) + len := uint64(0) for { - c, err := b.read() + if len >= max { + return 0, 0, &SyntaxError{"varuint too large", b.pos} + } + + c, err := b.read1() if err != nil { return 0, 0, err } - if c == -1 { - return 0, 0, errors.New("ion: unexpected end of input") - } - r <<= 7 - r ^= uint64(c & 0x7F) - l++ + val <<= 7 + val ^= uint64(c & 0x7F) + len++ if c&0x80 != 0 { - return r, l, nil + return val, len, nil } + } +} + +// SkipVarUint skips over a variable-length-encoded uint. +func (b *bitstream) skipVarUint() error { + _, err := b.skipVarUintLen(b.remaining()) + return err +} - if l == max { - return 0, 0, errors.New("ion: varuint too large") +// SkipVarUintLen skips over a variable-length-encoded uint of at most max bytes. +func (b *bitstream) skipVarUintLen(max uint64) (uint64, error) { + if max > 10 { + max = 10 + } + + len := uint64(0) + for { + if len >= max { + return 0, &SyntaxError{"varuint too large", b.pos - len} + } + + c, err := b.read1() + if err != nil { + return 0, err + } + + len++ + + if c&0x80 != 0 { + return len, nil } } } +// Remaining returns the number of bytes remaining in the current container. +func (b *bitstream) remaining() uint64 { + if b.stack.empty() { + return math.MaxUint64 + } + + end := b.stack.peek().end + if b.pos > end { + panic(fmt.Sprintf("pos (%v) > end (%v)", b.pos, end)) + } + + return end - b.pos +} + +// ReadVarIntLen reads a variable-length-encoded int of at most max bytes, +// returning the value and its actual length in bytes func (b *bitstream) readVarIntLen(max uint64) (int64, uint64, error) { - c, err := b.read() + if max == 0 { + return 0, 0, &SyntaxError{"varint too large", b.pos} + } + if max > 10 { + max = 10 + } + + // Read the first byte, which contains the sign bit. + c, err := b.read1() if err != nil { return 0, 0, err } - if c == -1 { - return 0, 0, errors.New("ion: unexpected end of input") - } sign := int64(1) if c&0x40 != 0 { sign = -1 } - r := int64(c & 0x3F) - l := uint64(1) + val := int64(c & 0x3F) + len := uint64(1) + // Check if that was the last (only) byte. if c&0x80 != 0 { - return r * sign, l, nil + return val * sign, len, nil } for { - c, err := b.read() - if err != nil { - return 0, 0, err + if len >= max { + return 0, 0, &SyntaxError{"varint too large", b.pos - len} } - if c == -1 { - return 0, 0, errors.New("ion: unexpected end of input") - } - - r <<= 7 - r ^= int64(c & 0x7F) - l++ - if c&0x80 != 0 { - return r * sign, l, nil + c, err := b.read1() + if err != nil { + return 0, 0, err } - if l == max { - return 0, 0, errors.New("ion: varint too large") - } - } -} + val <<= 7 + val ^= int64(c & 0x7F) + len++ -func (b *bitstream) skipVarUint() error { - for { - c, err := b.read() - if err != nil { - return err - } - if c == -1 { - return errors.New("ion: unexpected end of input") - } if c&0x80 != 0 { - return nil + return val * sign, len, nil } } } +// StateAfterValue returns the state this stream is in after reading a value. func (b *bitstream) stateAfterValue() bss { if b.stack.peek().code == bitcodeStruct { return bssBeforeFieldID @@ -741,6 +817,7 @@ var bitcodes = []bitcode{ bitcodeAnnotation, // 0xE0 } +// ParseTag parses a tag byte into a typecode and a length. func parseTag(c int) (bitcode, uint64) { high := (c >> 4) & 0x0F low := c & 0x0F @@ -753,59 +830,90 @@ func parseTag(c int) (bitcode, uint64) { return code, uint64(low) } +// ReadN reads the next n bytes of input from the underlying stream. func (b *bitstream) readN(n uint64) ([]byte, error) { if n == 0 { return nil, nil } bs := make([]byte, n) - _, err := b.in.Read(bs) + actual, err := b.in.Read(bs) + b.pos += uint64(actual) + if err == io.EOF { - return nil, errors.New("ion: unexpected end of input") + return nil, &UnexpectedEOFError{b.pos} } if err != nil { - return nil, err + return nil, &IOError{err} } - b.pos += n return bs, nil } +// Read1 reads the next byte of input from the underlying stream, returning +// an UnexpectedEOFError if it's an EOF. +func (b *bitstream) read1() (int, error) { + c, err := b.read() + if err != nil { + return 0, err + } + if c == -1 { + return 0, &UnexpectedEOFError{b.pos} + } + return c, nil +} + +// Read reads the next byte of input from the underlying stream. It returns +// -1 instead of io.EOF if we've hit the end of the stream, because I find +// that easier to reason about. func (b *bitstream) read() (int, error) { c, err := b.in.ReadByte() + b.pos++ + if err == io.EOF { return -1, nil } if err != nil { - return 0, err + return 0, &IOError{err} } - b.pos++ return int(c), nil } +// Skip skips n bytes of input from the underlying stream. func (b *bitstream) skip(n uint64) error { - _, err := b.in.Discard(int(n)) + actual, err := b.in.Discard(int(n)) + b.pos += uint64(actual) + if err == io.EOF { return nil } - b.pos += n - return err + if err != nil { + return &IOError{err} + } + + return nil } +// A bitnode represents a container value, including its typecode and +// the offset at which it (supposedly) ends. type bitnode struct { code bitcode end uint64 } +// A stack of bitnodes representing container values that we're currently +// stepped in to. type bitstack struct { arr []bitnode } +// Empty returns true if this bitstack is empty. func (b *bitstack) empty() bool { return len(b.arr) == 0 } +// Peek peeks at the top bitnode on the stack. func (b *bitstack) peek() bitnode { if len(b.arr) == 0 { return bitnode{} @@ -813,10 +921,12 @@ func (b *bitstack) peek() bitnode { return b.arr[len(b.arr)-1] } +// Push pushes a bitnode onto the stack. func (b *bitstack) push(code bitcode, end uint64) { b.arr = append(b.arr, bitnode{code, end}) } +// Pop pops a bitnode from the stack. func (b *bitstack) pop() { if len(b.arr) == 0 { panic("pop called on empty bitstack") diff --git a/bitstream_test.go b/bitstream_test.go index 2b8dd8db..8cbf57ab 100644 --- a/bitstream_test.go +++ b/bitstream_test.go @@ -34,8 +34,8 @@ func TestBitstream(t *testing.T) { if b.Code() != code { t.Errorf("expected code=%v, got %v", code, b.Code()) } - if b.Null() != null { - t.Errorf("expected null=%v, got %v", null, b.Null()) + if b.IsNull() != null { + t.Errorf("expected null=%v, got %v", null, b.IsNull()) } if b.Len() != len { t.Errorf("expected len=%v, got %v", len, b.Len()) @@ -62,7 +62,7 @@ func TestBitstream(t *testing.T) { } next(bitcodeAnnotation, false, 31) - ids, err := b.ReadAnnotations() + ids, err := b.ReadAnnotationIDs() if err != nil { t.Fatal(err) } diff --git a/catalog.go b/catalog.go index 60a96d18..65d4f88f 100644 --- a/catalog.go +++ b/catalog.go @@ -9,17 +9,22 @@ import ( // A Catalog provides access to shared symbol tables. type Catalog interface { - Find(name string, version int) SharedSymbolTable + FindExact(name string, version int) SharedSymbolTable + FindLatest(name string) SharedSymbolTable } // A basicCatalog wraps an in-memory collection of shared symbol tables. type basicCatalog struct { - ssts map[string]SharedSymbolTable + ssts map[string]SharedSymbolTable + latest map[string]SharedSymbolTable } // NewCatalog creates a new basic catalog containing the given symbol tables. func NewCatalog(ssts ...SharedSymbolTable) Catalog { - cat := &basicCatalog{make(map[string]SharedSymbolTable)} + cat := &basicCatalog{ + ssts: make(map[string]SharedSymbolTable), + latest: make(map[string]SharedSymbolTable), + } for _, sst := range ssts { cat.add(sst) } @@ -30,14 +35,24 @@ func NewCatalog(ssts ...SharedSymbolTable) Catalog { func (c *basicCatalog) add(sst SharedSymbolTable) { key := fmt.Sprintf("%v/%v", sst.Name(), sst.Version()) c.ssts[key] = sst + + cur, ok := c.latest[sst.Name()] + if !ok || sst.Version() > cur.Version() { + c.latest[sst.Name()] = sst + } } -// Find attempts to find a shared symbol table with the given name and version. -func (c *basicCatalog) Find(name string, version int) SharedSymbolTable { +// FindExact attempts to find a shared symbol table with the given name and version. +func (c *basicCatalog) FindExact(name string, version int) SharedSymbolTable { key := fmt.Sprintf("%v/%v", name, version) return c.ssts[key] } +// FindLatest finds the shared symbol table with the given name and largest version. +func (c *basicCatalog) FindLatest(name string) SharedSymbolTable { + return c.latest[name] +} + // A System is a reader factory wrapping a catalog. type System struct { Catalog Catalog diff --git a/decimal.go b/decimal.go index f8ad6cd5..1b522286 100644 --- a/decimal.go +++ b/decimal.go @@ -1,7 +1,6 @@ package ion import ( - "errors" "fmt" "math" "math/big" @@ -9,16 +8,27 @@ import ( "strings" ) +// A ParseError is returned if ParseDecimal is called with a parameter that +// cannot be parsed as a Decimal. +type ParseError struct { + Num string + Msg string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("ion: ParseDecimal(%v): %v", e.Num, e.Msg) +} + // TODO: Explicitly track precision? // Decimal is an arbitrary-precision decimal value. type Decimal struct { n *big.Int - scale int + scale int32 } // NewDecimal creates a new decimal whose value is equal to n * 10^exp. -func NewDecimal(n *big.Int, exp int) *Decimal { +func NewDecimal(n *big.Int, exp int32) *Decimal { return &Decimal{ n: n, scale: -exp, @@ -44,25 +54,25 @@ func MustParseDecimal(in string) *Decimal { // returning an error on failure. func ParseDecimal(in string) (*Decimal, error) { if len(in) == 0 { - return nil, errors.New("empty string") + return nil, &ParseError{in, "empty string"} } - exponent := 0 + exponent := int32(0) d := strings.IndexAny(in, "Dd") if d != -1 { // There's an explicit exponent. exp := in[d+1:] if len(exp) == 0 { - return nil, errors.New("unexpected end of input after d") + return nil, &ParseError{in, "unexpected end of input after d"} } tmp, err := strconv.ParseInt(exp, 10, 32) if err != nil { - return nil, err + return nil, &ParseError{in, err.Error()} } - exponent = int(tmp) + exponent = int32(tmp) in = in[:d] } @@ -72,21 +82,21 @@ func ParseDecimal(in string) (*Decimal, error) { ipart := in[:d] fpart := in[d+1:] - exponent -= len(fpart) + exponent -= int32(len(fpart)) in = ipart + fpart } n, ok := new(big.Int).SetString(in, 10) if !ok { // Unfortunately this is all we get? - return nil, fmt.Errorf("not a valid number: %v", in) + return nil, &ParseError{in, "cannot parse coefficient"} } return NewDecimal(n, exponent), nil } // CoEx returns this decimal's coefficient and exponent. -func (d *Decimal) CoEx() (*big.Int, int) { +func (d *Decimal) CoEx() (*big.Int, int32) { return d.n, -d.scale } @@ -135,7 +145,7 @@ func (d *Decimal) Mul(o *Decimal) *Decimal { return &Decimal{ n: new(big.Int).Mul(d.n, o.n), - scale: int(scale), + scale: int32(scale), } } @@ -150,7 +160,7 @@ func (d *Decimal) ShiftL(shift int) *Decimal { return &Decimal{ n: d.n, - scale: int(scale), + scale: int32(scale), } } @@ -165,7 +175,7 @@ func (d *Decimal) ShiftR(shift int) *Decimal { return &Decimal{ n: d.n, - scale: int(scale), + scale: int32(scale), } } @@ -206,7 +216,7 @@ var ten = big.NewInt(10) // do that. (1d100 -> 10d99). Makes comparisons and math easier, at the // expense of more storage space. Technically speaking implies adding // more precision, but we're not tracking that too closely. -func (d *Decimal) upscale(scale int) *Decimal { +func (d *Decimal) upscale(scale int32) *Decimal { diff := int64(scale) - int64(d.scale) if diff < 0 { panic("can't upscale to a smaller scale") @@ -221,22 +231,29 @@ func (d *Decimal) upscale(scale int) *Decimal { } } -// Trunc attempts to truncate this decimal to an int64. Use at your own risk. +// Trunc attempts to truncate this decimal to an int64, dropping any fractional bits. func (d *Decimal) Trunc() (int64, error) { if d.scale < 0 { - // TODO: safety in case scale is very small? + // Don't even bother trying this with numbers that *definitely* too big to represent + // as an int64, because upscale(0) will consume a bunch of memory. + if d.scale < -20 { + return 0, &strconv.NumError{ + Func: "ParseInt", + Num: d.String(), + Err: strconv.ErrRange, + } + } d = d.upscale(0) } str := d.n.String() - want := len(str) - d.scale + want := len(str) - int(d.scale) if want <= 0 { return 0, nil } - trunc := str[:want] - return strconv.ParseInt(trunc, 10, 64) + return strconv.ParseInt(str[:want], 10, 64) } // Truncate returns a new decimal, truncated to the given number of @@ -276,7 +293,7 @@ func (d *Decimal) Truncate(precision int) *Decimal { return &Decimal{ n: n, - scale: int(scale), + scale: int32(scale), } } @@ -294,7 +311,7 @@ func (d *Decimal) String() string { default: // Value is a downscaled integer nn.nn('d'-ss)? str := d.n.String() - idx := len(str) - d.scale + idx := len(str) - int(d.scale) prefix := 1 if d.n.Sign() < 0 { diff --git a/decimal_test.go b/decimal_test.go index 6db5a2c5..21b72c9e 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -7,7 +7,7 @@ import ( ) func TestDecimalToString(t *testing.T) { - test := func(n int64, scale int, expected string) { + test := func(n int64, scale int32, expected string) { t.Run(expected, func(t *testing.T) { d := Decimal{ n: big.NewInt(n), @@ -50,7 +50,7 @@ func TestDecimalToString(t *testing.T) { } func TestParseDecimal(t *testing.T) { - test := func(in string, n *big.Int, scale int) { + test := func(in string, n *big.Int, scale int32) { t.Run(in, func(t *testing.T) { d, err := ParseDecimal(in) if err != nil { diff --git a/err.go b/err.go new file mode 100644 index 00000000..737076a9 --- /dev/null +++ b/err.go @@ -0,0 +1,88 @@ +package ion + +import "fmt" + +// A UsageError is returned when you use a Reader or Writer in an inappropriate way. +type UsageError struct { + API string + Msg string +} + +func (e *UsageError) Error() string { + return fmt.Sprintf("ion: usage error in %v: %v", e.API, e.Msg) +} + +// An IOError is returned when there is an error reading from or writing to an +// underlying io.Reader or io.Writer. +type IOError struct { + Err error +} + +func (e *IOError) Error() string { + return fmt.Sprintf("ion: i/o error: %v", e.Err) +} + +// A SyntaxError is returned when a Reader encounters invalid input for which no more +// specific error type is defined. +type SyntaxError struct { + Msg string + Offset uint64 +} + +func (e *SyntaxError) Error() string { + return fmt.Sprintf("ion: syntax error: %v (offset %v)", e.Msg, e.Offset) +} + +// An UnexpectedEOFError is returned when a Reader unexpectedly encounters an +// io.EOF error. +type UnexpectedEOFError struct { + Offset uint64 +} + +func (e *UnexpectedEOFError) Error() string { + return fmt.Sprintf("ion: unexpected end of input (offset %v)", e.Offset) +} + +// An UnsupportedVersionError is returned when a Reader encounters a binary version +// marker with a version that this library does not understand. +type UnsupportedVersionError struct { + Major int + Minor int + Offset uint64 +} + +func (e *UnsupportedVersionError) Error() string { + return fmt.Sprintf("ion: unsupported version %v.%v (offset %v)", e.Major, e.Minor, e.Offset) +} + +// An InvalidTagByteError is returned when a binary Reader encounters an invalid +// tag byte. +type InvalidTagByteError struct { + Byte byte + Offset uint64 +} + +func (e *InvalidTagByteError) Error() string { + return fmt.Sprintf("ion: invalid tag byte 0x%02X (offset %v)", e.Byte, e.Offset) +} + +// An UnexpectedRuneError is returned when a text Reader encounters an unexpected rune. +type UnexpectedRuneError struct { + Rune rune + Offset uint64 +} + +func (e *UnexpectedRuneError) Error() string { + return fmt.Sprintf("ion: unexpected rune %q (offset %v)", e.Rune, e.Offset) +} + +// An UnexpectedTokenError is returned when a text Reader encounters an unexpected +// token. +type UnexpectedTokenError struct { + Token string + Offset uint64 +} + +func (e *UnexpectedTokenError) Error() string { + return fmt.Sprintf("ion: unexpected token '%v' (offset %v)", e.Token, e.Offset) +} diff --git a/reader.go b/reader.go index 2463c8db..9f16c5af 100644 --- a/reader.go +++ b/reader.go @@ -3,7 +3,6 @@ package ion import ( "bufio" "bytes" - "errors" "io" "math" "math/big" @@ -121,6 +120,11 @@ type Reader interface { // 64 bits to represent losslessly. Int64Value() (int64, error) + // Uint64Value returns the current value as an unsigned 64-bit integer (if that makes + // sense). It returns an error if the current value is not an Ion integer, is negative, + // or requires more than 64 bits to represent losslessly. + Uint64Value() (uint64, error) + // BigIntValue returns the current value as a big.Integer (if that makes sense). It // returns an error if the current value is not an Ion integer. BigIntValue() (*big.Int, error) @@ -213,19 +217,19 @@ func (r *reader) Annotations() []string { // BoolValue returns the current value as a bool. func (r *reader) BoolValue() (bool, error) { - if r.valueType == BoolType { - if r.value == nil { - return false, nil - } - return r.value.(bool), nil + if r.valueType != BoolType { + return false, &UsageError{"Reader.BoolValue", "value is not a bool"} + } + if r.value == nil { + return false, nil } - return false, errors.New("ion: value is not a bool") + return r.value.(bool), nil } // IntSize returns the size of the current int value. func (r *reader) IntSize() (IntSize, error) { if r.valueType != IntType { - return NullInt, errors.New("ion: value is not an int") + return NullInt, &UsageError{"Reader.IntSize", "value is not a int"} } if r.value == nil { return NullInt, nil @@ -238,6 +242,11 @@ func (r *reader) IntSize() (IntSize, error) { return Int32, nil } + i := r.value.(*big.Int) + if i.IsUint64() { + return Uint64, nil + } + return BigInt, nil } @@ -248,101 +257,129 @@ func (r *reader) IntValue() (int, error) { return 0, err } if i > math.MaxInt32 || i < math.MinInt32 { - return 0, errors.New("ion: int value out of bounds") + return 0, &UsageError{"Reader.IntValue", "value too large for an int32"} } return int(i), nil } // Int64Value returns the current value as an int64. func (r *reader) Int64Value() (int64, error) { - if r.valueType == IntType { - if r.value == nil { - return 0, nil - } + if r.valueType != IntType { + return 0, &UsageError{"Reader.Int64Value", "value is not an int"} + } + if r.value == nil { + return 0, nil + } - if i, ok := r.value.(int64); ok { - return i, nil - } + if i, ok := r.value.(int64); ok { + return i, nil + } + + bi := r.value.(*big.Int) + if bi.IsInt64() { + return bi.Int64(), nil + } + + return 0, &UsageError{"Reader.Int64Value", "value too large for an int64"} +} + +// Uint64Value returns the current value as a uint64. +func (r *reader) Uint64Value() (uint64, error) { + if r.valueType != IntType { + return 0, &UsageError{"Reader.Uint64Value", "value is not an int"} + } + if r.value == nil { + return 0, nil + } - bi := r.value.(*big.Int) - if bi.IsInt64() { - return bi.Int64(), nil + if i, ok := r.value.(int64); ok { + if i >= 0 { + return uint64(i), nil } + return 0, &UsageError{"Reader.Uint64Value", "value is negative"} + } - return 0, errors.New("ion: int value out of bounds") + bi := r.value.(*big.Int) + if bi.Sign() < 0 { + return 0, &UsageError{"Reader.Uint64Value", "value is negative"} } - return 0, errors.New("ion: value is not an int") + if !bi.IsUint64() { + return 0, &UsageError{"Reader.Uint64Value", "value too large for a uint64"} + } + return bi.Uint64(), nil } // BigIntValue returns the current value as a big int. func (r *reader) BigIntValue() (*big.Int, error) { - if r.valueType == IntType { - if r.value == nil { - return nil, nil - } - if i, ok := r.value.(int64); ok { - return big.NewInt(i), nil - } - return r.value.(*big.Int), nil + if r.valueType != IntType { + return nil, &UsageError{"Reader.BigIntValue", "value is not an int"} + } + if r.value == nil { + return nil, nil } - return nil, errors.New("ion: value is not an int") + + if i, ok := r.value.(int64); ok { + return big.NewInt(i), nil + } + return r.value.(*big.Int), nil } // FloatValue returns the current value as a float. func (r *reader) FloatValue() (float64, error) { - if r.valueType == FloatType { - if r.value == nil { - return 0.0, nil - } - return r.value.(float64), nil + if r.valueType != FloatType { + return 0, &UsageError{"Reader.FloatValue", "value is not a float"} + } + if r.value == nil { + return 0.0, nil } - return 0.0, errors.New("ion: value is not a float") + return r.value.(float64), nil } // DecimalValue returns the current value as a Decimal. func (r *reader) DecimalValue() (*Decimal, error) { - if r.valueType == DecimalType { - if r.value == nil { - return nil, nil - } - return r.value.(*Decimal), nil + if r.valueType != DecimalType { + return nil, &UsageError{"Reader.DecimalValue", "value is not a decimal"} + } + if r.value == nil { + return nil, nil } - return nil, errors.New("ion: value is not a decimal") + return r.value.(*Decimal), nil } // TimeValue returns the current value as a time. func (r *reader) TimeValue() (time.Time, error) { - if r.valueType == TimestampType { - if r.value == nil { - return time.Time{}, nil - } - return r.value.(time.Time), nil + if r.valueType != TimestampType { + return time.Time{}, &UsageError{"Reader.TimestampValue", "value is not a timestamp"} + } + if r.value == nil { + return time.Time{}, nil } - return time.Time{}, errors.New("ion: value is not a timestamp") + return r.value.(time.Time), nil } // StringValue returns the current value as a string. func (r *reader) StringValue() (string, error) { - if r.valueType == StringType || r.valueType == SymbolType { - if r.value == nil { - return "", nil - } - return r.value.(string), nil + if r.valueType != StringType && r.valueType != SymbolType { + return "", &UsageError{"Reader.StringValue", "value is not a string"} + } + if r.value == nil { + return "", nil } - return "", errors.New("ion: value is not a string") + return r.value.(string), nil } // ByteValue returns the current value as a byte slice. func (r *reader) ByteValue() ([]byte, error) { - if r.valueType == BlobType || r.valueType == ClobType { - if r.value == nil { - return nil, nil - } - return r.value.([]byte), nil + if r.valueType != BlobType && r.valueType != ClobType { + return nil, &UsageError{"Reader.ByteValue", "value is not a lob"} + } + if r.value == nil { + return nil, nil } - return nil, errors.New("ion: value is not a byte array") + return r.value.([]byte), nil } +// Clear clears the current value from the reader. func (r *reader) clear() { r.fieldName = "" r.annotations = nil diff --git a/skipper.go b/skipper.go index 2ce458bf..cf74ebb7 100644 --- a/skipper.go +++ b/skipper.go @@ -91,7 +91,7 @@ func (t *tokenizer) skipValue() (int, error) { case tokenOpenBracket: c, err = t.skipList() default: - err = fmt.Errorf("skipValue called with token=%v", t.token) + panic(fmt.Sprintf("skipValue called with token=%v", t.token)) } if err != nil { @@ -161,7 +161,7 @@ func (t *tokenizer) skipNumber() (int, error) { return 0, err } if !ok { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } return c, nil } @@ -199,7 +199,7 @@ func (t *tokenizer) skipRadix(pok, dok matcher) (int, error) { } if c != '0' { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } if err = t.expect(pok); err != nil { return 0, err @@ -220,7 +220,7 @@ func (t *tokenizer) skipRadix(pok, dok matcher) (int, error) { return 0, err } if !ok { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } return c, nil @@ -238,7 +238,7 @@ func (t *tokenizer) skipTimestamp() (int, error) { return t.read() } if c != '-' { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } // Read the next two, yyyy-mm. @@ -251,7 +251,7 @@ func (t *tokenizer) skipTimestamp() (int, error) { return t.read() } if c != '-' { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } // Read the day. @@ -283,7 +283,7 @@ func (t *tokenizer) skipTimestamp() (int, error) { return 0, err } if c != ':' { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } c, err = t.skipTimestampDigits(2) @@ -340,7 +340,7 @@ func (t *tokenizer) skipTimestampOffsetOrZ(c int) (int, error) { if c == 'z' || c == 'Z' { return t.read() } - return 0, invalidChar(c) + return 0, t.invalidChar(c) } // SkipTimestampOffset skips an (optional) +-hh:mm timestamp zone offset @@ -355,7 +355,7 @@ func (t *tokenizer) skipTimestampOffset(c int) (int, error) { return 0, err } if c != ':' { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } return t.skipTimestampDigits(2) } @@ -383,7 +383,7 @@ func (t *tokenizer) skipTimestampFinish(c int) (int, error) { return 0, err } if !ok { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } return c, nil } @@ -423,7 +423,7 @@ func (t *tokenizer) skipSymbolQuotedHelper() error { switch c { case -1, '\n': - return invalidChar(c) + return t.invalidChar(c) case '\'': return nil @@ -471,7 +471,7 @@ func (t *tokenizer) skipStringHelper() error { switch c { case -1, '\n': - return invalidChar(c) + return t.invalidChar(c) case '"': return nil @@ -503,7 +503,7 @@ func (t *tokenizer) skipLongStringHelper(handler commentHandler) error { switch c { case -1: - return invalidChar(c) + return t.invalidChar(c) case '\'': ok, err := t.skipEndOfLongString(handler) @@ -586,7 +586,7 @@ func (t *tokenizer) skipBlobHelper() error { return err } if c == -1 { - return invalidChar(c) + return t.invalidChar(c) } } @@ -645,7 +645,7 @@ func (t *tokenizer) skipContainerHelper(term int) error { switch c { case -1: - return invalidChar(c) + return t.invalidChar(c) case term: return nil @@ -832,7 +832,7 @@ func (t *tokenizer) skipBlockComment() error { return err } if c == -1 { - return invalidChar(c) + return t.invalidChar(c) } if star && c == '/' { diff --git a/skipper_test.go b/skipper_test.go index 743dcddf..afb48874 100644 --- a/skipper_test.go +++ b/skipper_test.go @@ -14,7 +14,7 @@ func TestSkipNumber(t *testing.T) { test("1d45\n", '\n') test("1.4e-12//", '/') - testErr("1.2d3d", "unexpected char 'd'") + testErr("1.2d3d", "ion: unexpected rune 'd' (offset 5)") } func TestSkipBinary(t *testing.T) { @@ -24,7 +24,7 @@ func TestSkipBinary(t *testing.T) { test("-0b10 ", ' ') test("0b010101,", ',') - testErr("0b2", "unexpected char '2'") + testErr("0b2", "ion: unexpected rune '2' (offset 2)") } func TestSkipHex(t *testing.T) { @@ -34,7 +34,7 @@ func TestSkipHex(t *testing.T) { test("-0x0F ", ' ') test("0x1234567890abcdefABCDEF,", ',') - testErr("0x0G", "unexpected char 'G'") + testErr("0x0G", "ion: unexpected rune 'G' (offset 3)") } func TestSkipTimestamp(t *testing.T) { @@ -55,17 +55,17 @@ func TestSkipTimestamp(t *testing.T) { test("2001-01-02T03:04:05.666Z ", ' ') test("2001-01-02T03:04:05.666666z ", ' ') - testErr("", "unexpected EOF") - testErr("2001", "unexpected EOF") - testErr("2001z", "unexpected char 'z'") - testErr("20011", "unexpected char '1'") - testErr("2001-0", "unexpected EOF") - testErr("2001-01", "unexpected EOF") - testErr("2001-01-02Tz", "unexpected char 'z'") - testErr("2001-01-02T03", "unexpected EOF") - testErr("2001-01-02T03z", "unexpected char 'z'") - testErr("2001-01-02T03:04x ", "unexpected char 'x'") - testErr("2001-01-02T03:04:05x ", "unexpected char 'x'") + testErr("", "ion: unexpected end of input (offset 0)") + testErr("2001", "ion: unexpected end of input (offset 4)") + testErr("2001z", "ion: unexpected rune 'z' (offset 4)") + testErr("20011", "ion: unexpected rune '1' (offset 4)") + testErr("2001-0", "ion: unexpected end of input (offset 6)") + testErr("2001-01", "ion: unexpected end of input (offset 7)") + testErr("2001-01-02Tz", "ion: unexpected rune 'z' (offset 11)") + testErr("2001-01-02T03", "ion: unexpected end of input (offset 13)") + testErr("2001-01-02T03z", "ion: unexpected rune 'z' (offset 13)") + testErr("2001-01-02T03:04x ", "ion: unexpected rune 'x' (offset 16)") + testErr("2001-01-02T03:04:05x ", "ion: unexpected rune 'x' (offset 19)") } func TestSkipSymbol(t *testing.T) { @@ -90,8 +90,8 @@ func TestSkipSymbolQuoted(t *testing.T) { test("foo\\'bar':", ':') test("foo\\\nbar',", ',') - testErr("foo", "unexpected EOF") - testErr("foo\n", "unexpected char '\\n'") + testErr("foo", "ion: unexpected end of input (offset 3)") + testErr("foo\n", "ion: unexpected rune '\\n' (offset 3)") } func TestSkipSymbolOperator(t *testing.T) { @@ -111,8 +111,8 @@ func TestSkipString(t *testing.T) { test("foo\\\"bar\"], \"\"", ']') test("foo\\\nbar\" \t\t\t", ' ') - testErr("foobar", "unexpected EOF") - testErr("foobar\n", "unexpected char '\\n'") + testErr("foobar", "ion: unexpected end of input (offset 6)") + testErr("foobar\n", "ion: unexpected rune '\\n' (offset 6)") } func TestSkipLongString(t *testing.T) { @@ -132,10 +132,10 @@ func TestSkipBlob(t *testing.T) { test("oogboog}},{{}}", ',') test("'''not encoded'''}}\n", '\n') - testErr("", "unexpected EOF") - testErr("oogboog", "unexpected EOF") - testErr("oogboog}", "unexpected EOF") - testErr("oog}{boog", "unexpected char '{'") + testErr("", "ion: unexpected end of input (offset 1)") + testErr("oogboog", "ion: unexpected end of input (offset 7)") + testErr("oogboog}", "ion: unexpected end of input (offset 8)") + testErr("oog}{boog", "ion: unexpected rune '{' (offset 4)") } func TestSkipList(t *testing.T) { @@ -145,7 +145,7 @@ func TestSkipList(t *testing.T) { test("[]],", ',') test("[123, \"]\", ']']] ", ' ') - testErr("abc, def, ", "unexpected EOF") + testErr("abc, def, ", "ion: unexpected end of input (offset 10)") } type skipFunc func(*tokenizer) (int, error) diff --git a/symboltable.go b/symboltable.go index e87b239d..922a1dc3 100644 --- a/symboltable.go +++ b/symboltable.go @@ -1,19 +1,23 @@ package ion import ( - "errors" "strings" ) // A SymbolTable maps binary-representation symbol IDs to // text-representation strings and vice versa. type SymbolTable interface { + // Imports returns the symbol tables this table imports. + Imports() []SharedSymbolTable + // Symbols returns the symbols this symbol table defines. + Symbols() []string // MaxID returns the maximum ID this symbol table defines. - MaxID() int + MaxID() uint64 + // FindByName finds the ID of a symbol by its name. - FindByName(symbol string) (int, bool) + FindByName(symbol string) (uint64, bool) // FindByID finds the name of a symbol given its ID. - FindByID(id int) (string, bool) + FindByID(id uint64) (string, bool) // WriteTo serializes the symbol table to an ion.Writer. WriteTo(w Writer) error // String returns an ion text representation of the symbol table. @@ -21,54 +25,42 @@ type SymbolTable interface { } // A SharedSymbolTable is distributed out-of-band and referenced from -// a LocalSymbolTable to save space. +// a local SymbolTable to save space. type SharedSymbolTable interface { SymbolTable + // Name returns the name of this shared symbol table. Name() string + // Version returns the version of this shared symbol table. Version() int + // Adjust returns a new shared symbol table limited or extended to the given max ID. + Adjust(maxID uint64) SharedSymbolTable } type sst struct { name string version int symbols []string - index map[string]int + index map[string]uint64 + maxID uint64 } // NewSharedSymbolTable creates a new shared symbol table. func NewSharedSymbolTable(name string, version int, symbols []string) SharedSymbolTable { - if name == "" { - panic("name must be non-empty") - } - if version < 1 { - panic("version must be at least one") - } + syms := make([]string, len(symbols)) + copy(syms, symbols) - index, copy := buildIndex(symbols, 0) + index := buildIndex(syms, 1) return &sst{ name: name, version: version, - symbols: copy, + symbols: syms, index: index, + maxID: uint64(len(syms)), } } -func buildIndex(symbols []string, offset int) (map[string]int, []string) { - index := map[string]int{} - copy := []string{} - - for _, sym := range symbols { - if _, ok := index[sym]; !ok { - copy = append(copy, sym) - index[sym] = offset + len(copy) - } - } - - return index, copy -} - func (s *sst) Name() string { return s.name } @@ -77,17 +69,57 @@ func (s *sst) Version() int { return s.version } -func (s *sst) MaxID() int { - return len(s.symbols) +func (s *sst) Imports() []SharedSymbolTable { + return nil } -func (s *sst) FindByName(sym string) (int, bool) { +func (s *sst) Symbols() []string { + syms := make([]string, s.maxID) + copy(syms, s.symbols) + return syms +} + +func (s *sst) MaxID() uint64 { + return uint64(s.maxID) +} + +func (s *sst) Adjust(maxID uint64) SharedSymbolTable { + if maxID == s.maxID { + // Nothing needs to change. + return s + } + + if maxID > uint64(len(s.symbols)) { + // Old index will work fine, just adjust the maxID. + return &sst{ + name: s.name, + version: s.version, + symbols: s.symbols, + index: s.index, + maxID: maxID, + } + } + + // Slice the symbols down to size and reindex. + symbols := s.symbols[:maxID] + index := buildIndex(symbols, 1) + + return &sst{ + name: s.name, + version: s.version, + symbols: symbols, + index: index, + maxID: maxID, + } +} + +func (s *sst) FindByName(sym string) (uint64, bool) { id, ok := s.index[sym] - return id, ok + return uint64(id), ok } -func (s *sst) FindByID(id int) (string, bool) { - if id <= 0 || id > len(s.symbols) { +func (s *sst) FindByID(id uint64) (string, bool) { + if id <= 0 || id > uint64(len(s.symbols)) { return "", false } return s.symbols[id-1], true @@ -143,9 +175,11 @@ var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ type bogusSST struct { name string version int - maxID int + maxID uint64 } +var _ SharedSymbolTable = &bogusSST{} + func (s *bogusSST) Name() string { return s.name } @@ -154,20 +188,36 @@ func (s *bogusSST) Version() int { return s.version } -func (s *bogusSST) MaxID() int { +func (s *bogusSST) Imports() []SharedSymbolTable { + return nil +} + +func (s *bogusSST) Symbols() []string { + return nil +} + +func (s *bogusSST) MaxID() uint64 { return s.maxID } -func (s *bogusSST) FindByName(sym string) (int, bool) { +func (s *bogusSST) Adjust(maxID uint64) SharedSymbolTable { + return &bogusSST{ + name: s.name, + version: s.version, + maxID: maxID, + } +} + +func (s *bogusSST) FindByName(sym string) (uint64, bool) { return 0, false } -func (s *bogusSST) FindByID(id int) (string, bool) { +func (s *bogusSST) FindByID(id uint64) (string, bool) { return "", false } func (s *bogusSST) WriteTo(w Writer) error { - return errors.New("ion: bogusSST does not implement WriteTo") + return &UsageError{"SharedSymbolTable.WriteTo", "bogus symbol table should never be written"} } func (s *bogusSST) String() string { @@ -183,7 +233,7 @@ func (s *bogusSST) String() string { w.WriteInt(int64(s.version)) w.FieldName("max_id") - w.WriteInt(int64(s.maxID)) + w.WriteUint(s.maxID) w.EndStruct() return buf.String() @@ -193,53 +243,47 @@ func (s *bogusSST) String() string { // it describes. It may include SharedSymbolTables by reference. type lst struct { imports []SharedSymbolTable - offsets []int - maxImportID int + offsets []uint64 + maxImportID uint64 symbols []string - index map[string]int + index map[string]uint64 } // NewLocalSymbolTable creates a new local symbol table. func NewLocalSymbolTable(imports []SharedSymbolTable, symbols []string) SymbolTable { imps, offsets, maxID := processImports(imports) - index, copy := buildIndex(symbols, maxID) + syms := make([]string, len(symbols)) + copy(syms, symbols) + + index := buildIndex(syms, maxID+1) return &lst{ imports: imps, offsets: offsets, maxImportID: maxID, - symbols: copy, + symbols: syms, index: index, } } -func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []int, int) { - var imps []SharedSymbolTable - if len(imports) > 0 && imports[0].Name() == "$ion" { - imps = make([]SharedSymbolTable, len(imports)) - copy(imps, imports) - } else { - imps = make([]SharedSymbolTable, len(imports)+1) - imps[0] = V1SystemSymbolTable - copy(imps[1:], imports) - } - - maxID := 0 - offsets := make([]int, len(imps)) - for i, imp := range imps { - offsets[i] = maxID - maxID += imp.MaxID() - } +func (t *lst) Imports() []SharedSymbolTable { + imps := make([]SharedSymbolTable, len(t.imports)) + copy(imps, t.imports) + return imps +} - return imps, offsets, maxID +func (t *lst) Symbols() []string { + syms := make([]string, len(t.symbols)) + copy(syms, t.symbols) + return syms } -func (t *lst) MaxID() int { - return t.maxImportID + len(t.symbols) +func (t *lst) MaxID() uint64 { + return t.maxImportID + uint64(len(t.symbols)) } -func (t *lst) FindByName(s string) (int, bool) { +func (t *lst) FindByName(s string) (uint64, bool) { for i, imp := range t.imports { if id, ok := imp.FindByName(s); ok { return t.offsets[i] + id, true @@ -253,7 +297,7 @@ func (t *lst) FindByName(s string) (int, bool) { return 0, false } -func (t *lst) FindByID(id int) (string, bool) { +func (t *lst) FindByID(id uint64) (string, bool) { if id <= 0 { return "", false } @@ -263,16 +307,16 @@ func (t *lst) FindByID(id int) (string, bool) { // Local to this symbol table. idx := id - t.maxImportID - 1 - if idx < len(t.symbols) { + if idx < uint64(len(t.symbols)) { return t.symbols[idx], true } return "", false } -func (t *lst) findByIDInImports(id int) (string, bool) { +func (t *lst) findByIDInImports(id uint64) (string, bool) { i := 1 - off := 0 + off := uint64(0) for ; i < len(t.imports); i++ { if id <= t.offsets[i] { @@ -306,7 +350,7 @@ func (t *lst) WriteTo(w Writer) error { w.WriteInt(int64(imp.Version())) w.FieldName("max_id") - w.WriteInt(int64(imp.MaxID())) + w.WriteUint(imp.MaxID()) w.EndStruct() } @@ -340,7 +384,7 @@ type SymbolTableBuilder interface { SymbolTable // Add adds a symbol to this symbol table. - Add(symbol string) (int, bool) + Add(symbol string) (uint64, bool) // Build creates an immutable local symbol table. Build() SymbolTable } @@ -357,18 +401,18 @@ func NewSymbolTableBuilder(imports ...SharedSymbolTable) SymbolTableBuilder { imports: imps, offsets: offsets, maxImportID: maxID, - index: make(map[string]int), + index: make(map[string]uint64), }, } } -func (b *symbolTableBuilder) Add(symbol string) (int, bool) { +func (b *symbolTableBuilder) Add(symbol string) (uint64, bool) { if id, ok := b.FindByName(symbol); ok { return id, false } b.symbols = append(b.symbols, symbol) - id := b.maxImportID + len(b.symbols) + id := b.maxImportID + uint64(len(b.symbols)) b.index[symbol] = id return id, true @@ -376,9 +420,9 @@ func (b *symbolTableBuilder) Add(symbol string) (int, bool) { func (b *symbolTableBuilder) Build() SymbolTable { symbols := append([]string{}, b.symbols...) - index := make(map[string]int) + index := make(map[string]uint64) for s, i := range b.index { - index[s] = i + index[s] = uint64(i) } return &lst{ @@ -389,3 +433,43 @@ func (b *symbolTableBuilder) Build() SymbolTable { index: index, } } + +// ProcessImports processes a slice of imports, returning an (augmented) copy, a set of +// offsets for each import, and the overall max ID. +func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []uint64, uint64) { + // Add in V1SystemSymbolTable at the head of the list if it's not already included. + var imps []SharedSymbolTable + if len(imports) > 0 && imports[0].Name() == "$ion" { + imps = make([]SharedSymbolTable, len(imports)) + copy(imps, imports) + } else { + imps = make([]SharedSymbolTable, len(imports)+1) + imps[0] = V1SystemSymbolTable + copy(imps[1:], imports) + } + + // Calculate offsets. + maxID := uint64(0) + offsets := make([]uint64, len(imps)) + for i, imp := range imps { + offsets[i] = maxID + maxID += imp.MaxID() + } + + return imps, offsets, maxID +} + +// BuildIndex builds an index from symbol name to symbol ID. +func buildIndex(symbols []string, offset uint64) map[string]uint64 { + index := make(map[string]uint64) + + for i, sym := range symbols { + if sym != "" { + if _, ok := index[sym]; !ok { + index[sym] = offset + uint64(i) + } + } + } + + return index +} diff --git a/symboltable_test.go b/symboltable_test.go index 57d50669..4fe1aaed 100644 --- a/symboltable_test.go +++ b/symboltable_test.go @@ -21,7 +21,7 @@ func TestSharedSymbolTable(t *testing.T) { if st.Version() != 2 { t.Errorf("wrong version: %v", st.Version()) } - if st.MaxID() != 5 { + if st.MaxID() != 6 { t.Errorf("wrong maxid: %v", st.MaxID()) } @@ -32,9 +32,9 @@ func TestSharedSymbolTable(t *testing.T) { testFindByID(t, st, 0, "") testFindByID(t, st, 2, "def") testFindByID(t, st, 4, "null") - testFindByID(t, st, 6, "") + testFindByID(t, st, 7, "") - testString(t, st, `$ion_shared_symbol_table::{name:"test",version:2,symbols:["abc","def","foo'bar","null","ghi"]}`) + testString(t, st, `$ion_shared_symbol_table::{name:"test",version:2,symbols:["abc","def","foo'bar","null","def","ghi"]}`) } func TestLocalSymbolTable(t *testing.T) { @@ -135,7 +135,7 @@ func TestSymbolTableBuilder(t *testing.T) { testFindByID(t, st, 11, "") } -func testFindByName(t *testing.T, st SymbolTable, sym string, expected int) { +func testFindByName(t *testing.T, st SymbolTable, sym string, expected uint64) { t.Run("FindByName("+sym+")", func(t *testing.T) { actual, ok := st.FindByName(sym) if expected == 0 { @@ -153,7 +153,7 @@ func testFindByName(t *testing.T, st SymbolTable, sym string, expected int) { }) } -func testFindByID(t *testing.T, st SymbolTable, id int, expected string) { +func testFindByID(t *testing.T, st SymbolTable, id uint64, expected string) { t.Run(fmt.Sprintf("FindByID(%v)", id), func(t *testing.T) { actual, ok := st.FindByID(id) if expected == "" { diff --git a/textreader.go b/textreader.go index 262eccea..d9cf4077 100644 --- a/textreader.go +++ b/textreader.go @@ -3,7 +3,6 @@ package ion import ( "bufio" "encoding/base64" - "errors" "fmt" "math" "strconv" @@ -132,7 +131,7 @@ func (t *textReader) nextAfterValue() (bool, error) { t.eof = true return true, nil } - return false, errors.New("ion: unexpected token '}'") + return false, &UnexpectedTokenError{"}", t.tok.Pos() - 1} case tokenCloseBracket: // No more values in this list. @@ -140,10 +139,10 @@ func (t *textReader) nextAfterValue() (bool, error) { t.eof = true return true, nil } - return false, errors.New("ion: unexpected token ']'") + return false, &UnexpectedTokenError{"]", t.tok.Pos() - 1} default: - return false, fmt.Errorf("ion: unexpected token '%v'", tok) + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} } } @@ -164,7 +163,7 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { return false, err } if tok == tokenSymbol { - if err := verifyUnquotedSymbol(val, "field name"); err != nil { + if err := t.verifyUnquotedSymbol(val, "field name"); err != nil { return false, err } } @@ -174,7 +173,7 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { return false, err } if tok = t.tok.Token(); tok != tokenColon { - return false, fmt.Errorf("ion: unexpected token '%v'", tok) + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} } t.fieldName = val @@ -183,7 +182,7 @@ func (t *textReader) nextBeforeFieldName() (bool, error) { return false, nil default: - return false, fmt.Errorf("ion: unexpected token '%v'", tok) + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} } } @@ -197,12 +196,12 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.eof = true return true, nil } - return false, errors.New("ion: unexpected EOF") + return false, &UnexpectedEOFError{t.tok.Pos() - 1} case tokenSymbolOperator, tokenDot: if t.ctx.peek() != ctxInSexp { // Operators can only appear inside an sexp. - return false, fmt.Errorf("ion: unexpected token '%v'", tok) + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} } fallthrough @@ -220,7 +219,7 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { if ok { // val was an annotation; remember it and keep going. if tok == tokenSymbol { - if err := verifyUnquotedSymbol(val, "annotation"); err != nil { + if err := t.verifyUnquotedSymbol(val, "annotation"); err != nil { return false, err } } @@ -287,7 +286,7 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.eof = true return true, nil } - return false, errors.New("ion: unexpected token ']'") + return false, &UnexpectedTokenError{"]", t.tok.Pos() - 1} case tokenCloseParen: // No more values in this sexp. @@ -295,10 +294,10 @@ func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { t.eof = true return true, nil } - return false, errors.New("ion: unexpected token ')'") + return false, &UnexpectedTokenError{")", t.tok.Pos() - 1} default: - return false, fmt.Errorf("ion: unexpected token '%v'", tok) + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} } } @@ -308,7 +307,7 @@ func (t *textReader) StepIn() error { return t.err } if t.state != trsBeforeContainer { - return errors.New("ion: StepIn called when not on a container") + return &UsageError{"Reader.StepIn", fmt.Sprintf("cannot step in to a %v", t.valueType)} } ctx := containerTypeToCtx(t.valueType) @@ -334,7 +333,7 @@ func (t *textReader) StepOut() error { ctx := t.ctx.peek() if ctx == ctxAtTopLevel { - return errors.New("ion: StepOut called at top level") + return &UsageError{"Reader.StepOut", "cannot step out of top-level datagram"} } ctype := ctxToContainerType(ctx) @@ -363,10 +362,10 @@ func (t *textReader) StepOut() error { // VerifyUnquotedSymbol checks for certain 'special' values that are returned from // the tokenizer as symbols but cannot be used as field names or annotations. -func verifyUnquotedSymbol(val string, ctx string) error { +func (t *textReader) verifyUnquotedSymbol(val string, ctx string) error { switch val { case "null", "true", "false", "nan": - return fmt.Errorf("ion: cannot use unquoted keyword %v as %v", val, ctx) + return &SyntaxError{fmt.Sprintf("unquoted keyword '%v' as %v", val, ctx), t.tok.Pos() - 1} } return nil } @@ -427,7 +426,8 @@ func (t *textReader) readNullType() (Type, error) { return NoType, err } if t.tok.Token() != tokenSymbol { - return NoType, fmt.Errorf("ion: invalid symbol null.%v", t.tok.Token()) + msg := fmt.Sprintf("invalid symbol null.%v", t.tok.Token()) + return NoType, &SyntaxError{msg, t.tok.Pos() - 1} } val, err := t.tok.ReadValue(tokenSymbol) @@ -463,7 +463,8 @@ func (t *textReader) readNullType() (Type, error) { case "sexp": return SexpType, nil default: - return NoType, fmt.Errorf("ion: invalid symbol null.%v", val) + msg := fmt.Sprintf("invalid symbol null.%v", t.tok.Token()) + return NoType, &SyntaxError{msg, t.tok.Pos() - 1} } } @@ -588,7 +589,7 @@ func (t *textReader) onLob() error { return err } if !ok { - return invalidChar(c) + return t.tok.invalidChar(c) } valType = ClobType diff --git a/textreader_test.go b/textreader_test.go index 629ac69d..68a84ba8 100644 --- a/textreader_test.go +++ b/textreader_test.go @@ -520,6 +520,33 @@ func _int64AF(t *testing.T, r Reader, efn string, etas []string, eval int64) { } } +func _uint(t *testing.T, r Reader, eval uint64) { + _uintAF(t, r, "", nil, eval) +} + +func _uintAF(t *testing.T, r Reader, efn string, etas []string, eval uint64) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != Uint64 { + t.Errorf("expected size=Uint, got %v", size) + } + + val, err := r.Uint64Value() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + func _bigInt(t *testing.T, r Reader, eval *big.Int) { _bigIntAF(t, r, "", nil, eval) } diff --git a/textutils.go b/textutils.go index 4b2a0e83..e4a03de4 100644 --- a/textutils.go +++ b/textutils.go @@ -16,10 +16,6 @@ func symbolNeedsQuoting(sym string) bool { return true } - if isSymbolRef(sym) { - return true - } - if !isIdentifierStart(int(sym[0])) { return true } @@ -382,5 +378,5 @@ func parseTimestamp(val string) (time.Time, error) { } func invalidTimestamp(val string) (time.Time, error) { - return time.Time{}, fmt.Errorf("invalid timestamp: %v", val) + return time.Time{}, fmt.Errorf("ion: invalid timestamp: %v", val) } diff --git a/textutils_test.go b/textutils_test.go index 988a4e71..67144a82 100644 --- a/textutils_test.go +++ b/textutils_test.go @@ -59,9 +59,9 @@ func TestWriteSymbol(t *testing.T) { test("basic", "basic") test("_basic_", "_basic_") test("$basic$", "$basic$") + test("$123", "$123") test("123", "'123'") - test("$123", "'$123'") test("abc'def", "'abc\\'def'") test("abc\"def", "'abc\"def'") } @@ -87,9 +87,9 @@ func TestSymbolNeedsQuoting(t *testing.T) { test("basic$123", false) test("$", false) test("$basic", false) + test("$123", false) test("123", true) - test("$123", true) test("abc.def", true) test("abc,def", true) test("abc:def", true) diff --git a/textwriter.go b/textwriter.go index f7648873..2ae82c1f 100644 --- a/textwriter.go +++ b/textwriter.go @@ -2,7 +2,6 @@ package ion import ( "encoding/base64" - "errors" "fmt" "io" "math/big" @@ -43,12 +42,12 @@ func NewTextWriterOpts(out io.Writer, opts TextWriterOpts) Writer { // WriteNull writes an untyped null. func (w *textWriter) WriteNull() error { - return w.WriteNullType(NoType) + return w.writeValue("Writer.WriteNull", textNulls[NoType]) } // WriteNullType writes a typed null. func (w *textWriter) WriteNullType(t Type) error { - return w.writeValue(textNulls[t]) + return w.writeValue("Writer.WriteNullType", textNulls[t]) } // WriteBool writes a boolean value. @@ -57,32 +56,37 @@ func (w *textWriter) WriteBool(val bool) error { if val { str = "true" } - return w.writeValue(str) + return w.writeValue("Writer.WriteBool", str) } // WriteInt writes an integer value. func (w *textWriter) WriteInt(val int64) error { - return w.writeValue(fmt.Sprintf("%d", val)) + return w.writeValue("Writer.WriteInt", fmt.Sprintf("%d", val)) +} + +// WriteUint writes an unsigned integer value. +func (w *textWriter) WriteUint(val uint64) error { + return w.writeValue("Writer.WriteUint", fmt.Sprintf("%d", val)) } // WriteBigInt writes a (big) integer value. func (w *textWriter) WriteBigInt(val *big.Int) error { - return w.writeValue(val.String()) + return w.writeValue("Writer.WriteBigInt", val.String()) } // WriteFloat writes a floating-point value. func (w *textWriter) WriteFloat(val float64) error { - return w.writeValue(formatFloat(val)) + return w.writeValue("Writer.WriteFloat", formatFloat(val)) } // WriteDecimal writes an arbitrary-precision decimal value. func (w *textWriter) WriteDecimal(val *Decimal) error { - return w.writeValue(val.String()) + return w.writeValue("Writer.WriteDecimal", val.String()) } // WriteTimestamp writes a timestamp. func (w *textWriter) WriteTimestamp(val time.Time) error { - return w.writeValue(val.Format(time.RFC3339Nano)) + return w.writeValue("Writer.WriteTimestamp", val.Format(time.RFC3339Nano)) } // WriteSymbol writes a symbol. @@ -90,7 +94,7 @@ func (w *textWriter) WriteSymbol(val string) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteSymbol"); w.err != nil { return w.err } @@ -107,7 +111,7 @@ func (w *textWriter) WriteString(val string) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteString"); w.err != nil { return w.err } @@ -130,7 +134,7 @@ func (w *textWriter) WriteClob(val []byte) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { return w.err } @@ -161,7 +165,7 @@ func (w *textWriter) WriteBlob(val []byte) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { return w.err } @@ -186,7 +190,7 @@ func (w *textWriter) WriteBlob(val []byte) error { // BeginList begins writing a list. func (w *textWriter) BeginList() error { if w.err == nil { - w.err = w.begin(ctxInList, '[') + w.err = w.begin("Writer.BeginList", ctxInList, '[') } return w.err } @@ -194,7 +198,7 @@ func (w *textWriter) BeginList() error { // EndList finishes writing a list. func (w *textWriter) EndList() error { if w.err == nil { - w.err = w.end(ctxInList, ']') + w.err = w.end("Writer.EndList", ctxInList, ']') } return w.err } @@ -202,7 +206,7 @@ func (w *textWriter) EndList() error { // BeginSexp begins writing an s-expression. func (w *textWriter) BeginSexp() error { if w.err == nil { - w.err = w.begin(ctxInSexp, '(') + w.err = w.begin("Writer.BeginSexp", ctxInSexp, '(') } return w.err } @@ -210,7 +214,7 @@ func (w *textWriter) BeginSexp() error { // EndSexp finishes writing an s-expression. func (w *textWriter) EndSexp() error { if w.err == nil { - w.err = w.end(ctxInSexp, ')') + w.err = w.end("Writer.EndSexp", ctxInSexp, ')') } return w.err } @@ -218,7 +222,7 @@ func (w *textWriter) EndSexp() error { // BeginStruct begins writing a struct. func (w *textWriter) BeginStruct() error { if w.err == nil { - w.err = w.begin(ctxInStruct, '{') + w.err = w.begin("Writer.BeginStruct", ctxInStruct, '{') } return w.err } @@ -226,7 +230,7 @@ func (w *textWriter) BeginStruct() error { // EndStruct finishes writing a struct. func (w *textWriter) EndStruct() error { if w.err == nil { - w.err = w.end(ctxInStruct, '}') + w.err = w.end("Writer.EndStruct", ctxInStruct, '}') } return w.err } @@ -237,7 +241,7 @@ func (w *textWriter) Finish() error { return w.err } if w.ctx.peek() != ctxAtTopLevel { - return errors.New("ion: Finish not at top level") + return &UsageError{"Writer.Finish", "not at top level"} } if w.opts&TextWriterQuietFinish == 0 { @@ -252,11 +256,11 @@ func (w *textWriter) Finish() error { } // writeValue writes a stringified value to the output stream. -func (w *textWriter) writeValue(val string) error { +func (w *textWriter) writeValue(api string, val string) error { if w.err != nil { return w.err } - if w.err = w.beginValue(); w.err != nil { + if w.err = w.beginValue(api); w.err != nil { return w.err } @@ -271,7 +275,7 @@ func (w *textWriter) writeValue(val string) error { // beginValue begins the process of writing a value, by writing out // a separator (if needed), field name (if in a struct), and type // annotations (if any). -func (w *textWriter) beginValue() error { +func (w *textWriter) beginValue(api string) error { if w.needsSeparator { var sep byte switch w.ctx.peek() { @@ -290,7 +294,7 @@ func (w *textWriter) beginValue() error { if w.inStruct() { if w.fieldName == "" { - return errors.New("ion: field name not set") + return &UsageError{api, "field name not set"} } name := w.fieldName w.fieldName = "" @@ -326,8 +330,8 @@ func (w *textWriter) endValue() { } // begin starts writing a container of the given type. -func (w *textWriter) begin(t ctx, c byte) error { - if err := w.beginValue(); err != nil { +func (w *textWriter) begin(api string, t ctx, c byte) error { + if err := w.beginValue(api); err != nil { return err } @@ -338,9 +342,9 @@ func (w *textWriter) begin(t ctx, c byte) error { } // end finishes writing a container of the given type -func (w *textWriter) end(t ctx, c byte) error { +func (w *textWriter) end(api string, t ctx, c byte) error { if w.ctx.peek() != t { - return errors.New("ion: End called with wrong container type") + return &UsageError{api, "not in that kind of container"} } if err := writeRawChar(c, w.out); err != nil { diff --git a/textwriter_test.go b/textwriter_test.go index a2a6c329..b3124c32 100644 --- a/textwriter_test.go +++ b/textwriter_test.go @@ -245,7 +245,7 @@ func TestWriteTextTimestamp(t *testing.T) { } func TestWriteTextSymbol(t *testing.T) { - expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸'}" + expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸',$123:$456}" testTextWriter(t, expected, func(w Writer) { w.BeginStruct() @@ -262,6 +262,9 @@ func TestWriteTextSymbol(t *testing.T) { w.Annotation("u") w.WriteSymbol("lo🇺🇸") + w.FieldName("$123") + w.WriteSymbol("$456") + w.EndStruct() }) } diff --git a/tokenizer.go b/tokenizer.go index f545328f..3413de7f 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -3,7 +3,6 @@ package ion import ( "bufio" "bytes" - "errors" "fmt" "io" "strings" @@ -115,6 +114,7 @@ type tokenizer struct { token token unfinished bool + pos uint64 } func tokenizeString(in string) *tokenizer { @@ -136,6 +136,10 @@ func (t *tokenizer) Token() token { return t.token } +func (t *tokenizer) Pos() uint64 { + return t.pos +} + // Next advances to the next token in the input stream. func (t *tokenizer) Next() error { var c int @@ -241,7 +245,7 @@ func (t *tokenizer) Next() error { } if tt == tokenTimestamp { // can't have negative timestamps. - return invalidChar(c2) + return t.invalidChar(c2) } t.unread(c2) t.unread(c) @@ -280,7 +284,7 @@ func (t *tokenizer) Next() error { return t.ok(tt, true) default: - return invalidChar(c) + return t.invalidChar(c) } } @@ -375,7 +379,7 @@ func (t *tokenizer) ReadNumber() (string, Type, error) { if first == '0' { if w.Len()-oldlen > 1 { - return "", NoType, errors.New("invalid leading zeroes") + return "", NoType, &SyntaxError{"invalid leading zeroes", t.pos - 1} } } @@ -416,7 +420,7 @@ func (t *tokenizer) ReadNumber() (string, Type, error) { return "", NoType, err } if !ok { - return "", NoType, invalidChar(c) + return "", NoType, t.invalidChar(c) } t.unread(c) @@ -481,7 +485,7 @@ func (t *tokenizer) readQuotedSymbol() (string, error) { switch c { case -1, '\n': - return "", invalidChar(c) + return "", t.invalidChar(c) case '\'': return ret.String(), nil @@ -541,7 +545,7 @@ func (t *tokenizer) readString() (string, error) { switch c { case -1, '\n': - return "", invalidChar(c) + return "", t.invalidChar(c) case '"': return ret.String(), nil @@ -581,7 +585,7 @@ func (t *tokenizer) readLongString() (string, error) { switch c { case -1: - return "", invalidChar(c) + return "", t.invalidChar(c) case '\'': ok, err := t.skipEndOfLongString(t.skipCommentsHandler) @@ -652,7 +656,7 @@ func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { return '\\', nil case 'U': if clob { - return 0, invalidChar('U') + return 0, t.invalidChar('U') } return t.readHexEscapeSeq(8) case 'u': @@ -661,7 +665,7 @@ func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { return t.readHexEscapeSeq(2) } - return 0, fmt.Errorf("bad escape sequence '\\%c'", c) + return 0, &SyntaxError{fmt.Sprintf("bad escape sequence '\\%c'", c), t.pos - 2} } func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { @@ -673,7 +677,7 @@ func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { return 0, err } - d, err := fromHex(c) + d, err := t.fromHex(c) if err != nil { return 0, err } @@ -685,7 +689,7 @@ func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { return val, nil } -func fromHex(c int) (int, error) { +func (t *tokenizer) fromHex(c int) (int, error) { if c >= '0' && c <= '9' { return c - '0', nil } @@ -695,7 +699,7 @@ func fromHex(c int) (int, error) { if c >= 'A' && c <= 'F' { return 10 + (c - 'A'), nil } - return 0, invalidChar(c) + return 0, t.invalidChar(c) } func (t *tokenizer) readBinary() (string, error) { @@ -732,7 +736,7 @@ func (t *tokenizer) readRadix(pok, dok matcher) (string, error) { } if c != '0' { - return "", invalidChar(c) + return "", t.invalidChar(c) } w.WriteByte('0') @@ -741,7 +745,7 @@ func (t *tokenizer) readRadix(pok, dok matcher) (string, error) { return "", err } if !pok(c) { - return "", invalidChar(c) + return "", t.invalidChar(c) } w.WriteByte(byte(c)) @@ -755,7 +759,7 @@ func (t *tokenizer) readRadix(pok, dok matcher) (string, error) { return "", err } if !ok { - return "", invalidChar(c) + return "", t.invalidChar(c) } t.unread(c) @@ -794,7 +798,7 @@ func (t *tokenizer) readTimestamp() (string, error) { return w.String(), nil } if c != '-' { - return "", invalidChar(c) + return "", t.invalidChar(c) } w.WriteByte('-') @@ -807,7 +811,7 @@ func (t *tokenizer) readTimestamp() (string, error) { return w.String(), nil } if c != '-' { - return "", invalidChar(c) + return "", t.invalidChar(c) } w.WriteByte('-') @@ -836,7 +840,7 @@ func (t *tokenizer) readTimestamp() (string, error) { return "", err } if c != ':' { - return "", invalidChar(c) + return "", t.invalidChar(c) } w.WriteByte(':') @@ -888,7 +892,7 @@ func (t *tokenizer) readTimestampOffsetOrZ(c int, w io.ByteWriter) (int, error) w.WriteByte(byte(c)) return t.read() } - return 0, invalidChar(c) + return 0, t.invalidChar(c) } func (t *tokenizer) readTimestampOffset(c int, w io.ByteWriter) (int, error) { @@ -902,7 +906,7 @@ func (t *tokenizer) readTimestampOffset(c int, w io.ByteWriter) (int, error) { return 0, err } if c != ':' { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } w.WriteByte(':') return t.readTimestampDigits(2, w) @@ -915,7 +919,7 @@ func (t *tokenizer) readTimestampDigits(n int, w io.ByteWriter) (int, error) { return 0, err } if !isDigit(c) { - return 0, invalidChar(c) + return 0, t.invalidChar(c) } w.WriteByte(byte(c)) n-- @@ -929,7 +933,7 @@ func (t *tokenizer) readTimestampFinish(c int, w fmt.Stringer) (string, error) { return "", err } if !ok { - return "", invalidChar(c) + return "", t.invalidChar(c) } t.unread(c) return w.String(), nil @@ -948,7 +952,7 @@ func (t *tokenizer) ReadBlob() (string, error) { return "", err } if c == -1 { - return "", invalidChar(c) + return "", t.invalidChar(c) } if c == '}' { break @@ -960,7 +964,7 @@ func (t *tokenizer) ReadBlob() (string, error) { return "", err } if c != '}' { - return "", invalidChar(c) + return "", t.invalidChar(c) } t.unfinished = false @@ -978,14 +982,14 @@ func (t *tokenizer) ReadShortClob() (string, error) { return "", err } if c != '}' { - return "", invalidChar(c) + return "", t.invalidChar(c) } if c, err = t.read(); err != nil { return "", err } if c != '}' { - return "", invalidChar(c) + return "", t.invalidChar(c) } t.unfinished = false @@ -1003,14 +1007,14 @@ func (t *tokenizer) ReadLongClob() (string, error) { return "", err } if c != '}' { - return "", invalidChar(c) + return "", t.invalidChar(c) } if c, err = t.read(); err != nil { return "", err } if c != '}' { - return "", invalidChar(c) + return "", t.invalidChar(c) } t.unfinished = false @@ -1137,18 +1141,18 @@ func (t *tokenizer) expect(f matcher) error { return err } if !f(c) { - return invalidChar(c) + return t.invalidChar(c) } return nil } // InvalidChar returns an error complaining that the given character was // unexpected. -func invalidChar(c int) error { +func (t *tokenizer) invalidChar(c int) error { if c == -1 { - return errors.New("unexpected EOF") + return &UnexpectedEOFError{t.pos - 1} } - return fmt.Errorf("unexpected char %q", c) + return &UnexpectedRuneError{rune(c), t.pos - 1} } // SkipN skips over the next n bytes of input. Presumably you've @@ -1220,6 +1224,7 @@ func (t *tokenizer) peek() (int, error) { // returned as (-1, nil) rather than (0, io.EOF), because I find it // easier to reason about that way. Newlines are normalized to '\n'. func (t *tokenizer) read() (int, error) { + t.pos++ if len(t.buffer) > 0 { // We've already peeked ahead; read from our buffer. c := t.buffer[len(t.buffer)-1] @@ -1232,7 +1237,7 @@ func (t *tokenizer) read() (int, error) { return -1, nil } if err != nil { - return 0, err + return 0, &IOError{err} } // Normalize \r and \r\n to just \n. @@ -1240,7 +1245,7 @@ func (t *tokenizer) read() (int, error) { cs, err := t.in.Peek(1) if err != nil && err != io.EOF { // Not EOF, because we haven't dealt with the '\r' yet. - return 0, err + return 0, &IOError{err} } if len(cs) > 0 && cs[0] == '\n' { // Skip over the '\n' as well. @@ -1255,5 +1260,6 @@ func (t *tokenizer) read() (int, error) { // Unread pushes a character (or -1) back into the input stream to // be read again later. func (t *tokenizer) unread(c int) { + t.pos-- t.buffer = append(t.buffer, c) } diff --git a/type.go b/type.go index 0371a818..f0165909 100644 --- a/type.go +++ b/type.go @@ -100,6 +100,8 @@ const ( Int32 // Int64 is the size of an Ion integer that can be losslessly stored in an int64. Int64 + // Uint64 is the size of an Ion integer that can be losslessly stored in a uint64. + Uint64 // BigInt is the size of an Ion integer that can only be losslessly stored in a big.Int. BigInt ) @@ -113,6 +115,8 @@ func (i IntSize) String() string { return "int32" case Int64: return "int64" + case Uint64: + return "uint64" case BigInt: return "big.Int" default: diff --git a/writer.go b/writer.go index 232db8d1..2a85216a 100644 --- a/writer.go +++ b/writer.go @@ -71,6 +71,8 @@ type Writer interface { // WriteInt writes an integer value. WriteInt(val int64) error + // WriteUint writes an unsigned integer value. + WriteUint(val uint64) error // WriteBigInt writes a big integer value. WriteBigInt(val *big.Int) error // WriteFloat writes a floating-point value. From 97c266013167a8cf2b634cbaed3f723bfaea1ced Mon Sep 17 00:00:00 2001 From: David Murray Date: Wed, 6 May 2020 21:02:25 -0500 Subject: [PATCH 40/56] Point at amzn/ion-go, which exists now (#1) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 525d615b..bf8f470c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ # Ion Go A Golang implementation of Amazon's [Ion data notation](https://amzn.github.io/ion-docs/). +| ❗ | You should probably use [amzn/ion-go](https://github.com/amzn/ion-go) now instead of this 😊 | +|---|----| + ## Using the Library Import `github.com/fernomac/ion-go` and you're off to the races. From 823a1994df95679a23763c8af94badf7b58a9864 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 19:08:12 -0700 Subject: [PATCH 41/56] Removed older code --- internal/lex/lex.go | 344 ---------- internal/lex/lex_binary.go | 272 -------- internal/lex/lex_container.go | 79 --- internal/lex/lex_numeric.go | 332 ---------- internal/lex/lex_strings.go | 314 ---------- internal/lex/lex_test.go | 753 ---------------------- internal/lex/lexitem.go | 138 ---- internal/lex/lexitem_test.go | 58 -- ion/cmp_test.go | 253 -------- ion/doc.go | 29 - ion/parse_binary.go | 380 ----------- ion/parse_binary_container.go | 183 ------ ion/parse_binary_numeric.go | 337 ---------- ion/parse_binary_simple.go | 142 ----- ion/parse_binary_test.go | 1109 --------------------------------- ion/parse_text.go | 202 ------ ion/parse_text_container.go | 125 ---- ion/parse_text_numeric.go | 44 -- ion/parse_text_simple.go | 327 ---------- ion/parse_text_test.go | 555 ----------------- ion/symbols_system.go | 66 -- ion/symbols_table.go | 127 ---- ion/symbols_table_test.go | 270 -------- ion/symbols_token.go | 74 --- ion/symbols_token_test.go | 69 -- ion/types.go | 237 ------- ion/types_binary.go | 144 ----- ion/types_character.go | 107 ---- ion/types_container.go | 181 ------ ion/types_numeric.go | 347 ----------- ion/types_test.go | 110 ---- ion/types_timestamp.go | 286 --------- 32 files changed, 7994 deletions(-) delete mode 100644 internal/lex/lex.go delete mode 100644 internal/lex/lex_binary.go delete mode 100644 internal/lex/lex_container.go delete mode 100644 internal/lex/lex_numeric.go delete mode 100644 internal/lex/lex_strings.go delete mode 100644 internal/lex/lex_test.go delete mode 100644 internal/lex/lexitem.go delete mode 100644 internal/lex/lexitem_test.go delete mode 100644 ion/cmp_test.go delete mode 100644 ion/doc.go delete mode 100644 ion/parse_binary.go delete mode 100644 ion/parse_binary_container.go delete mode 100644 ion/parse_binary_numeric.go delete mode 100644 ion/parse_binary_simple.go delete mode 100644 ion/parse_binary_test.go delete mode 100644 ion/parse_text.go delete mode 100644 ion/parse_text_container.go delete mode 100644 ion/parse_text_numeric.go delete mode 100644 ion/parse_text_simple.go delete mode 100644 ion/parse_text_test.go delete mode 100644 ion/symbols_system.go delete mode 100644 ion/symbols_table.go delete mode 100644 ion/symbols_table_test.go delete mode 100644 ion/symbols_token.go delete mode 100644 ion/symbols_token_test.go delete mode 100644 ion/types.go delete mode 100644 ion/types_binary.go delete mode 100644 ion/types_character.go delete mode 100644 ion/types_container.go delete mode 100644 ion/types_numeric.go delete mode 100644 ion/types_test.go delete mode 100644 ion/types_timestamp.go diff --git a/internal/lex/lex.go b/internal/lex/lex.go deleted file mode 100644 index 2df8d01b..00000000 --- a/internal/lex/lex.go +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "bytes" - "fmt" - "strings" - "unicode/utf8" -) - -const ( - eof = -1 - - // Dot is not included and must be checked individually. - operatorRunes = "!#%&*+-/;<=>?@^`|~" - - // \v is a vertical tab - whitespaceRunes = " \t\n\r\f\v" -) - -// The state of the scanner as a function that returns the next state. -type stateFn func(*Lexer) stateFn - -// Lexer represents the state of scanning the input text. -type Lexer struct { - input []byte // the data being scanned - state stateFn // the next lexing function to enter - pos int // current position in the input - itemStart int // start position of the current item - width int // width of last rune read from input - lastPos int // position of most recent item returned by NextItem - items chan Item // channel of scanned items - containers []byte // keep track of container starts and ends -} - -// New creates a new scanner for the input data. This is the lexing half of the -// Lexer / parser. Basic validation is done in the Lexer for a loose sense of -// correctness, but the rigid correctness is enforced in the parser. -func New(input []byte) *Lexer { - x := &Lexer{ - input: input, - items: make(chan Item), - } - go x.run() - return x -} - -// NextItem returns the next item from the input. -func (x *Lexer) NextItem() Item { - item := <-x.items - x.lastPos = item.Pos - return item -} - -// LineNumber returns the line number that the Lexer last stopped at. -func (x *Lexer) LineNumber() int { - // Count the number of newlines, then add 1 for the line we're currently on. - return bytes.Count(x.input[:x.lastPos], []byte("\n")) + 1 -} - -// run the state machine for the Lexer. -func (x *Lexer) run() { - for x.state = lexValue; x.state != nil; { - x.state = x.state(x) - } -} - -// next returns the next rune in the input. If there is a problem decoding -// the rune, then utf8.RuneError is returned. -func (x *Lexer) next() rune { - if x.pos >= len(x.input) { - x.width = 0 - return eof - } - r, w := utf8.DecodeRune(x.input[x.pos:]) - x.width = w - x.pos += x.width - return r -} - -// peek returns, but does not consume, the next rune from the input. -func (x *Lexer) peek() rune { - if x.pos >= len(x.input) { - return eof - } - r, _ := utf8.DecodeRune(x.input[x.pos:]) - return r -} - -// backup steps back one rune. Can only be called once per call of next(). -func (x *Lexer) backup() { - x.pos -= x.width -} - -// emit sends an item representing the current Lexer state and the given type -// onto the items channel. -func (x *Lexer) emit(it itemType) { - x.items <- Item{Type: it, Pos: x.itemStart, Val: x.input[x.itemStart:x.pos]} - x.itemStart = x.pos -} - -// ignore sets the itemStart point to the current position, thereby "ignoring" any -// input between the two points. -func (x *Lexer) ignore() { - x.itemStart = x.pos -} - -// emitAndIgnoreTripleQuoteEnd backs up three spots (a triple quote), emits the given -// itemType, then goes forward three spots to ignore the triple quote -func (x *Lexer) emitAndIgnoreTripleQuoteEnd(itemType itemType) { - x.width = 1 - x.backup() - x.backup() - x.backup() - x.emit(itemType) - - x.next() - x.next() - x.next() - x.ignore() -} - -// errorf emits an error token and returns nil to stop lexing. -func (x *Lexer) errorf(format string, args ...interface{}) stateFn { - x.items <- Item{Type: IonError, Pos: x.itemStart, Val: []byte(fmt.Sprintf(format, args...))} - return nil -} - -// error emits an error token and returns nil to stop lexing. -func (x *Lexer) error(message string) stateFn { - x.items <- Item{Type: IonError, Pos: x.itemStart, Val: []byte(message)} - return nil -} - -// lexValue scans for a value, which can be an annotation, number, symbol, list, -// struct, or s-expression. -func lexValue(x *Lexer) stateFn { - switch ch := x.next(); { - case ch == eof: - x.emit(IonEOF) - return nil - case isWhitespace(ch): - x.ignore() - return lexValue - case ch == ':': - return lexColons - case ch == '\'': - return lexSingleQuote - case ch == '"': - return lexString - case ch == ',': - x.emit(IonComma) - return lexValue - case ch == '[': - return lexList - case ch == ']': - return lexListEnd - case ch == '(': - return lexSExp - case ch == ')': - return lexSExpEnd - case ch == '{': - if x.peek() == '{' { - x.next() - return lexBinary - } - return lexStruct - case ch == '}': - return lexStructEnd - case ch == '/': - // Comment handling needs to come before operator handling because the - // start of a comment is also an operator. Treat it as an operator if - // the following character doesn't adhere to one of the comment standards. - switch x.peek() { - case '/': - x.next() - return lexLineComment - case '*': - x.next() - return lexBlockComment - } - x.emit(IonOperator) - return lexValue - case isOperator(ch) || ch == '.': - // - is both an operator and a signal that a number is starting. Since - // infinity is represented as +inf or -inf, we need to take that into - // account as well. - if (ch == '+' && x.peek() == 'i') || (ch == '-' && (isNumber(x.peek()) || x.peek() == 'i' || x.peek() == '_')) { - x.backup() - return lexNumber - } - // An operator can consist of multiple characters. - for next := x.peek(); isOperator(next); next = x.peek() { - x.next() - } - x.emit(IonOperator) - return lexValue - case isIdentifierSymbolStart(ch): - x.backup() - return lexSymbol - case isNumericStart(ch): - x.backup() - return lexNumber - default: - return x.errorf("invalid start of a value: %#U", ch) - } -} - -// lexColons expects one colon to be scanned and checks to see if there is -// a second before emitting. Returns lexValue. -func lexColons(x *Lexer) stateFn { - if x.peek() == ':' { - x.next() - x.emit(IonDoubleColon) - } else { - x.emit(IonColon) - } - - return lexValue -} - -// lexLineComment scans a comment while parsing values. The comment is -// terminated by a newline. lexValue is returned. -func lexLineComment(x *Lexer) stateFn { - // Ignore the preceding "//" characters. - x.ignore() - for { - ch := x.next() - if ch == utf8.RuneError { - return x.error("error parsing rune") - } - if isEndOfLine(ch) || ch == eof { - x.backup() - break - } - } - x.emit(IonCommentLine) - return lexValue -} - -// lexBlockComment scans a block comment. The comment is terminated by */ -// lexTopLevel is returned since we don't know what is going to come next. -func lexBlockComment(x *Lexer) stateFn { - // Ignore the preceding "/*" characters. - x.ignore() - for { - ch := x.next() - if ch == eof { - return x.error("unexpected end of file while lexing block comment") - } - if ch == utf8.RuneError { - return x.error("error parsing rune") - } - if ch == '*' && x.peek() == '/' { - x.backup() - break - } - } - - x.emit(IonCommentBlock) - // Ignore the trailing "*/" characters. - x.next() - x.next() - x.ignore() - - return lexValue -} - -// eatWhitespace eats up all of the text until a non-whitespace character is encountered. -func eatWhitespace(x *Lexer) { - for isWhitespace(x.peek()) { - x.next() - } - x.ignore() -} - -// isWhitespace returns if the given rune is considered to be a form of whitespace. -func isWhitespace(ch rune) bool { - return bytes.ContainsRune([]byte(whitespaceRunes), ch) -} - -// isEndOfLine returns true if the given rune is an end-of-line character. -func isEndOfLine(ch rune) bool { - return ch == '\r' || ch == '\n' -} - -// isOperator returns true if the given rune is one of the operator chars (not including dot). -func isOperator(ch rune) bool { - return bytes.ContainsRune([]byte(operatorRunes), ch) -} - -// accept consumes the next rune if it's from the given set of valid runes. -func (x *Lexer) accept(valid string) bool { - if strings.IndexRune(valid, x.peek()) >= 0 { - x.next() - return true - } - return false -} - -// acceptString consumes the as many runes from the given string as possible. -// If it hits a rune it can't accept, then it backs up and returns false. -func (x *Lexer) acceptString(valid string) bool { - for _, ch := range valid { - if x.peek() != ch { - return false - } - x.next() - } - return true -} - -// acceptRun consumes as many runes as possible from the given set set of valid runes. -// Stops at either an unacceptable rune, EOF, or if any of the noRepeat runes are encountered -// twice consecutively. -func (x *Lexer) acceptRun(valid string, noRepeat string) int { - inRepeat := false - count := 0 - // Use peek so that we can still back up if the rune we fail on is EOF. - for ch := x.peek(); strings.IndexRune(valid, ch) >= 0; ch = x.peek() { - x.next() - count++ - isRepeatRune := strings.IndexRune(noRepeat, ch) >= 0 - if isRepeatRune && inRepeat { - break - } - inRepeat = isRepeatRune - } - return count -} diff --git a/internal/lex/lex_binary.go b/internal/lex/lex_binary.go deleted file mode 100644 index 0dcf6571..00000000 --- a/internal/lex/lex_binary.go +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "bytes" -) - -// This file contains the state functions for lexing Blobs and Clobs. - -// lexBinary emits IonBinaryStart, determines if the contained text is a Blob or Clob -// and emits the corresponding function. -func lexBinary(x *Lexer) stateFn { - x.emit(IonBinaryStart) - eatWhitespace(x) - // We can tell the difference between a Blob and a Clob by the presence of an - // opening quote character. - if ch := x.peek(); ch == '\'' || ch == '"' { - return lexClob - } - return lexBlob -} - -// lexBlob reads the base64-encoded blob, keeping all whitespace. If the blob does -// not have the correct number of padding characters, then an error is returned. -func lexBlob(x *Lexer) stateFn { - eqCount := 0 - charCount := 0 - -Loop: - for { - switch ch := x.next(); { - case ch == eof: - return x.error("unterminated blob") - case ch == '}': - if ch = x.peek(); ch != '}' { - return x.errorf("invalid end to blob, expected } but found: %#U", ch) - } - // We can back up again since we know for a fact that the previous - // character is the same width as the character we just peeked. - x.backup() - break Loop - case isBlobText(ch): - if eqCount > 0 { - return x.error("base64 character found after padding character") - } - charCount++ - case ch == '=': - eqCount++ - case isWhitespace(ch): - default: - return x.errorf("invalid rune as part of blob string: %#U", ch) - } - } - - if eqCount > 2 { - return x.error("too much padding for base64 encoding") - } - - if (charCount+eqCount)%4 != 0 { - return x.error("invalid base64 encoding") - } - - x.emit(IonBlob) - x.next() - x.next() - x.emit(IonBinaryEnd) - - return lexValue -} - -// lexClob determines whether the clob uses the long or short string format then -// returns the corresponding stateFn. -func lexClob(x *Lexer) stateFn { - // Determine if we are a short or long string. If we are a long string then - // we need to keep looking for long strings before we end. - if ch := x.next(); ch == '"' { - return lexClobShort - } - x.backup() - return lexClobLong -} - -// lexClobShort consumes Clob text between double-quotes, similar to a String but is -// limited in what characters are legal. -func lexClobShort(x *Lexer) stateFn { - // Ignore the opening quote. - x.ignore() - -Loop: - for { - switch ch := x.next(); { - case isClobShortText(ch): - case ch == '\\': - switch r := x.next(); { - case r == eof: - return x.error("unterminated short clob") - case !isEscapeAble(r): - return x.errorf("invalid character after escape: %#U", r) - case r == '\r': - // check for CR LF - if x.peek() == '\n' { - x.next() - } - case r == 'x' || r == 'X': - // If what is being escaped is a hex character, then we still - // need to make sure that escaped character is allowed. - if !bytes.ContainsRune([]byte(hexDigits), x.next()) || !bytes.ContainsRune([]byte(hexDigits), x.next()) { - x.backup() - return x.errorf("invalid character as part of hex escape: %#U", x.peek()) - } - case r == 'u' || r == 'U': - return x.error("unicode escape is not valid in clob") - } - case isEndOfLine(ch) || ch == eof: - return x.error("unterminated short clob") - case ch == '"': - x.backup() - break Loop - default: - return x.errorf("invalid rune as part of short clob string: %#U", ch) - } - } - - x.emit(IonClobShort) - // Ignore the closing quote. - x.next() - x.ignore() - - eatWhitespace(x) - if ch := x.next(); ch != '}' { - return x.errorf("invalid end to short clob, expected } but found: %#U", ch) - } - if ch := x.next(); ch != '}' { - return x.errorf("invalid end to short clob, expected second } but found: %#U", ch) - } - - x.emit(IonBinaryEnd) - return lexValue -} - -// lexClobLong consumes Clob text between one or more sets of triple single-quotes, similar -// to a Long String but is limited in what characters are legal. -func lexClobLong(x *Lexer) stateFn { - // emitSingleLongClob returns true as long as it is able to process a long - // string. It returns false if it cannot, e.g. it encounters the end of the - // Clob. If it encounters an error, then a state function is returned. - for lexed, errFn := emitSingleLongClob(x); lexed || errFn != nil; lexed, errFn = emitSingleLongClob(x) { - if errFn != nil { - return errFn - } - } - - // emitSingleLongClob returns an error if we have an invalid ending of the - // Clob, so we can safely eat the ending "}}". - x.next() - x.next() - x.emit(IonBinaryEnd) - - return lexValue -} - -// emitSingleLongClob eats optional whitespace then expects either the end of a Clob -// or the opening of a long string. If we encountered the end of a Clob, then false and a nil -// state function are returned. We then read the Clob version of a long string and return -// true and a nil state function if through the end of the long string is read without issue. -// If there is an issue, then false and an error state function are returned. -func emitSingleLongClob(x *Lexer) (bool, stateFn) { - // eat any whitespace. - eatWhitespace(x) - - // If the next character is the closing of a binary blob, then check it out - // and consume it if it is. - if x.peek() == '}' { - x.next() - if ch := x.peek(); ch != '}' { - return false, x.errorf("expected a second } but found: %#U", ch) - } - // We can back up again since we know for a fact that the previous - // character is the same width as the character we just peeked. - x.backup() - return false, nil - } - - // Ensure that we have three single quotes to start the clob, then ignore them. - if x.next() != '\'' || x.next() != '\'' || x.next() != '\'' { - x.backup() - return false, x.errorf("expected end of a Clob or start of a long string but found: %#U", x.next()) - } - x.ignore() - - for { - switch ch := x.next(); { - case isClobLongText(ch): - case ch == '\\': - // Eat whatever is after the escape character unless it's an EOF. - if r := x.next(); r == eof { - return false, x.error("unterminated long clob") - } - case ch == eof: - return false, x.error("unterminated long clob") - case ch == '\'': - count := 1 - for next := x.next(); next == '\''; next = x.next() { - count++ - } - // We have reached a character after the end of our long string, which is - // three single quotes. Need to back up over both that character and the - // three single quotes, emit the long string, then eat the single quotes. - if count >= 3 { - x.backup() - x.emitAndIgnoreTripleQuoteEnd(IonClobLong) - - return true, nil - } - default: - return false, x.errorf("invalid rune as part of long clob string: %#U", ch) - } - } - -} - -// isBlobText returns if the given rune is a valid Blob rune. Note that whitespace is not included -// since there are rules for when certain non-blob-text characters can occur. -func isBlobText(ch rune) bool { - return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '+' || ch == '/' -} - -// isClobLongText returns if the given rune is a valid part of a long-quoted Clob. -// CLOB_LONG_TEXT_ALLOWED -// : '\u0020'..'\u0026' // no U+0027 single quote -// | '\u0028'..'\u005B' // no U+005C blackslash -// | '\u005D'..'\u007F' -// | WS -// ; -func isClobLongText(ch rune) bool { - return (ch >= 0x0020 && ch <= 0x0026) || (ch >= 0x0028 && ch <= 0x005B) || (ch >= 0x005D && ch <= 0x07F) || isWhitespace(ch) -} - -// isClobShortText returns if the given rune is a valid part of a short-quoted Clob. -// CLOB_SHORT_TEXT_ALLOWED -// : '\u0020'..'\u0021' // no U+0022 double quote -// | '\u0023'..'\u005B' // no U+005C backslash -// | '\u005D'..'\u007F' -// | WS_NOT_NL -// ; -func isClobShortText(ch rune) bool { - return (ch >= 0x0020 && ch <= 0x0021) || (ch >= 0x0023 && ch <= 0x005B) || (ch >= 0x005D && ch <= 0x007F) || isSpace(ch) -} - -// isSpace returns if the given rune is a valid space character (not newline). -// WS_NOT_NL -// : '\u0009' // tab -// | '\u000B' // vertical tab -// | '\u000C' // form feed -// | '\u0020' // space -func isSpace(ch rune) bool { - return ch == 0x09 || ch == 0x0B || ch == 0x0C || ch == 0x20 -} diff --git a/internal/lex/lex_container.go b/internal/lex/lex_container.go deleted file mode 100644 index e0500b3e..00000000 --- a/internal/lex/lex_container.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -// This file contains the state functions for container types: -// List, Struct, and S-Expression. - -// lexList emits the start of a list. -func lexList(x *Lexer) stateFn { - x.emit(IonListStart) - x.containers = append(x.containers, '[') - return lexValue -} - -// lexListEnd emits the end of a list. -func lexListEnd(x *Lexer) stateFn { - return containerEnd(x, IonListEnd) -} - -// lexSExp emits the start of an s-expression. -func lexSExp(x *Lexer) stateFn { - x.emit(IonSExpStart) - x.containers = append(x.containers, '(') - return lexValue -} - -// lexSExpEnd emits the end of an s-expression. -func lexSExpEnd(x *Lexer) stateFn { - return containerEnd(x, IonSExpEnd) -} - -// lexStruct emits the start of a structure. -func lexStruct(x *Lexer) stateFn { - x.emit(IonStructStart) - x.containers = append(x.containers, '{') - return lexValue -} - -// lexStructEnd ensures that ending the struct corresponds to a struct start and -// returns lexValue since we don't know what will come next. Inappropriate ending -// of the struct will be handled by the parser. -func lexStructEnd(x *Lexer) stateFn { - return containerEnd(x, IonStructEnd) -} - -// containerEnd makes sure that the container being ended matches the last one -// opened. It emits the given itemType if everything is fine. -func containerEnd(x *Lexer, it itemType) stateFn { - if len(x.containers) == 0 { - return x.error("unexpected closing of container") - } - - switch ch := x.containers[len(x.containers)-1]; { - case ch == '(' && it != IonSExpEnd: - return x.errorf("expected closing of s-expression but found %s", it) - case ch == '{' && it != IonStructEnd: - return x.errorf("expected closing of struct but found %s", it) - case ch == '[' && it != IonListEnd: - return x.errorf("expected closing of list but found %s", it) - } - - x.containers = x.containers[:len(x.containers)-1] - x.emit(it) - - return lexValue -} diff --git a/internal/lex/lex_numeric.go b/internal/lex/lex_numeric.go deleted file mode 100644 index b197cf09..00000000 --- a/internal/lex/lex_numeric.go +++ /dev/null @@ -1,332 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "bytes" -) - -const ( - // One of these runes must follow a decimal, float, int, or timestamp. - numericStopRunes = ", \t\n\r{}[]()\"'\v\f" - - binaryDigits = "_01" - decimalDigits = "_0123456789" - hexDigits = "_0123456789abcdefABCDEF" -) - -// This file contains the state functions for lexing all numeric types: decimal, float, -// int, and timestamp. - -// lexNumber scans a number: decimal, float, int (base 10, 16, or 2), or timestamp. -// Returns lexValue. -func lexNumber(x *Lexer) stateFn { - // Optional leading sign. - hasSign := x.accept("-") - runeCount := 0 - hasDot := false - - // Handle infinity: "+inf" or "-inf". - if x.acceptString("+inf") || (hasSign && x.acceptString("inf")) { - x.emit(IonInfinity) - return lexValue - } - - if hasSign && x.accept("_") { - return x.error("underscore must not be after negative sign") - } - - // Default to base10, but look for different potential valid character sets based - // on whether or not the first number is a 0. - validRunes := decimalDigits - it := IonInt - if x.accept("0") { - runeCount++ - if x.accept("xX") { - validRunes = hexDigits - it = IonIntHex - runeCount++ - } else if x.accept("bB") { - validRunes = binaryDigits - it = IonIntBinary - runeCount++ - } - } - - // We started with 0x or some variant, so the next character must not be an underscore. - if validRunes != decimalDigits && x.accept("_") { - return x.error("underscore must not be at start of hex or binary number") - } - - // Hit the first stop point. Make sure that the end of the stop point wasn't - // an underscore. If we get to this point then we have read at least one rune - // so it is safe to backup. - runeCount += x.acceptRun(validRunes, "_") - - x.backup() - if ch := x.next(); ch == '_' { - return x.error("number span cannot end with an underscore") - } - - // We can only continue on to further stop points if we are dealing with decimal digits. - if validRunes == decimalDigits { - // We can have a period and then E or D, but we can't have E or D then a period. - if x.accept(".") { - hasDot = true - runeCount++ - it = IonDecimal - if x.peek() == '_' { - return x.error("underscore may not follow a period") - } - x.acceptRun(validRunes, "_") - } - // Finally attempt to pull in everything after a float or decimal designator. - if x.accept("eE") { - it = IonFloat - // Exponents are allowed to have a no sign, a plus sign, or a minus sign. - x.accept("+-") - x.acceptRun(validRunes, "_") - } else if x.accept("dD") { - it = IonDecimal - // Exponents are allowed to have a no sign, a plus sign, or a minus sign. - x.accept("+-") - x.acceptRun(validRunes, "_") - } - } - - // If you're an int, use decimal digits, don't have a sign or a dot and you have - // exactly four numbers before hitting a stop character, you might be a timestamp. - mightBeTimestamp := it == IonInt && validRunes == decimalDigits && runeCount == 4 && !hasSign && !hasDot - switch ch := x.next(); { - case ch == 'T' && mightBeTimestamp: - // Four numbers followed by a T is most likely a year-only timestamp. - if x.input[x.pos-5] == '0' && x.input[x.pos-4] == '0' && x.input[x.pos-3] == '0' && x.input[x.pos-2] == '0' { - return x.error("year must be greater than zero") - } - x.emit(IonTimestamp) - return lexValue - case ch == '-' && mightBeTimestamp: - // Four numbers followed by a - is most likely a timestamp. - if x.input[x.pos-5] == '0' && x.input[x.pos-4] == '0' && x.input[x.pos-3] == '0' && x.input[x.pos-2] == '0' { - return x.error("year must be greater than zero") - } - return lexTimestamp - case isNumericStop(ch) || ch == eof: - // Do nothing, number terminated as expected. - default: - return x.errorf("invalid numeric stop character: %#U", ch) - } - - // We consumed one character past the number, so back up. - x.backup() - - if x.input[x.pos-1] == '_' { - return x.error("numbers cannot end with an underscore") - } - if it == IonInt && x.input[x.itemStart] == '0' && x.itemStart < x.pos-1 { - return x.error("leading zeros are not allowed for decimal integers") - } - if it == IonFloat && x.input[x.itemStart] == '0' && !bytes.ContainsRune([]byte(".eE"), rune(x.input[x.itemStart+1])) { - return x.error("leading zeros are not allowed for floats") - } - if it == IonDecimal && x.input[x.itemStart] == '0' && !bytes.ContainsRune([]byte(".dD"), rune(x.input[x.itemStart+1])) { - return x.error("leading zeros are not allowed for decimals") - } - - x.emit(it) - return lexValue -} - -// lexTimestamp lexes everything past the first "-" in a timestamp. It is assumed that -// the year and dash have been consumed. -func lexTimestamp(x *Lexer) stateFn { - // Set defaults for all our our values so that we can safely check them all - // later without worrying about which ones were set. - year := [4]byte{x.input[x.pos-5], x.input[x.pos-4], x.input[x.pos-3], x.input[x.pos-2]} - month := [2]byte{'0', '1'} - day, hour, hourOffset := month, month, month - - // Overall form can be some subset of yyyy-mm-ddThh:mm:ss.sssTZD. The "yyyy-" - // has already been lexed. Comments will show the progression in parsing. - // yyyy-mm - if !isMonthStart(x.next()) || !isNumber(x.next()) { - x.backup() - return x.errorf("invalid character as month part of timestamp: %#U", x.next()) - } - month[0], month[1] = x.input[x.pos-2], x.input[x.pos-1] - - // yyyy-mmT or yyyy-mm- - switch ch := x.next(); { - case ch == 'T': - if pk := x.peek(); !isNumericStop(pk) && pk != eof { - return x.errorf("invalid timestamp stop character: %#U", pk) - } - return validateDateAndEmit(x, year, month, day, hour, hourOffset) - case ch == '-': - // Do nothing, a dash means we're going into days. - default: - return x.errorf("invalid character after month part of timestamp: %#U", ch) - } - - // yyyy-mm-dd - if !isDayStart(x.next()) || !isNumber(x.next()) { - x.backup() - return x.errorf("invalid character as day part of timestamp: %#U", x.next()) - } - day[0], day[1] = x.input[x.pos-2], x.input[x.pos-1] - - // The day portion does not need to be terminated by a 'T' to be a valid timestamp. - // yyyy-mm-dd or yyyy-mm-ddT - switch ch := x.next(); { - case ch == 'T': - if pk := x.peek(); isNumericStop(pk) || pk == eof { - return validateDateAndEmit(x, year, month, day, hour, hourOffset) - } - case isNumericStop(ch): - x.backup() - return validateDateAndEmit(x, year, month, day, hour, hourOffset) - default: - return x.errorf("invalid character after day part of timestamp: %#U", ch) - } - - // yyyy-mm-ddThh:mm - if !isHourStart(x.next()) || !isNumber(x.next()) || x.next() != ':' || !isMinuteStart(x.next()) || !isNumber(x.next()) { - x.backup() - return x.errorf("invalid character as hour/minute part of timestamp: %#U", x.next()) - } - hour[0], hour[1] = x.input[x.pos-5], x.input[x.pos-4] - - // yyyy-mm-ddThh:mm:ss(.sss)? - if x.peek() == ':' { - x.next() - // yyyy-mm-ddThh:mm:ss - if !isMinuteStart(x.next()) || !isNumber(x.next()) { - x.backup() - return x.errorf("invalid character as seconds part of timestamp: %#U", x.next()) - } - // yyyy-mm-ddThh:mm:ss.sss (can be any number of digits) - if x.peek() == '.' { - x.next() - // There must be at least one number after the period. - if !isNumber(x.peek()) { - return x.error("missing fractional seconds value") - } - for isNumber(x.peek()) { - x.next() - } - } - } - - // If the time is included, then there must be a timezone component. - // https://www.w3.org/TR/NOTE-datetime - // yyyy-mm-ddThh:mm:ss(.sss)?TZD - switch ch := x.next(); { - case ch == '+' || ch == '-': - // TZD == +hh:mm or -hh:mm - if !isHourStart(x.next()) || !isNumber(x.next()) || x.next() != ':' || !isMinuteStart(x.next()) || !isNumber(x.next()) { - x.backup() - return x.errorf("invalid character as hour/minute part of timezone: %#U", x.next()) - } - hourOffset[0], hourOffset[1] = x.input[x.pos-5], x.input[x.pos-4] - case ch == 'Z': - // Do nothing. 'Z' is a great way to end a timestamp. - default: - return x.errorf("invalid character as timezone part of timestamp: %#U", ch) - } - - if ch := x.peek(); !isNumericStop(ch) && ch != eof { - return x.errorf("invalid timestamp stop character: %#U", ch) - } - - return validateDateAndEmit(x, year, month, day, hour, hourOffset) -} - -func validateDateAndEmit(x *Lexer, year [4]byte, month, day, hour, hourOffset [2]byte) stateFn { - monthInt := ((month[0] - '0') * 10) + (month[1] - '0') - if monthInt > 12 { - return x.errorf("invalid month %d", monthInt) - } - if monthInt == 0 { - return x.error("month must be greater than zero") - } - - dayInt := ((day[0] - '0') * 10) + (day[1] - '0') - if dayInt > 31 { - return x.errorf("invalid day %d", dayInt) - } - if dayInt == 0 { - return x.error("day must be greater than zero") - } - if (monthInt == 4 || monthInt == 6 || monthInt == 9 || monthInt == 11) && dayInt == 31 { - return x.errorf("invalid day %d for month %d", dayInt, monthInt) - } - // Only care about the year if we are in February so that we can calculate whether - // or not it is a leap year. - if monthInt == 2 { - yearInt := (int(year[0]-'0') * 1000) + (int(year[1]-'0') * 100) + (int(year[2]-'0') * 10) + int(year[3]-'0') - isLeapYear := (yearInt%4 == 0 && yearInt%100 != 0) || (yearInt%400 == 0) - if (isLeapYear && dayInt >= 30) || (!isLeapYear && dayInt >= 29) { - return x.errorf("invalid day %d for month %d in year %d", dayInt, monthInt, yearInt) - } - } - - if hour[0] == '2' && hour[1] > '3' { - return x.errorf("invalid hour %s", hour) - } - if hourOffset[0] == '2' && hourOffset[1] > '3' { - return x.errorf("invalid hour offset %s", hourOffset) - } - - x.emit(IonTimestamp) - - return lexValue -} - -// isNumericStart returns if the given rune is a valid start of an decimal, -// float, int, or timestamp -func isNumericStart(ch rune) bool { - return isNumber(ch) || ch == '-' || ch == '+' -} - -// isNumber returns if the given rune is a number between 0-9. -func isNumber(ch rune) bool { - return '0' <= ch && ch <= '9' -} - -// isMonthStart returns if the given rune is a 0 or a 1. -func isMonthStart(ch rune) bool { - return ch == '0' || ch == '1' -} - -// isHourStart returns if the given rune is a number between 0-2. -func isHourStart(ch rune) bool { - return '0' <= ch && ch <= '2' -} - -// isDayStart returns if the given rune is a number between 0-3. -func isDayStart(ch rune) bool { - return '0' <= ch && ch <= '3' -} - -// isMinuteStart returns if the given rune is a number between 0-5. -func isMinuteStart(ch rune) bool { - return '0' <= ch && ch <= '5' -} - -// isNumericStop returns true if the given rune is one of the numeric/timestamp stop chars. -func isNumericStop(ch rune) bool { - return bytes.ContainsRune([]byte(numericStopRunes), ch) -} diff --git a/internal/lex/lex_strings.go b/internal/lex/lex_strings.go deleted file mode 100644 index f128a3e2..00000000 --- a/internal/lex/lex_strings.go +++ /dev/null @@ -1,314 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "bytes" -) - -const ( - // Runes that can follow a slash as part of an escape sequence. - escapeAbleRunes = "abtnfrv?0xuU'\"/\r\n\\" -) - -// This file contains the state functions for lexing string and symbol types. It -// does not contain state functions for the binary text types Blob and Clob. - -// lexSymbol scans an annotation or symbol and returns lexValue. -func lexSymbol(x *Lexer) stateFn { - isNull := false - for { - switch ch := x.next(); { - case isIdentifierSymbolPart(ch): - case ch == '.' && string(x.input[x.itemStart:x.pos]) == "null.": - // There is a special case where a dot is okay within an identifier symbol, - // and that is when it is part of one of the null types. - isNull = true - case isIdentifierSymbolEnd(ch): - x.backup() - if isNull { - x.emit(IonNull) - } else { - x.emit(IonSymbol) - } - return lexValue - default: - return x.errorf("bad character as part of symbol: %#U", ch) - } - } -} - -// lexString scans a quoted string. -func lexString(x *Lexer) stateFn { - // Ignore the opening quote. - x.ignore() - -Loop: - for { - switch ch := x.next(); { - case ch == '\\': - if fn := handleEscapedRune(x); fn != nil { - return fn - } - case isEndOfLine(ch) || ch == eof: - return x.error("unterminated quoted string") - case ch == '"': - x.backup() - break Loop - case isStringPart(ch): - // Yay, happy string character. - default: - return x.errorf("bad character as part of string: %#U", ch) - } - } - - x.emit(IonString) - // Ignore the closing quote. - x.next() - x.ignore() - return lexValue -} - -// lexSingleQuote determines whether the single quote is for the start of a quoted -// symbol, an empty quotedSymbol, or the first of three single quotes that denotes -// a long string -func lexSingleQuote(x *Lexer) stateFn { - // Need to distinguish between an empty symbol, e.g. '', and the - // start of a "long string", e.g. ''' - if x.peek() == '\'' { - x.next() - if x.peek() == '\'' { - // Triple quote! Dive into lexing a quoted long string. - x.next() - return lexLongString - } else { - // Ignore the opening and closing quotes. - x.ignore() - // Empty quoted symbol. Emit it and move on. - x.emit(IonSymbolQuoted) - return lexValue - } - } - return lexQuotedSymbol -} - -// lexLongString scans a long string. It is up to the parser to join multiple -// long strings together. -func lexLongString(x *Lexer) stateFn { - // Ignore the initial triple quote. - x.ignore() - - count := 0 - -Loop: - for { - // Keep consuming single quotes until they stop. If there was a run - // of three or more then the string was ended and we can break the loop. - switch ch := x.next(); { - case ch == '\\': - if fn := handleEscapedRune(x); fn != nil { - return fn - } - case ch == eof: - if count >= 3 { - break Loop - } - return x.error("unterminated long string") - case ch == '\'': - count++ - case isLongStringPart(ch): - if count >= 3 { - x.backup() - break Loop - } - count = 0 - default: - if count >= 3 { - x.backup() - break Loop - } - return x.errorf("bad character as part of long string: %#U", ch) - } - } - - x.emitAndIgnoreTripleQuoteEnd(IonStringLong) - return lexValue -} - -// lexQuotedSymbol scans an annotation or symbol that is surrounded -// in quotes and returns lexValue. -func lexQuotedSymbol(x *Lexer) stateFn { - // Ignore the opening quote. - x.ignore() - -Loop: - for { - switch ch := x.next(); { - case isQuotedSymbolPart(ch): - case ch == '\\': - if fn := handleEscapedRune(x); fn != nil { - return fn - } - case ch == eof: - return x.error("unterminated quoted symbol") - case ch == '\'': - x.backup() - break Loop - default: - return x.errorf("bad character as part of quoted symbol: %#U", ch) - } - } - - x.emit(IonSymbolQuoted) - // Ignore the closing quote. - x.next() - x.ignore() - return lexValue -} - -// handleEscapedRune checks the character after an escape character '\' within -// a string. If the escaped character is not valid, e.g. EOF, then an error -// stateFn is returned. -func handleEscapedRune(x *Lexer) stateFn { - // Escaping the EOF isn't cool. - ch := x.next() - if ch == eof { - return x.error("unterminated sequence") - } - if !isEscapeAble(ch) { - return x.errorf("invalid character after escape: %#U", ch) - } - - // Both the \x and \u escapes should be followed by a specific number - // of unicode characters (2 and 4 respectively). - switch ch { - case 'x': - // '\x' HEX_DIGIT HEX_DIGIT - if !bytes.ContainsRune([]byte(hexDigits), x.next()) || !bytes.ContainsRune([]byte(hexDigits), x.next()) { - x.backup() - return x.errorf("invalid character as part of hex escape: %#U", x.peek()) - } - case 'u': - // '\u' HEX_DIGIT_QUARTET - if !bytes.ContainsRune([]byte(hexDigits), x.next()) || !bytes.ContainsRune([]byte(hexDigits), x.next()) || - !bytes.ContainsRune([]byte(hexDigits), x.next()) || !bytes.ContainsRune([]byte(hexDigits), x.next()) { - return unicodeEscapeError(x) - } - case 'U': - // '\U000' HEX_DIGIT_QUARTET HEX_DIGIT or - // '\U0010' HEX_DIGIT_QUARTET - if x.next() != '0' || x.next() != '0' { - return unicodeEscapeError(x) - } - switch next := x.next(); next { - case '0': - // Eat the hex digit that is expected to be a 0 in the other case. - if !bytes.ContainsRune([]byte(hexDigits), x.next()) { - return unicodeEscapeError(x) - } - case '1': - if x.next() != '0' { - return unicodeEscapeError(x) - } - default: - return unicodeEscapeError(x) - } - if !bytes.ContainsRune([]byte(hexDigits), x.next()) || !bytes.ContainsRune([]byte(hexDigits), x.next()) || - !bytes.ContainsRune([]byte(hexDigits), x.next()) || !bytes.ContainsRune([]byte(hexDigits), x.next()) { - return unicodeEscapeError(x) - } - } - - // If the file is from Windows then the escape character may be - // trying to escape both /r and /n. - if pk := x.peek(); ch == '\r' && pk == '\n' { - x.next() - } - return nil -} - -// unicodeEscapeError is a convenience function that backs up the lexer and emits an invalid -// character error with that character. -func unicodeEscapeError(x *Lexer) stateFn { - x.backup() - return x.errorf("invalid character as part of unicode escape: %#U", x.peek()) -} - -// isIdentifierSymbolEnd returns if the given rune is one of the container end characters. -func isContainerEnd(ch rune) bool { - return ch == ')' || ch == ']' || ch == '}' -} - -// isIdentifierSymbolStart returns if the given rune is a valid start of a symbol. -// IDENTIFIER_SYMBOL: [$_a-zA-Z] ([$_a-zA-Z] | DEC_DIGIT)* -func isIdentifierSymbolStart(ch rune) bool { - return ch == '$' || ch == '_' || ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') -} - -// isIdentifierSymbolPart returns if the given rune is a valid part of an -// identifier symbol. -// IDENTIFIER_SYMBOL: [$_a-zA-Z] ([$_a-zA-Z] | DEC_DIGIT)* -func isIdentifierSymbolPart(ch rune) bool { - return ch == '$' || ch == '_' || (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || isNumber(ch) -} - -// isIdentifierSymbolEnd returns if the given rune is a valid end character -// for an identifier symbol. -func isIdentifierSymbolEnd(ch rune) bool { - return ch == ':' || ch == ',' || ch == '.' || isContainerEnd(ch) || isWhitespace(ch) || isOperator(ch) || ch == eof -} - -// isQuotedSymbolPart returns if the given rune is a valid part of a quoted symbol. -// SYMBOL_TEXT: (TEXT_ESCAPE | SYMBOL_TEXT_ALLOWED)* -// SYMBOL_TEXT_ALLOWED -// : '\u0020'..'\u0026' // no C1 control characters and no U+0027 single quote -// | '\u0028'..'\u005B' // no U+005C backslash -// | '\u005D'..'\u10FFFF' -// | WS_NOT_NL -// Note: The backslash character is a valid escape-able character, so it is valid within -// a quoted symbol even though it isn't allowed here. -func isQuotedSymbolPart(ch rune) bool { - return (ch >= 0x0020 && ch <= 0x0026) || (ch >= 0x0028 && ch <= 0x005B) || (ch >= 0x005D && ch <= 0x10FFFF) || isSpace(ch) -} - -// isStringPart returns if the given rune is a valid part of a double-quoted string. -// STRING_SHORT_TEXT_ALLOWED -// : '\u0020'..'\u0021' // no C1 control characters and no U+0022 double quote -// | '\u0023'..'\u005B' // no U+005C backslash -// | '\u005D'..'\u10FFFF' -// | WS_NOT_NL -// Note: The backslash character is a valid escape-able character, so it is valid within -// a double-quoted string even though it isn't allowed here. -func isStringPart(ch rune) bool { - return (ch >= 0x0020 && ch <= 0x0021) || (ch >= 0x0023 && ch <= 0x005B) || (ch >= 0x005D && ch <= 0x10FFFF) || isSpace(ch) -} - -// isLongStringPart returns if the given rune is a valid part of a long string. -// STRING_LONG_TEXT_ALLOWED -// : '\u0020'..'\u005B' // no C1 control characters and no U+005C backslash -// | '\u005D'..'\u10FFFF' -// | WS -// Note: The backslash character is a valid escape-able character, so it is valid within -// a long string even though it isn't allowed here. -func isLongStringPart(ch rune) bool { - return (ch >= 0x0020 && ch <= 0x005B) || (ch >= 0x005D && ch <= 0x10FFFF) || isWhitespace(ch) -} - -// isEscapeAble returns if the given rune is a character that is allowed to follow the -// escape character U+005C backslash. -func isEscapeAble(ch rune) bool { - return bytes.ContainsRune([]byte(escapeAbleRunes), ch) -} diff --git a/internal/lex/lex_test.go b/internal/lex/lex_test.go deleted file mode 100644 index 1b53abcb..00000000 --- a/internal/lex/lex_test.go +++ /dev/null @@ -1,753 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -var ( - tEOF = Item{Type: IonEOF} - tBinaryStart = Item{Type: IonBinaryStart, Val: []byte("{{")} - tBinaryEnd = Item{Type: IonBinaryEnd, Val: []byte("}}")} - tColon = Item{Type: IonColon, Val: []byte(":")} - tComma = Item{Type: IonComma, Val: []byte(",")} - tDoubleColon = Item{Type: IonDoubleColon, Val: []byte("::")} - tListStart = Item{Type: IonListStart, Val: []byte("[")} - tListEnd = Item{Type: IonListEnd, Val: []byte("]")} - tSExpStart = Item{Type: IonSExpStart, Val: []byte("(")} - tSExpEnd = Item{Type: IonSExpEnd, Val: []byte(")")} - tStructStart = Item{Type: IonStructStart, Val: []byte("{")} - tStructEnd = Item{Type: IonStructEnd, Val: []byte("}")} -) - -func TestLex(t *testing.T) { - binInt := func(val string) Item { return Item{Type: IonIntBinary, Val: []byte(val)} } - blob := func(val string) Item { return Item{Type: IonBlob, Val: []byte(val)} } - blockComment := func(val string) Item { return Item{Type: IonCommentBlock, Val: []byte(val)} } - clobLong := func(val string) Item { return Item{Type: IonClobLong, Val: []byte(val)} } - clobShort := func(val string) Item { return Item{Type: IonClobShort, Val: []byte(val)} } - decimal := func(val string) Item { return Item{Type: IonDecimal, Val: []byte(val)} } - doubleQuote := func(val string) Item { return Item{Type: IonString, Val: []byte(val)} } - err := func(val string) Item { return Item{Type: IonError, Val: []byte(val)} } - float := func(val string) Item { return Item{Type: IonFloat, Val: []byte(val)} } - hexInt := func(val string) Item { return Item{Type: IonIntHex, Val: []byte(val)} } - integer := func(val string) Item { return Item{Type: IonInt, Val: []byte(val)} } - lineComment := func(val string) Item { return Item{Type: IonCommentLine, Val: []byte(val)} } - longString := func(val string) Item { return Item{Type: IonStringLong, Val: []byte(val)} } - nullItem := func(val string) Item { return Item{Type: IonNull, Val: []byte(val)} } - operator := func(val string) Item { return Item{Type: IonOperator, Val: []byte(val)} } - quotedSym := func(val string) Item { return Item{Type: IonSymbolQuoted, Val: []byte(val)} } - symbol := func(val string) Item { return Item{Type: IonSymbol, Val: []byte(val)} } - timestamp := func(val string) Item { return Item{Type: IonTimestamp, Val: []byte(val)} } - - tests := []struct { - name string - input []byte - expected []Item - }{ - // Empty things. - - { - name: "nil input", - expected: []Item{tEOF}, - }, - { - name: "only whitespace", - input: []byte(" \t\r\n\f\v"), - expected: []Item{tEOF}, - }, - { - name: "empty list", - input: []byte("[]"), - expected: []Item{tListStart, tListEnd, tEOF}, - }, - { - name: "empty struct", - input: []byte("{}"), - expected: []Item{tStructStart, tStructEnd, tEOF}, - }, - { - name: "empty s-expression", - input: []byte("()"), - expected: []Item{tSExpStart, tSExpEnd, tEOF}, - }, - { - name: "empty single quote", - input: []byte("''"), - expected: []Item{quotedSym(""), tEOF}, - }, - { - name: "empty double quote", - input: []byte(`""`), - expected: []Item{doubleQuote(""), tEOF}, - }, - { - name: "empty triple single quote", - input: []byte("''''''"), - expected: []Item{longString(""), tEOF}, - }, - { - name: "empty symbol to symbol", - input: []byte("'': abc"), - expected: []Item{quotedSym(""), tColon, symbol("abc"), tEOF}, - }, - { - name: "empty annotation to symbol", - input: []byte("'':: abc"), - expected: []Item{quotedSym(""), tDoubleColon, symbol("abc"), tEOF}, - }, - - // Simple symbols, strings, nulls, and comments. - - { - name: "symbol to symbol", - input: []byte("'':\nabc"), - expected: []Item{quotedSym(""), tColon, symbol("abc"), tEOF}, - }, - { - name: "symbol to quoted symbol", - input: []byte("'':\n'abc'"), - expected: []Item{quotedSym(""), tColon, quotedSym("abc"), tEOF}, - }, - { - name: "symbol to double quoted symbol", - input: []byte(`'':"abc"`), - expected: []Item{quotedSym(""), tColon, doubleQuote("abc"), tEOF}, - }, - { - name: "long string", - input: []byte("'''\"'''"), - expected: []Item{longString(`"`), tEOF}, - }, - { - name: "long string with single quote", - input: []byte("'''''''"), - expected: []Item{longString("'"), tEOF}, - }, - { - name: "quoted string", - input: []byte(`"\\"`), - expected: []Item{doubleQuote(`\\`), tEOF}, - }, - { - name: "long string with single quotes", - input: []byte("''' ' '' '''"), - expected: []Item{longString(" ' '' "), tEOF}, - }, - { - name: "some nulls", - input: []byte("null null.bool"), - // Since the first null doesn't have a period we don't know that it is a null until we parse. - expected: []Item{symbol("null"), nullItem("null.bool"), tEOF}, - }, - { - name: "some nulls in a list", - input: []byte("[\n\tnull,\n\tnull.bool]"), - // Since the first null doesn't have a period we don't know that it is a null until we parse. - expected: []Item{tListStart, symbol("null"), tComma, nullItem("null.bool"), tListEnd, tEOF}, - }, - { - name: "quoted null to null", - input: []byte("'null.bool':null.bool"), - expected: []Item{quotedSym("null.bool"), tColon, nullItem("null.bool"), tEOF}, - }, - { - name: "we don't know that these are boolean", - input: []byte("true false"), - expected: []Item{symbol("true"), symbol("false"), tEOF}, - }, - { - name: "line comment", - input: []byte("// Line Comment"), - expected: []Item{lineComment(" Line Comment"), tEOF}, - }, - { - name: "block comment", - input: []byte("/* Block\n Comment*/"), - expected: []Item{blockComment(" Block\n Comment"), tEOF}, - }, - - // Numeric. - - { - name: "infinity", - input: []byte("inf +inf -inf"), - // "inf" must have a plus or minus on it to be considered a number. - expected: []Item{symbol("inf"), {Type: IonInfinity, Val: []byte("+inf")}, {Type: IonInfinity, Val: []byte("-inf")}, tEOF}, - }, - { - name: "integers", - input: []byte("0 -1 1_2_3 0xFf 0Xe_d 0b10 0B1_0"), - expected: []Item{integer("0"), integer("-1"), integer("1_2_3"), hexInt("0xFf"), hexInt("0Xe_d"), binInt("0b10"), binInt("0B1_0"), tEOF}, - }, - { - name: "decimals", - input: []byte("0. 0.123 -0.12d4 0D-0 0d+0 12_34.56_78"), - expected: []Item{decimal("0."), decimal("0.123"), decimal("-0.12d4"), decimal("0D-0"), decimal("0d+0"), decimal("12_34.56_78"), tEOF}, - }, - { - name: "floats", - input: []byte("0E0 0.12e-4 -0e+0"), - expected: []Item{float("0E0"), float("0.12e-4"), float("-0e+0"), tEOF}, - }, - { - name: "dates", - input: []byte("2019T 2019-10T 2019-10-30 2019-10-30T"), - expected: []Item{timestamp("2019T"), timestamp("2019-10T"), timestamp("2019-10-30"), timestamp("2019-10-30T"), tEOF}, - }, - { - name: "times", - input: []byte("2019-10-30T22:30Z 2019-10-30T12:30:59+02:30 2019-10-30T12:30:59.999-02:30"), - expected: []Item{timestamp("2019-10-30T22:30Z"), timestamp("2019-10-30T12:30:59+02:30"), timestamp("2019-10-30T12:30:59.999-02:30"), tEOF}, - }, - - // Binary. - - { - name: "short blob", - input: []byte("{{+AB/}}"), - expected: []Item{tBinaryStart, blob("+AB/"), tBinaryEnd, tEOF}, - }, - { - name: "padded blob with whitespace", - input: []byte("{{ + A\nB\t/abc= }}"), - expected: []Item{tBinaryStart, blob("+ A\nB\t/abc= "), tBinaryEnd, tEOF}, - }, - { - name: "short clob", - input: []byte(`{{ "A\n" }}`), - expected: []Item{tBinaryStart, clobShort(`A\n`), tBinaryEnd, tEOF}, - }, - { - name: "symbol to short clob", - input: []byte(`abc : {{ "A\n" }}`), - expected: []Item{symbol("abc"), tColon, tBinaryStart, clobShort(`A\n`), tBinaryEnd, tEOF}, - }, - { - name: "symbol with comments to short clob", - input: []byte("abc : // Line\n/* Block */ {{ \"A\\n\" }}"), - expected: []Item{symbol("abc"), tColon, lineComment(" Line"), blockComment(" Block "), tBinaryStart, clobShort(`A\n`), tBinaryEnd, tEOF}, - }, - { - name: "long clob", - input: []byte("{{ '''+AB/''' }}"), - expected: []Item{tBinaryStart, clobLong("+AB/"), tBinaryEnd, tEOF}, - }, - { - name: "multiple long clobs", - input: []byte("{{ '''A\\nB'''\n'''foo''' }}"), - expected: []Item{tBinaryStart, clobLong("A\\nB"), clobLong("foo"), tBinaryEnd, tEOF}, - }, - { - name: "quotes withing a long clob", - input: []byte("{{ ''' ' '' ''' }}"), - expected: []Item{tBinaryStart, clobLong(" ' '' "), tBinaryEnd, tEOF}, - }, - - // Containers - - { - name: "symbol to empty list", - input: []byte("abc\t:[]"), - expected: []Item{symbol("abc"), tColon, tListStart, tListEnd, tEOF}, - }, - { - name: "list of things", - input: []byte("[a, 1, ' ', {}, () /* comment */ ]"), - expected: []Item{ - tListStart, symbol("a"), tComma, - integer("1"), tComma, - quotedSym(" "), tComma, - tStructStart, tStructEnd, tComma, - tSExpStart, tSExpEnd, - blockComment(" comment "), - tListEnd, tEOF}, - }, - { - name: "symbol to empty struct", - input: []byte("abc:\t{ // comment\n}"), - expected: []Item{symbol("abc"), tColon, tStructStart, lineComment(" comment"), tStructEnd, tEOF}, - }, - { - name: "struct of things", - input: []byte("{'a' : 1 , s:'', 'st': {}, '''lngstr''': 1,\nlst:[],\"sexp\":()}"), - expected: []Item{tStructStart, - quotedSym("a"), tColon, integer("1"), tComma, - symbol("s"), tColon, quotedSym(""), tComma, - quotedSym("st"), tColon, tStructStart, tStructEnd, tComma, - longString("lngstr"), tColon, integer("1"), tComma, - symbol("lst"), tColon, tListStart, tListEnd, tComma, - doubleQuote("sexp"), tColon, tSExpStart, tSExpEnd, - tStructEnd, tEOF}, - }, - { - name: "symbol to empty s-expression", - input: []byte("abc:\r\n()"), - expected: []Item{symbol("abc"), tColon, tSExpStart, tSExpEnd, tEOF}, - }, - { - name: "s-expression of things", - input: []byte("(a+b/c--( j * k))"), - expected: []Item{tSExpStart, - symbol("a"), operator("+"), symbol("b"), operator("/"), symbol("c"), operator("--"), - tSExpStart, symbol("j"), operator("*"), symbol("k"), tSExpEnd, - tSExpEnd, tEOF}, - }, - - // Error cases - - { - name: "invalid start", - input: []byte(" 世界"), - expected: []Item{err("invalid start of a value: U+4E16 '世'")}, - }, - { - name: "invalid symbol value", - input: []byte("a:世界"), - expected: []Item{symbol("a"), tColon, err("invalid start of a value: U+4E16 '世'")}, - }, - { - name: "unterminated block comment", - input: []byte("/* "), - expected: []Item{err("unexpected end of file while lexing block comment")}, - }, - { - name: "rune error in line comment", - input: []byte("// \uFFFD"), - expected: []Item{err("error parsing rune")}, - }, - { - name: "rune error in block comment", - input: []byte("/* \uFFFD */"), - expected: []Item{err("error parsing rune")}, - }, - { - name: "double struct end", - input: []byte("{} a }"), - expected: []Item{tStructStart, tStructEnd, symbol("a"), err("unexpected closing of container")}, - }, - { - name: "double list end", - input: []byte("[] a ]"), - expected: []Item{tListStart, tListEnd, symbol("a"), err("unexpected closing of container")}, - }, - { - name: "double sexp end", - input: []byte("() a )"), - expected: []Item{tSExpStart, tSExpEnd, symbol("a"), err("unexpected closing of container")}, - }, - { - name: "mismatch: struct list", - input: []byte("{]"), - expected: []Item{tStructStart, err("expected closing of struct but found ]")}, - }, - { - name: "mismatch: list sexp", - input: []byte("[)"), - expected: []Item{tListStart, err("expected closing of list but found )")}, - }, - { - name: "mismatch: sexp struct", - input: []byte("(}"), - expected: []Item{tSExpStart, err("expected closing of s-expression but found }")}, - }, - { - name: "invalid escaped char in long string", - input: []byte("'''\\c'''"), - expected: []Item{err("invalid character after escape: U+0063 'c'")}, - }, - { - name: "invalid escaped char in short string", - input: []byte("\"\\c\""), - expected: []Item{err("invalid character after escape: U+0063 'c'")}, - }, - { - name: "invalid escaped char in quoted symbol", - input: []byte("'\\c'"), - expected: []Item{err("invalid character after escape: U+0063 'c'")}, - }, - { - name: "unterminated long string", - input: []byte("'''"), - expected: []Item{err("unterminated long string")}, - }, - { - name: "unterminated string", - input: []byte(`"`), - expected: []Item{err("unterminated quoted string")}, - }, - { - name: "unterminated quoted symbol", - input: []byte(`'`), - expected: []Item{err("unterminated quoted symbol")}, - }, - { - name: "escaping EOF in string", - input: []byte(`"a\`), - expected: []Item{err("unterminated sequence")}, - }, - { - name: "escaping EOF in quoted symbol", - input: []byte(`'a\`), - expected: []Item{err("unterminated sequence")}, - }, - { - name: "escaping a non-hex character for hex escape", - input: []byte(`'\xAG'`), - expected: []Item{err("invalid character as part of hex escape: U+0047 'G'")}, - }, - { - name: "escaping a non-hex character for unicode escape", - input: []byte(`'\u000G'`), - expected: []Item{err("invalid character as part of unicode escape: U+0047 'G'")}, - }, - { - name: "invalid start for a \\U escape", - input: []byte(`'\U1000'`), - expected: []Item{err("invalid character as part of unicode escape: U+0031 '1'")}, - }, - { - name: "invalid \\U000 escape", - input: []byte(`'\U000G'`), - expected: []Item{err("invalid character as part of unicode escape: U+0047 'G'")}, - }, - { - name: "invalid start for a \\U0010 escape", - input: []byte(`'\U001G'`), - expected: []Item{err("invalid character as part of unicode escape: U+0047 'G'")}, - }, - { - name: "invalid \\U escape", - input: []byte(`'\U0010G000'`), - expected: []Item{err("invalid character as part of unicode escape: U+0047 'G'")}, - }, - { - name: "invalid character in symbol", - input: []byte("null世int"), - expected: []Item{err("bad character as part of symbol: U+4E16 '世'")}, - }, - { - name: "invalid character in quoted symbol", - input: []byte("'null\u0007'"), - expected: []Item{err("bad character as part of quoted symbol: U+0007")}, - }, - { - name: "invalid character in quoted string", - input: []byte("\"null\u0007\""), - expected: []Item{err("bad character as part of string: U+0007")}, - }, - { - name: "invalid character in long string", - input: []byte("'''null\u0007'''"), - expected: []Item{err("bad character as part of long string: U+0007")}, - }, - { - name: "invalid character after long string", - input: []byte("'''null'''\u0007"), - expected: []Item{longString("null"), err("invalid start of a value: U+0007")}, - }, - { - name: "int with leading zeros", - input: []byte("007"), - expected: []Item{err("leading zeros are not allowed for decimal integers")}, - }, - { - name: "decimal with leading zeros", - input: []byte("03.4"), - expected: []Item{err("leading zeros are not allowed for decimals")}, - }, - { - name: "float with leading zeros", - input: []byte("03.4e0"), - expected: []Item{err("leading zeros are not allowed for floats")}, - }, - { - name: "decimal with trailing underscore", - input: []byte("123.456_"), - expected: []Item{err("numbers cannot end with an underscore")}, - }, - { - name: "hex designator followed by underscore", - input: []byte("0x_0"), - expected: []Item{err("underscore must not be at start of hex or binary number")}, - }, - { - name: "hex followed by underscore", - input: []byte("0x0_"), - expected: []Item{err("number span cannot end with an underscore")}, - }, - { - name: "underscore before period", - input: []byte("1_."), - expected: []Item{err("number span cannot end with an underscore")}, - }, - { - name: "underscore after period", - input: []byte("1._"), - expected: []Item{err("underscore may not follow a period")}, - }, - { - name: "underscore after negative sign binary", - input: []byte("-_0b1010"), - expected: []Item{err("underscore must not be after negative sign")}, - }, - { - name: "repeated underscores", - input: []byte("1__0"), - expected: []Item{err("number span cannot end with an underscore")}, - }, - { - name: "not a numeric stop character", - input: []byte("1a"), - expected: []Item{err("invalid numeric stop character: U+0061 'a'")}, - }, - { - name: "year 0000", - input: []byte("0000T"), - expected: []Item{err("year must be greater than zero")}, - }, - { - name: "year 0000 with month", - input: []byte("0000T-01"), - expected: []Item{err("year must be greater than zero")}, - }, - { - name: "month 20", - input: []byte("2019-20T"), - expected: []Item{err("invalid character as month part of timestamp: U+0032 '2'")}, - }, - { - name: "month 13", - input: []byte("2019-13T"), - expected: []Item{err("invalid month 13")}, - }, - { - name: "month 0", - input: []byte("2019-00T"), - expected: []Item{err("month must be greater than zero")}, - }, - { - name: "year and month must have a T", - input: []byte("2019-12 "), - expected: []Item{err("invalid character after month part of timestamp: U+0020 ' '")}, - }, - { - name: "not a numeric stop character after year and month", - input: []byte("2019-12Ta"), - expected: []Item{err("invalid timestamp stop character: U+0061 'a'")}, - }, - { - name: "day 40", - input: []byte("2019-12-40T"), - expected: []Item{err("invalid character as day part of timestamp: U+0034 '4'")}, - }, - { - name: "day 32", - input: []byte("2019-12-32T"), - expected: []Item{err("invalid day 32")}, - }, - { - name: "April 31", - input: []byte("2019-04-31T"), - expected: []Item{err("invalid day 31 for month 4")}, - }, - { - name: "day 0", - input: []byte("2019-12-00T"), - expected: []Item{err("day must be greater than zero")}, - }, - { - name: "not a numeric stop character after year month and day", - input: []byte("2019-12-30a"), - expected: []Item{err("invalid character after day part of timestamp: U+0061 'a'")}, - }, - { - name: "not a numeric character after year month and day", - input: []byte("2019-12-30Ta"), - expected: []Item{err("invalid character as hour/minute part of timestamp: U+0061 'a'")}, - }, - { - name: "hour 30", - input: []byte("2019-12-30T30:00Z"), - expected: []Item{err("invalid character as hour/minute part of timestamp: U+0033 '3'")}, - }, - { - name: "hour 24", - input: []byte("2019-12-30T24:00Z"), - expected: []Item{err("invalid hour 24")}, - }, - { - name: "minute 60", - input: []byte("2019-12-30T12:60Z"), - expected: []Item{err("invalid character as hour/minute part of timestamp: U+0036 '6'")}, - }, - { - name: "second 60", - input: []byte("2019-12-30T12:34:60Z"), - expected: []Item{err("invalid character as seconds part of timestamp: U+0036 '6'")}, - }, - { - name: "no fractional seconds", - input: []byte("2019-12-30T12:34:00.Z"), - expected: []Item{err("missing fractional seconds value")}, - }, - { - name: "timezone offset hour 30", - input: []byte("2019-12-30T12:34:59+30:00"), - expected: []Item{err("invalid character as hour/minute part of timezone: U+0033 '3'")}, - }, - { - name: "timezone offset hour 24", - input: []byte("2019-12-30T12:34:59+24:00"), - expected: []Item{err("invalid hour offset 24")}, - }, - { - name: "timezone offset minute 60", - input: []byte("2019-12-30T12:34:59+10:60"), - expected: []Item{err("invalid character as hour/minute part of timezone: U+0036 '6'")}, - }, - { - name: "invalid timezone offset", - input: []byte("2019-12-30T12:34:59a"), - expected: []Item{err("invalid character as timezone part of timestamp: U+0061 'a'")}, - }, - { - name: "invalid character after timezone offset", - input: []byte("2019-12-30T12:34:59Za"), - expected: []Item{err("invalid timestamp stop character: U+0061 'a'")}, - }, - { - name: "unterminated blob", - input: []byte("{{abcd"), - expected: []Item{tBinaryStart, err("unterminated blob")}, - }, - { - name: "blob with only one ending brace", - input: []byte("{{abcd}a"), - expected: []Item{tBinaryStart, err("invalid end to blob, expected } but found: U+0061 'a'")}, - }, - { - name: "blob with padding in the middle", - input: []byte("{{ab=cd}}"), - expected: []Item{tBinaryStart, err("base64 character found after padding character")}, - }, - { - name: "invalid blob character", - input: []byte("{{abc.}}"), - expected: []Item{tBinaryStart, err("invalid rune as part of blob string: U+002E '.'")}, - }, - { - name: "invalid base64 encoding", - input: []byte("{{ab=}}"), - expected: []Item{tBinaryStart, err("invalid base64 encoding")}, - }, - { - name: "unterminated short clob escaping EOF", - input: []byte(`{{ "ab\`), - expected: []Item{tBinaryStart, err("unterminated short clob")}, - }, - { - name: "clob escaping c", - input: []byte(`{{ "ab\c" }}`), - expected: []Item{tBinaryStart, err("invalid character after escape: U+0063 'c'")}, - }, - { - name: "clob invalid hex escape", - input: []byte(`{{ "ab\x0g" }}`), - expected: []Item{tBinaryStart, err("invalid character as part of hex escape: U+0067 'g'")}, - }, - { - name: "clob unicode escape", - input: []byte(`{{ "ab\u0067" }}`), - expected: []Item{tBinaryStart, err("unicode escape is not valid in clob")}, - }, - { - name: "unterminated short clob no closing quote", - input: []byte(`{{ "ab`), - expected: []Item{tBinaryStart, err("unterminated short clob")}, - }, - { - name: "unterminated short clob no closing brace", - input: []byte(`{{ "ab" a`), - expected: []Item{tBinaryStart, clobShort("ab"), err("invalid end to short clob, expected } but found: U+0061 'a'")}, - }, - { - name: "unterminated short clob only one closing brace", - input: []byte(`{{ "ab" }a`), - expected: []Item{tBinaryStart, clobShort("ab"), err("invalid end to short clob, expected second } but found: U+0061 'a'")}, - }, - { - name: "invalid short clob text", - input: []byte("{{ \"ab\u0007\" }}"), - expected: []Item{tBinaryStart, err("invalid rune as part of short clob string: U+0007")}, - }, - { - name: "unterminated long clob no closing brace", - input: []byte(`{{ '''ab''' a`), - expected: []Item{tBinaryStart, clobLong("ab"), err("expected end of a Clob or start of a long string but found: U+0061 'a'")}, - }, - { - name: "unterminated long clob only one closing brace", - input: []byte(`{{ '''ab''' }a`), - expected: []Item{tBinaryStart, clobLong("ab"), err("expected a second } but found: U+0061 'a'")}, - }, - { - name: "unterminated long clob escaping EOF", - input: []byte(`{{ '''ab\`), - expected: []Item{tBinaryStart, err("unterminated long clob")}, - }, - { - name: "unterminated long clob no closing quotes", - input: []byte(`{{ '''ab`), - expected: []Item{tBinaryStart, err("unterminated long clob")}, - }, - { - name: "invalid long clob text", - input: []byte("{{ '''ab\u0007''' }}"), - expected: []Item{tBinaryStart, err("invalid rune as part of long clob string: U+0007")}, - }, - } - for _, tst := range tests { - test := tst - t.Run(test.name, func(t *testing.T) { - out := runLexer(test.input) - // Only focusing on the type and value for these tests. - if diff := cmp.Diff(test.expected, out, cmpopts.EquateEmpty(), cmpopts.IgnoreFields(Item{}, "Pos")); diff != "" { - t.Log("Expected:", test.expected) - t.Log("Found: ", out) - t.Error("(-expected, +found)", diff) - } - }) - } -} - -// Gather the items emitted from the Lexer into a slice. -func runLexer(input []byte) []Item { - x := New(input) - var items []Item - for { - item := x.NextItem() - items = append(items, item) - if item.Type == IonEOF || item.Type == IonError { - break - } - } - return items -} diff --git a/internal/lex/lexitem.go b/internal/lex/lexitem.go deleted file mode 100644 index 061225a8..00000000 --- a/internal/lex/lexitem.go +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "fmt" - "strconv" -) - -// A token returned from the Lexer. -type Item struct { - Type itemType // The type of this Item. - Pos int // The starting position, in bytes, of this Item in the input. - Val []byte // The value of this Item. -} - -// String satisfies Stringer. -func (i Item) String() string { - _, typeKnown := itemTypeMap[i.Type] - switch { - case i.Type == IonEOF: - return "EOF" - case i.Type == IonIllegal: - return string(i.Val) - case len(i.Val) > 75: - return fmt.Sprintf("<%.75s>...", i.Val) - case !typeKnown: - return fmt.Sprintf("%s <%s>", i.Type, i.Val) - } - // We use the '<' and '>' characters because both single and double quotes are - // used extensively as are brackets, braces, and parens. - return fmt.Sprintf("<%s>", i.Val) -} - -const ( - IonIllegal itemType = iota - IonError - IonEOF - - IonBlob - IonClobLong - IonClobShort - IonCommentBlock - IonCommentLine - IonDecimal - IonFloat - IonInfinity - IonInt - IonIntBinary - IonIntHex - IonList - IonNull - IonSExp - IonString - IonStringLong - IonStruct - IonSymbol - IonSymbolQuoted - IonTimestamp - - IonBinaryStart // {{ - IonBinaryEnd // }} - IonColon // : - IonDoubleColon // :: - IonComma // , - IonDot // . - IonOperator // One of !#%&*+\\-/;<=>?@^`|~ - IonStructStart // { - IonStructEnd // } - IonListStart // [ - IonListEnd // ] - IonSExpStart // ( - IonSExpEnd // ) -) - -// Type of the lex Item. -type itemType int - -var itemTypeMap = map[itemType]string{ - IonIllegal: "Illegal", - IonError: "Error", - IonEOF: "EOF", - IonNull: "Null", - - IonBlob: "Blob", - IonClobLong: "ClobLong", - IonClobShort: "ClobShort", - IonCommentBlock: "BlockComment", - IonCommentLine: "LineComment", - IonDecimal: "Decimal", - IonInfinity: "Infinity", - IonInt: "Int", - IonIntBinary: "BinaryInt", - IonIntHex: "HexInt", - IonFloat: "Float", - IonList: "List", - IonSExp: "SExp", - IonString: "String", - IonStringLong: "LongString", - IonStruct: "Struct", - IonSymbol: "Symbol", - IonSymbolQuoted: "QuotedSymbol", - IonTimestamp: "Timestamp", - - IonBinaryStart: "{{", - IonBinaryEnd: "}}", - IonColon: ":", - IonDoubleColon: "::", - IonComma: ",", - IonDot: ".", - IonOperator: "Operator", - IonStructStart: "{", - IonStructEnd: "}", - IonListStart: "[", - IonListEnd: "]", - IonSExpStart: "(", - IonSExpEnd: ")", -} - -func (i itemType) String() string { - if s, ok := itemTypeMap[i]; ok { - return s - } - return "Unknown itemType " + strconv.Itoa(int(i)) -} diff --git a/internal/lex/lexitem_test.go b/internal/lex/lexitem_test.go deleted file mode 100644 index cef7dd03..00000000 --- a/internal/lex/lexitem_test.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package lex - -import ( - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestLexItem_String(t *testing.T) { - tests := []struct { - item Item - expected string - }{ - {}, - {item: Item{Type: IonEOF}, expected: "EOF"}, - {item: Item{Type: itemType(100)}, expected: "Unknown itemType 100 <>"}, - {item: Item{Type: IonIllegal, Val: []byte("illegal")}, expected: "illegal"}, - {item: Item{Type: IonCommentLine, Val: []byte("comment")}, expected: ``}, - { - item: Item{ - Type: IonString, - Val: []byte("12345678901234567890123456789012345678901234567890123456789012345678901234567890"), - }, - expected: `<123456789012345678901234567890123456789012345678901234567890123456789012345>...`, - }, - { - item: Item{ - Type: IonString, - Val: []byte("123456789012345678901234567890123456789012345678901234567890123456789012345"), - }, - expected: `<123456789012345678901234567890123456789012345678901234567890123456789012345>`, - }, - } - - for _, tst := range tests { - test := tst - t.Run(test.expected, func(t *testing.T) { - if diff := cmp.Diff(test.expected, test.item.String()); diff != "" { - t.Error("(-expected, +found)", diff) - } - }) - } -} diff --git a/ion/cmp_test.go b/ion/cmp_test.go deleted file mode 100644 index c0eda81a..00000000 --- a/ion/cmp_test.go +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" -) - -// This file contains all of the specialized comparison functions we use for tests. - -func assertEquivalentValues(values []Value, t *testing.T) { - t.Helper() - - for i, value := range values { - next := values[i+1] - if diff := cmpValueResults(value, next); diff != "" { - t.Logf("value %d: %#v", i, value) - t.Logf("value %d: %#v", i+1, next) - t.Error("values", i, "and", i+1, ": (-expected, +found)", diff) - } - if i >= len(values)-2 { - break - } - } -} - -func assertNonEquivalentValues(values []Value, t *testing.T) { - t.Helper() - - for i, value := range values { - for n := i + 1; n < len(values); n++ { - next := values[n] - fmt.Println("comparing values", i, "and", n) - if diff := cmpValueResults(value, next); diff == "" { - t.Logf("value %d: %#v", i, value) - t.Logf("value %d: %#v", n, next) - t.Error("values", i, "and", n, "are equivalent") - } - } - - if i >= len(values)-2 { - break - } - } -} - -// cmpDigests compares nil-ness of the Digests, then calls cmpValueSlices on the two -// if they are not nil. -func cmpDigests(expected, found *Digest) string { - if (expected == nil) != (found == nil) { - return fmt.Sprintf("nil mis-match: expected is %v and found is %v", expected, found) - } - if expected == nil { - return "" - } - - return cmpValueSlices(expected.values, found.values) -} - -// cmpValueSlices compares the results of the calls to Binary(), Text(), and Value() of each -// element of the given Value slices. -func cmpValueSlices(expected, found []Value) string { - if len(expected) != len(found) { - return fmt.Sprintf("length mis-match: expected number of values to be %v, but found %v", len(expected), len(found)) - } - - for i, exp := range expected { - fnd := found[i] - if diff := cmp.Diff(exp.Type(), fnd.Type()); diff != "" { - return diff - } - - if exp.IsNull() != fnd.IsNull() { - return fmt.Sprintf("item %d expected IsNull %v but found %v", i, exp.IsNull(), fnd.IsNull()) - } - - if diff := cmp.Diff(exp.Binary(), fnd.Binary()); diff != "" { - return fmt.Sprintf("item %d of Type %s Binary() %s", i, fnd.Type(), diff) - } - - if diff := cmp.Diff(string(exp.Text()), string(fnd.Text())); diff != "" { - return fmt.Sprintf("item %d of type %s Text() %s", i, fnd.Type(), diff) - } - - if diff := cmpAnnotations(exp.Annotations(), fnd.Annotations()); diff != "" { - return diff - } - - if diff := cmpValueResults(exp, fnd); diff != "" { - return fmt.Sprintf("item %d of type %s: %s", i, exp.Type(), diff) - } - } - return "" -} - -// cmpValueResults compares the results of calling Value() on the two given Values. -// If there is a difference, then that difference is returned. Otherwise the empty -// string is returned. -func cmpValueResults(expected, found Value) string { - if expected.IsNull() != found.IsNull() { - return fmt.Sprintf("expected is null %v and found is null %v", expected.IsNull(), found.IsNull()) - } - - if expected.Type() != found.Type() { - return fmt.Sprintf("expected type is %s and found type is %s", expected.Type(), found.Type()) - } - - if diff := cmpAnnotations(expected.Annotations(), found.Annotations()); diff != "" { - return diff - } - - switch expected.(type) { - case Blob: - return cmp.Diff(string(expected.(Blob).Value()), string(found.(Blob).Value())) - case Bool: - return cmp.Diff(expected.(Bool).Value(), found.(Bool).Value()) - case Clob: - return cmp.Diff(expected.(Clob).Value(), found.(Clob).Value()) - case Decimal: - return cmpDecimals(expected.(Decimal), found.(Decimal)) - case Float: - return cmp.Diff(expected.(Float).Value(), found.(Float).Value()) - case Int: - return cmpInts(expected.(Int), found.(Int)) - case List: - return cmpValueSlices(expected.(List).values, found.(List).values) - case SExp: - return cmpValueSlices(expected.(SExp).values, found.(SExp).values) - case String: - return cmp.Diff(expected.(String).Value(), found.(String).Value()) - case Struct: - return cmpStructFields(expected.(Struct).fields, found.(Struct).fields) - case Symbol: - return cmp.Diff(expected.(Symbol).Value(), found.(Symbol).Value()) - case Timestamp: - return cmpTimestamps(expected.(Timestamp), found.(Timestamp)) - } - - return "" -} - -func cmpAnnotations(expected, found []Symbol) string { - if len(expected) != len(found) { - return fmt.Sprintf("length mis-match: expected annotation length is %v and found is %v", len(expected), len(found)) - } - - for i, exp := range expected { - fnd := found[i] - if exp.id != fnd.id { - return fmt.Sprintf("expected annotation at index %d to have id %d but found %d", i, exp.id, fnd.id) - } - // TODO: Support symbol tables. - expText, fndText := exp.Text(), fnd.Text() - if bytes.HasPrefix(expText, []byte{'$'}) || bytes.HasPrefix(fndText, []byte{'$'}) { - continue - } - if diff := cmp.Diff(expText, fndText); diff != "" { - return fmt.Sprintf("expected annotation at index %d to have text %q but found %q", i, expText, fndText) - } - } - - return "" -} - -func cmpDecimals(expected, found Decimal) string { - expVal, fndVal := expected.Value(), found.Value() - if !expVal.Equal(fndVal) { - return fmt.Sprintf("value differs: %q %q", expVal, fndVal) - } - - // TODO: Do a comparison that tracks precision. - - return "" -} - -// cmpErrs calls cmp.Diff on a string representation of the two errors. -func cmpErrs(expected, found error) string { - expectedStr := "nil" - if expected != nil { - expectedStr = expected.Error() - } - foundStr := "nil" - if found != nil { - foundStr = found.Error() - } - return cmp.Diff(expectedStr, foundStr) -} - -func cmpInts(expected, found Int) string { - exp, fnd := expected.Value(), found.Value() - if (exp == nil) != (fnd == nil) { - return fmt.Sprintf("nil Value() mis-match for Int: expected is %v and found is %v", exp, fnd) - } - if exp == nil { - return "" - } - - if exp.Cmp(fnd) != 0 { - return fmt.Sprintf("int values differ: %q %q", exp.String(), fnd.String()) - } - - return "" -} - -func cmpStructFields(expected, found []StructField) string { - if len(expected) != len(found) { - return fmt.Sprintf("length mis-match: expected struct field length is %v and found is %v", len(expected), len(found)) - } - - for i, exp := range expected { - fnd := found[i] - if exp.Symbol.id != fnd.Symbol.id { - return fmt.Sprintf("field %d: expected symbolID %d but found %d", i, exp.Symbol.id, fnd.Symbol.id) - } - expText, fndText := exp.Symbol.Text(), fnd.Symbol.Text() - if diff := cmp.Diff(expText, fndText); diff != "" { - return fmt.Sprintf("field %d: diff %s", i, diff) - } - if diff := cmpValueSlices([]Value{exp.Value}, []Value{fnd.Value}); diff != "" { - return diff - } - } - - return "" -} - -func cmpTimestamps(expected, found Timestamp) string { - exp, fnd := expected.Value(), found.Value() - if diff := exp.Sub(fnd); diff != 0 { - return fmt.Sprintf("timestamps differ by %v: %q %q", diff, exp, fnd) - } - if diff := cmp.Diff(expected.Precision(), found.Precision()); diff != "" { - return diff - } - return "" -} diff --git a/ion/doc.go b/ion/doc.go deleted file mode 100644 index 5f476ace..00000000 --- a/ion/doc.go +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -// ion is a data format that is comprised of three parts: -// * A set of data types -// * A textual notation for values of those types -// * A binary notation for values of those types -// -// There are many considerations that go into an Ion implementation -// that expand past those basic representations. This includes but -// is not limited to a customizable Symbol Catalog to aid in efficient -// binary decoding and a System Symbol Catalog for symbols defined in -// the specification. -// -// More information can be found from these links -// * http://amzn.github.io/ion-docs/docs/spec.html -package ion diff --git a/ion/parse_binary.go b/ion/parse_binary.go deleted file mode 100644 index 737efcc8..00000000 --- a/ion/parse_binary.go +++ /dev/null @@ -1,380 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "io" - "time" - - "github.com/pkg/errors" -) - -const ( - binaryTypePadding = 0 - binaryTypeBool = 1 - binaryTypeInt = 2 - binaryTypeNegInt = 3 - binaryTypeFloat = 4 - binaryTypeDecimal = 5 - binaryTypeTimestamp = 6 - binaryTypeSymbol = 7 - binaryTypeString = 8 - binaryTypeClob = 9 - binaryTypeBlob = 0xA - binaryTypeList = 0xB - binaryTypeSExp = 0xC - binaryTypeStruct = 0xD - binaryTypeAnnotation = 0xE -) - -var ( - // From http://amzn.github.io/ion-docs/docs/binary.html#value-streams : - // The only valid BVM, identifying Ion 1.0, is 0xE0 0x01 0x00 0xEA - ion10BVM = []byte{0xE0, 0x01, 0x00, 0xEA} - - // Map of the type portion (high nibble) of a value header byte to - // the corresponding Ion Type. Padding itself does not have a type mapping, - // but it shares a high nibble with Null. That means that the type - // is determined by the low nibble when the high nibble is 0. - binaryTypeMap = map[byte]Type{ - binaryTypePadding: TypeNull, - binaryTypeBool: TypeBool, - binaryTypeInt: TypeInt, - binaryTypeNegInt: TypeInt, - binaryTypeFloat: TypeFloat, - binaryTypeDecimal: TypeDecimal, - binaryTypeTimestamp: TypeTimestamp, - binaryTypeSymbol: TypeSymbol, - binaryTypeString: TypeString, - binaryTypeClob: TypeClob, - binaryTypeBlob: TypeBlob, - binaryTypeList: TypeList, - binaryTypeSExp: TypeSExp, - binaryTypeStruct: TypeStruct, - binaryTypeAnnotation: TypeAnnotation, - } -) - -// parseBinaryBlob decodes a single blob of bytes into a Digest. -func parseBinaryBlob(blob []byte) (*Digest, error) { - // Read a single item from the binary stream. Timeout after - // five seconds because it shouldn't take that long to parse - // a blob that is already loaded into memory. - ch := parseBinaryStream(bytes.NewReader(blob)) - select { - case out := <-ch: - return out.Digest, out.Error - case <-time.After(5 * time.Second): - return nil, errors.New("timed out waiting for parser to finish") - } -} - -type streamItem struct { - Digest *Digest - Error error -} - -// parseBinaryStream reads from the given reader until either an error -// is encountered or the reader returns (0, io.EOF), at which point -// the returned channel is closed. If an error occurred then the error -// is sent on the channel before it is closed. Reading from the stream -// is not buffered. -func parseBinaryStream(r io.Reader) <-chan streamItem { - itemChannel := make(chan streamItem) - - go func() { - // First four bytes of the stream must be the version marker. - if err := verifyByteVersionMarker(r); err != nil { - itemChannel <- streamItem{Error: err} - return - } - - var values []Value - for { - switch val, err := parseNextBinaryValue(nil, r); { - case err == io.EOF: - itemChannel <- streamItem{Digest: &Digest{values: values}} - // Signal that there isn't any more data coming. - close(itemChannel) - return - case err != nil: - itemChannel <- streamItem{Error: err} - return - case val != nil: - values = append(values, val) - default: - itemChannel <- streamItem{Digest: &Digest{values: values}} - values = nil - } - } - }() - - return itemChannel -} - -// parseNextBinaryValue parses the next binary value from the stream. It returns -// io.EOF as the error if the first read shows that the end of the stream has been -// reached. It returns a nil value and nil error if a new ByteVersionMarker has -// been reached. ann is an optional list of annotations to associate with the next -// value that is parsed. -func parseNextBinaryValue(ann []Symbol, r io.Reader) (Value, error) { - switch high, low, err := readNibblesHighAndLow(r); { - case err != nil: - return nil, err - case low == 0xF: - return parseBinaryNull(high) - case high == binaryTypePadding: - return parseBinaryPadding(low, r) - case high == binaryTypeBool: - return parseBinaryBool(ann, low) - case high == binaryTypeInt || high == binaryTypeNegInt: - // 2 = positive integer, 3 = negative integer. - return parseBinaryInt(ann, high == binaryTypeNegInt, low, r) - case high == binaryTypeFloat: - return parseBinaryFloat(ann, low, r) - case high == binaryTypeDecimal: - return parseBinaryDecimal(ann, low, r) - case high == binaryTypeTimestamp: - return parseBinaryTimestamp(ann, low, r) - case high == binaryTypeSymbol: - return parseBinarySymbol(ann, low, r) - case high == binaryTypeString: - return parseBinaryString(ann, low, r) - case high == binaryTypeBlob || high == binaryTypeClob: - return parseBinaryBytes(ann, high, low, r) - case high == binaryTypeList || high == binaryTypeSExp: - return parseBinaryList(ann, high, low, r) - case high == binaryTypeStruct: - return parseBinaryStruct(ann, low, r) - case high == binaryTypeAnnotation: - if len(ann) != 0 { - return nil, errors.New("nesting annotations is not legal") - } - return parseBinaryAnnotation(low, r) - default: - return nil, errors.Errorf("invalid header combination - high: %d low: %d", high, low) - } -} - -// parseBinaryVersionMarker verifies that what is read next is a valid BVM. If it is -// then a nil Value and error are returned. It is assumed that the first byte has already -// been read and that it's value is 0xE0. -func parseBinaryVersionMarker(r io.Reader) (Value, error) { - numBytes := len(ion10BVM) - bvm := make([]byte, numBytes) - bvm[0] = ion10BVM[0] - if n, err := r.Read(bvm[1:]); err != nil || n != numBytes-1 { - return nil, errors.Errorf("unable to read binary version marker - read %d bytes of %d with err: %v", n, numBytes-1, err) - } - - if err := verifyByteVersionMarker(bytes.NewReader(bvm)); err != nil { - return nil, err - } - - return nil, nil -} - -// verifyByteVersionMarker reads the BVM from the stream and ensures that it matches -// what is expected for Ion 1.0. -func verifyByteVersionMarker(r io.Reader) error { - buf := make([]byte, 4) - // First four bytes must be the version marker. - if n, err := r.Read(buf); err != nil || n != 4 { - return errors.Errorf("read %d bytes of binary version marker with err: %v", n, err) - } - if bytes.Compare(buf, ion10BVM) != 0 { - return errors.Errorf("invalid binary version marker: %0 #x", buf) - } - return nil -} - -// determineLength16 takes in the length nibble from the header byte and determines -// whether or not there is a Length portion to the value. If there is, it then -// reads the length portion. -func determineLength16(lengthByte byte, r io.Reader) (uint16, error) { - // "If the representation is at least 14 bytes long, then L is set to 14, and - // the length field is set to the representation length, encoded as a VarUInt field." - if lengthByte != 14 { - return uint16(lengthByte), nil - } - return readVarUInt16(r) -} - -// determineLength32 takes in the length nibble from the header byte and determines -// whether or not there is a Length portion to the value. If there is, it then -// reads the length portion. -func determineLength32(lengthByte byte, r io.Reader) (uint32, error) { - // "If the representation is at least 14 bytes long, then L is set to 14, and - // the length field is set to the representation length, encoded as a VarUInt field." - if lengthByte != 14 { - return uint32(lengthByte), nil - } - return readVarUInt32(r) -} - -// readVarUInt8 reads a variable-length number, but assumes that variable-length number is -// only one byte. It then converts that byte into an uint8 for return. -func readVarUInt8(r io.Reader) (uint8, error) { - bits, err := readVarNumber(1, r) - if err != nil { - return 0, err - } - - // Ignore the stop bit. - return bits[0] & 0x7F, nil -} - -// readVarUInt16 reads until the high bit is set, which signals the end of the -// variable-length number, or we hit the maximum number of bytes for uint16. -// We compress the number it into a uint16. When used to express a length in -// bytes, the max value would signal a size of 63KB. -func readVarUInt16(r io.Reader) (uint16, error) { - // Since we are being given seven bits per byte, we can fit 2 1/4 bytes - // of input into our two bytes of value, so don't read more than 3 bytes. - bits, err := readVarNumber(3, r) - if err != nil { - return 0, err - } - - // 0xFC == 0b1111_1100. - if (len(bits) == 3) && (bits[2]&0xFC != 0) { - return 0, errors.Errorf("number is too big to fit into uint16: % #x", bits) - } - - var ret uint16 - // Compact all of the bits into a uint16, ignoring the stop bit. - // Turn [0111 0001] [1110 0001] into [0011 1000] [1110 0001]. - for _, b := range bits { - ret <<= 7 - ret |= uint16(b & 0x7F) - } - - return ret, nil -} - -// readVarUInt32 reads until the high bit is set, which signals the end of the -// variable-length number, or we hit the maximum number of bytes for uint32. -// We compress the number into a uint32. When used to express a length in -// bytes, the max value would signal a size of 3GB. -func readVarUInt32(r io.Reader) (uint32, error) { - // Since we are being given seven bits per byte, we can fit 4 1/2 bytes - // of input into our four bytes of value, so don't read more than 5 bytes. - bits, err := readVarNumber(5, r) - if err != nil { - return 0, err - } - - // 0xF0 == 0b1111_0000. - if (len(bits) == 5) && (bits[4]&0xF0 != 0) { - return 0, errors.Errorf("number is too big to fit into uint32: % #x", bits) - } - - var ret uint32 - // Compact all of the bits into a uint32, ignoring the stop bit. - // Turn [0111 1111] [1110 1111] into [0011 1111] [1110 1111]. - for _, b := range bits { - ret <<= 7 - ret |= uint32(b & 0x7F) - } - - return ret, nil -} - -// readVarInt64 reads until the high bit is set, which signals the end of the -// variable-length number, or we hit the maximum number of bytes for int64. -// We compress the number into an int64. -func readVarInt64(r io.Reader) (int64, error) { - // Since we are being given seven bits per byte, we can fit 9 1/8 bytes - // of input into our eight bytes of value, so don't read more than 10 bytes. - bits, err := readVarNumber(10, r) - if err != nil { - return 0, err - } - - // 0xFE == 0b1111_1110. - if (len(bits) == 10) && (bits[9]&0xFE != 0) { - return 0, errors.Errorf("number is too big to fit into int64: % #x", bits) - } - - var ret int64 - // Compact all of the bits into an int64, ignoring the stop bit. - // Turn [0111 1111] [1110 1111] into [0011 1111] [1110 1111]. - for i, b := range bits { - ret <<= 7 - // Need to ignore the sign bit. We add the sign later. - if i == 0 { - ret |= int64(b & 0x3F) - } else { - ret |= int64(b & 0x7F) - } - } - - // The second bit of the number is the sign bit. - if bits[0]&0x40 != 0 { - ret *= -1 - } - - return ret, nil -} - -// readVarNumber reads until the high bit is set, which signals the end of the -// variable-length number, or maxBytes is hit. If maxBytes is reached without -// the number being terminated, then an error is returned. The bits are not modified. -func readVarNumber(maxBytes uint16, r io.Reader) ([]byte, error) { - buf := make([]byte, 1) - var bits []byte - for { - if n, err := r.Read(buf); err != nil || n != 1 { - return nil, errors.Errorf("read %d bytes (wanted one) of number with err: %v", n, err) - } - bits = append(bits, buf[0]) - if (buf[0] & 0x80) != 0 { - break - } - if uint16(len(bits)) >= maxBytes { - return nil, errors.Errorf("number not terminated after %d bytes", maxBytes) - } - } - - return bits, nil -} - -// readNibblesHighAndLow reads one byte from the given reader then returns -// the high nibble and the low nibble of that byte. If read encounters the -// error io.EOF, then that error is returned. -func readNibblesHighAndLow(r io.Reader) (byte, byte, error) { - buf := make([]byte, 1) - if n, err := r.Read(buf); err != nil || n != 1 { - if err == io.EOF { - return 0, 0, err - } - return 0, 0, errors.Wrapf(err, "read %d bytes when wanted to read the one byte header", n) - } - return highNibble(buf[0]), lowNibble(buf[0]), nil -} - -// highNibble returns a byte representation of the high-order nibble -// (half a byte) of the given byte. -func highNibble(b byte) byte { - return (b >> 4) & 0x0F -} - -// lowNibble returns a byte representation of the low-order nibble -// (half a byte) of the given byte. -func lowNibble(b byte) byte { - return b & 0x0F -} diff --git a/ion/parse_binary_container.go b/ion/parse_binary_container.go deleted file mode 100644 index bfbab983..00000000 --- a/ion/parse_binary_container.go +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "io" - - "github.com/pkg/errors" -) - -// This file contains binary parsers for List, SExp, Struct, and Annotation. - -// parseBinaryList attempts to read and parse the entirety of the list whether -// it be a List (high == binaryTypeList) or SExp (high == binaryTypeSExp). -func parseBinaryList(ann []Symbol, high byte, lengthByte byte, r io.Reader) (Value, error) { - if lengthByte == 0 && high == binaryTypeList { - return List{annotations: ann, values: []Value{}}, nil - } - if lengthByte == 0 && high == binaryTypeSExp { - return SExp{annotations: ann, values: []Value{}}, nil - } - - numBytes, errLength := determineLength32(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of list") - } - - data := make([]byte, numBytes) - if n, err := r.Read(data); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read list - read %d bytes of %d with err: %v", n, numBytes, err) - } - - var values []Value - dataReader := bytes.NewReader(data) - for dataReader.Len() > 0 { - value, err := parseNextBinaryValue(nil, dataReader) - if err != nil { - return nil, errors.WithMessage(err, "unable to parse list") - } - values = append(values, value) - } - - if high == binaryTypeList { - return List{annotations: ann, values: values}, nil - } - return SExp{annotations: ann, values: values}, nil -} - -// parseBinaryStruct reads all of the symbol / value pairs and puts them -// into a Struct. -func parseBinaryStruct(ann []Symbol, lengthByte byte, r io.Reader) (Value, error) { - if lengthByte == 0 { - return Struct{annotations: ann, fields: []StructField{}}, nil - } - - var numBytes uint32 - var errLength error - // "When L is 1, the struct has at least one symbol/value pair, the length - // field exists, and the field name integers are sorted in increasing order." - if lengthByte == 1 { - numBytes, errLength = readVarUInt32(r) - } else { - numBytes, errLength = determineLength32(lengthByte, r) - } - - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of struct") - } - - data := make([]byte, numBytes) - if n, err := r.Read(data); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read struct - read %d bytes of %d with err: %v", n, numBytes, err) - } - - // Not having any fields isn't the same as being null, so differentiate - // between the two by ensuring that fields isn't nil even if it's empty. - fields := []StructField{} - dataReader := bytes.NewReader(data) - for dataReader.Len() > 0 { - symbol, errSymbol := readVarUInt32(dataReader) - if errSymbol != nil { - return nil, errors.WithMessage(errSymbol, "unable to read struct field symbol") - } - value, errValue := parseNextBinaryValue(nil, dataReader) - if errValue != nil { - return nil, errors.WithMessage(errValue, "unable to read struct field value") - } - - // Ignore padding. - if value.Type() == TypePadding { - continue - } - - fields = append(fields, StructField{ - Symbol: Symbol{id: int32(symbol)}, - Value: value, - }) - } - - return Struct{annotations: ann, fields: fields}, nil -} - -// parseBinaryAnnotation reads the annotation and the value that it is -// annotating. If the lengthByte is zero, then this is treated as the -// first byte of a Binary Version Marker. -func parseBinaryAnnotation(lengthByte byte, r io.Reader) (Value, error) { - // 0xE as the high byte has two potential uses, one for annotations and one for the - // start of the binary version marker. We are going to be optimistic and assume that - // 0xE0 is for the BVM and all other values for the low nibble is for annotations. - if lengthByte == 0 { - return parseBinaryVersionMarker(r) - } - - if lengthByte < 3 { - return nil, errors.Errorf("length must be at least 3 for an annotation wrapper, found %d", lengthByte) - } - - numBytes, errLength := determineLength32(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of annotation") - } - - data := make([]byte, numBytes) - if n, err := r.Read(data); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read annotation - read %d bytes of %d with err: %v", n, numBytes, err) - } - - dataReader := bytes.NewReader(data) - annLen, errAnnLen := readVarUInt16(dataReader) - if errAnnLen != nil { - return nil, errors.WithMessage(errAnnLen, "unable to determine annotation symbol length") - } - - if annLen == 0 || uint32(annLen) >= numBytes { - return nil, errors.Errorf("invalid lengths for annotation - field length is %d while annotation symbols length is %d", numBytes, annLen) - } - - annData := make([]byte, annLen) - // We've already verified lengths and are basically performing a copy to - // a pre-allocated byte slice. There is no error to catch. - _, _ = dataReader.Read(annData) - - annReader := bytes.NewReader(annData) - var annotations []Symbol - for annReader.Len() > 0 { - symbol, errSymbol := readVarUInt32(annReader) - if errSymbol != nil { - return nil, errors.WithMessage(errSymbol, "unable to read annotation symbol") - } - annotations = append(annotations, Symbol{id: int32(symbol)}) - } - - // Since an annotation is a container for a single value there isn't a need to - // pre-read the contents so that we know when to stop. - value, errValue := parseNextBinaryValue(annotations, dataReader) - if errValue != nil { - return nil, errors.WithMessage(errValue, "unable to read annotation value") - } - - if dataReader.Len() > 0 { - return nil, errors.Errorf("annotation declared %d bytes but there are %d bytes left", numBytes, dataReader.Len()) - } - - if _, ok := value.(padding); ok { - return nil, errors.New("annotation on padding is not legal") - } - - return value, nil -} diff --git a/ion/parse_binary_numeric.go b/ion/parse_binary_numeric.go deleted file mode 100644 index 060d01fe..00000000 --- a/ion/parse_binary_numeric.go +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "io" - "math" - "math/big" - "time" - - "github.com/pkg/errors" -) - -// This file contains binary parsers for Int, Float, Decimal, and Timestamp. - -// parseBinaryInt parses the magnitude and optional length portion of an an Int. -// The magnitude is a UInt so we need to be told what the sign is. -func parseBinaryInt(ann []Symbol, isNegative bool, lengthByte byte, r io.Reader) (Value, error) { - if lengthByte == 0 { - if isNegative { - return nil, errors.New("negative zero is invalid") - } - return Int{annotations: ann, isSet: true, binary: []byte{}, value: &big.Int{}}, nil - } - - numBytes, errLength := determineLength32(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of int") - } - - buf := make([]byte, numBytes) - if n, err := r.Read(buf); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read int - read %d bytes of %d with err: %v", n, numBytes, err) - } - - // Negative zero is not valid. - if isNegative { - isZero := true - for _, b := range buf { - if b != 0 { - isZero = false - break - } - } - if isZero { - return nil, errors.Errorf("negative zero is invalid") - } - } - - return Int{annotations: ann, isSet: true, isNegative: isNegative, binary: buf}, nil -} - -// parseBinaryFloat parses either the 32-bit or 64-bit version of of an IEEE-754 floating -// point number. -func parseBinaryFloat(ann []Symbol, numBytes byte, r io.Reader) (Value, error) { - // Represents 0e0. - if numBytes == 0 { - return Float{annotations: ann, isSet: true, binary: []byte{}}, nil - } - if numBytes != 4 && numBytes != 8 { - return nil, errors.Errorf("invalid float length %d", numBytes) - } - buf := make([]byte, numBytes) - if n, err := r.Read(buf); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read float - read %d bytes of %d with err: %v", n, numBytes, err) - } - return Float{annotations: ann, isSet: true, binary: buf}, nil -} - -// parseBinaryDecimal parses a variable length Decimal with exponent and coefficient components. -func parseBinaryDecimal(ann []Symbol, lengthByte byte, r io.Reader) (Value, error) { - // Represents 0d0. - if lengthByte == 0 { - return Decimal{annotations: ann, isSet: true, binary: []byte{}}, nil - } - - numBytes, errLength := determineLength16(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of decimal") - } - - // Read in the entirety of the Decimal value from the stream, then farm out those - // bytes to read the exponent and coefficient to ensure that we have a valid decimal. - data := make([]byte, numBytes) - if n, err := r.Read(data); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read decimal - read %d bytes of %d with err: %v", n, numBytes, err) - } - - dataReader := bytes.NewReader(data) - expBytes, errExp := readVarNumber(numBytes, dataReader) - if errExp != nil { - return nil, errors.WithMessage(errExp, "unable to read exponent part of decimal") - } - - coefficientLength := numBytes - uint16(len(expBytes)) - if coefficientLength <= 0 { - return nil, errors.Errorf("invalid decimal - total length %d with exponent length %d", numBytes, len(expBytes)) - } - - return Decimal{annotations: ann, isSet: true, binary: data}, nil -} - -// parseBinaryTimestamp parses a timestamp comprised of a required year and offset with -// optional month, day, hour, minute, second, and fractional sub-second components. -func parseBinaryTimestamp(ann []Symbol, lengthByte byte, r io.Reader) (Value, error) { - numBytes, errLength := determineLength16(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of timestamp") - } - - /* - Timestamp value | 6 | L | - +---------+---------+========+ - : length [VarUInt] : - +----------------------------+ - | offset [VarInt] | - +----------------------------+ - | year [VarUInt] | - +----------------------------+ - : month [VarUInt] : - +============================+ - : day [VarUInt] : - +============================+ - : hour [VarUInt] : - +==== ====+ - : minute [VarUInt] : - +============================+ - : second [VarUInt] : - +============================+ - : fraction_exponent [VarInt] : - +============================+ - : fraction_coefficient [Int] : - +============================+ - */ - // Sanity check. Don't try to parse a timestamp of unreasonable length. - // offset (2) year (2) month (1) day (1) hour (1) minute (1) second (1) exponent (1) coefficient (2). - maxLength := 2 + 2 + 1 + 1 + 1 + 1 + 1 + 1 + 2 - if numBytes > uint16(maxLength) { - return nil, errors.Errorf("timestamp length of %d exceeds expected maximum of %d", numBytes, maxLength) - } - - // Offset = at least 1 byte, Year = at least 1 byte. - if numBytes < 2 { - return nil, errors.Errorf("timestamp must have a length of at least two bytes") - } - - // Read in the entirety of the Timestamp value from the stream, then farm out those - // bytes to read the constituent parts to ensure that we have a valid timestamp. - data := make([]byte, numBytes) - if n, err := r.Read(data); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read timestamp - read %d bytes of %d with err: %v", n, numBytes, err) - } - - dataReader := bytes.NewReader(data) - offset, errOffset := readVarInt64(dataReader) - if errOffset != nil { - return nil, errors.WithMessage(errOffset, "unable to determine timestamp offset") - } - if offset >= 1440 || offset <= -1440 { - return nil, errors.Errorf("invalid timestamp offset %d", offset) - } - - year, errYear := readVarUInt16(dataReader) - if errYear != nil { - return nil, errors.WithMessage(errYear, "unable to determine timestamp year") - } - if year > 9999 { - return nil, errors.Errorf("invalid year %d", year) - } - precision := TimestampPrecisionYear - month, day, hour, minute, sec, nsec := uint8(1), uint8(1), uint8(0), uint8(0), uint8(0), uint32(0) - - var err error - if dataReader.Len() > 0 { - precision = TimestampPrecisionMonth - if month, err = readVarUInt8(dataReader); err != nil { - return nil, errors.WithMessage(err, "unable to determine timestamp month") - } - if month > 12 { - return nil, errors.Errorf("invalid month %d", month) - } - } - - if dataReader.Len() > 0 { - precision = TimestampPrecisionDay - if day, err = readVarUInt8(dataReader); err != nil { - return nil, errors.WithMessage(err, "unable to determine timestamp day") - } - if day > 31 { - return nil, errors.Errorf("invalid day %d", day) - } - } - - if dataReader.Len() == 1 { - // "The hour and minute is considered as a single component, that is, it is illegal - // to have hour but not minute (and vice versa)." - return nil, errors.New("invalid timestamp - cannot specify hours without minutes") - } - - if dataReader.Len() > 0 { - precision = TimestampPrecisionMinute - if hour, err = readVarUInt8(dataReader); err != nil { - return nil, errors.WithMessage(err, "unable to determine timestamp hour") - } - if hour > 23 { - return nil, errors.Errorf("invalid hour %d", hour) - } - - if minute, err = readVarUInt8(dataReader); err != nil { - return nil, errors.WithMessage(err, "unable to determine timestamp minute") - } - if minute > 59 { - return nil, errors.Errorf("invalid minute %d", minute) - } - } - - if dataReader.Len() > 0 { - precision = TimestampPrecisionSecond - if sec, err = readVarUInt8(dataReader); err != nil || sec > 59 { - return nil, errors.WithMessage(err, "unable to determine timestamp second") - } - if sec > 59 { - return nil, errors.Errorf("invalid second %d", sec) - } - } - - var exponent int8 - if dataReader.Len() > 0 { - // "The fraction_exponent and fraction_coefficient denote the fractional seconds - // of the timestamp as a decimal value. The fractional seconds’ value is - // coefficient * 10 ^ exponent. It must be greater than or equal to zero and less - // than 1. A missing coefficient defaults to zero. Fractions whose coefficient is - // zero and exponent is greater than -1 are ignored." - // - // We expect the exponent to be a single byte. That is able to cover a precision - // of up to 63 digits which is excessive. - exp, errExp := readVarUInt8(dataReader) - if errExp != nil { - return nil, errors.WithMessage(errExp, "unable to determine timestamp fractional second exponent") - } - exponent = int8(exp) & 0x3F - if exp&0x40 != 0 { - exponent *= -1 - } - - switch { - case exponent > -1: - // "Fractions whose coefficient is zero and exponent is greater than -1 are ignored." - if dataReader.Len() == 0 { - precision = TimestampPrecisionSecond - } - case exponent == -1: - precision = TimestampPrecisionMillisecond1 - case exponent == -2: - precision = TimestampPrecisionMillisecond2 - case exponent == -3: - precision = TimestampPrecisionMillisecond3 - case exponent == -4: - precision = TimestampPrecisionMillisecond4 - case exponent == -5: - precision = TimestampPrecisionMicrosecond1 - case exponent == -6: - precision = TimestampPrecisionMicrosecond2 - case exponent == -7: - precision = TimestampPrecisionMicrosecond3 - case exponent == -8: - precision = TimestampPrecisionMicrosecond4 - default: - return nil, errors.Errorf("invalid exponent for timestamp fractional second: %#x", exp) - } - } - - if dataReader.Len() > 0 { - coBytes := make([]byte, dataReader.Len()) - // We've already verified lengths and are basically performing a copy to - // a pre-allocated byte slice. There is no error to catch. - _, _ = dataReader.Read(coBytes) - - var coefficient int16 - switch len(coBytes) { - case 1: - coefficient = int16(coBytes[0] & 0x7F) - case 2: - coefficient = int16(coBytes[0] & 0x7F) - coefficient <<= 8 - coefficient |= int16(coBytes[1]) - } - if coBytes[0]&0x80 != 0 { - coefficient *= -1 - } - - switch { - case coefficient < 0: - // "It must be greater than or equal to zero and less than 1." - // A negative coefficient can't be greater than or equal to zero. - return nil, errors.Errorf("negative coefficient is not legal") - case coefficient == 0 && exponent > -1: - // "Fractions whose coefficient is zero and exponent is greater than -1 are ignored." - precision = TimestampPrecisionSecond - default: - fraction := math.Pow10(int(exponent)) * float64(coefficient) - if fraction >= 1 || fraction < 0 { - return nil, errors.Errorf("invalid fractional seconds: %F", fraction) - } - nsec = uint32(fraction * float64(time.Second)) - } - } - - // Ignore the offset if we don't have a time component. - if precision <= TimestampPrecisionDay { - offset = 0 - } - - loc := time.FixedZone("", int(offset)) - timestamp := time.Date(int(year), time.Month(month), int(day), int(hour), int(minute), int(sec), int(nsec), loc) - // time.Date does a translation, with invalid month/day combinations (e.g. April 31) being - // reflected as overflows into the next month (e.g. May 1). - if timestamp.Month() != time.Month(month) { - return nil, errors.Errorf("invalid year / month / day combination: %d %d %d", year, month, day) - } - - return Timestamp{annotations: ann, precision: precision, binary: data, value: timestamp}, nil -} diff --git a/ion/parse_binary_simple.go b/ion/parse_binary_simple.go deleted file mode 100644 index b3a701e0..00000000 --- a/ion/parse_binary_simple.go +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "io" - - "github.com/pkg/errors" -) - -// This file contains binary parsers for Null, Padding, Bool, Symbol, String, Blob, and Clob. - -// parseBinaryNull returns the null value for the given type. -func parseBinaryNull(typ byte) (Value, error) { - if it, ok := binaryTypeMap[typ]; ok { - return Null{typ: it}, nil - } - return nil, errors.Errorf("invalid type value for null: %d", typ) -} - -// parseBinaryPadding returns a padding Value while consuming the padding -// value number of bytes. -func parseBinaryPadding(lengthByte byte, r io.Reader) (Value, error) { - // Special case the "zero" length padding since we don't read anything for it. - // The zero is in quotes since the marker of this is itself a byte of padding. - if lengthByte == 0 { - return padding{}, nil - } - - numBytes, errLength := determineLength16(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of padding") - } - - buf := make([]byte, numBytes) - if n, err := r.Read(buf); err != nil || n != int(numBytes) { - return nil, errors.Errorf("read %d of expected %d padding bytes with err: %v", n, numBytes, err) - } - return padding{binary: buf}, nil -} - -// parseBinaryBool returns the Bool Value for the given representation. -// 1 == true and 0 == false. Null is handled by parseBinaryNull. -func parseBinaryBool(ann []Symbol, rep byte) (Value, error) { - switch rep { - case 0: - return Bool{annotations: ann, isSet: true, value: false}, nil - case 1: - return Bool{annotations: ann, isSet: true, value: true}, nil - default: - return nil, errors.Errorf("invalid bool representation %#x", rep) - } -} - -// parseBinarySymbol parses an integer Symbol ID. -func parseBinarySymbol(ann []Symbol, lengthByte byte, r io.Reader) (Value, error) { - if lengthByte == 0 { - return Symbol{annotations: ann}, nil - } - - numBytes, errLength := determineLength16(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of symbol") - } - - // Sanity check the number of bytes expected for the Symbol ID. - // If it takes more than 4 bytes of UInt, then it won't fit into int32. - if numBytes > 4 { - return nil, errors.Errorf("symbol ID length of %d bytes exceeds expected maximum of 4", numBytes) - } - - buf := make([]byte, numBytes) - if n, err := r.Read(buf); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read symbol ID - read %d bytes of %d with err: %v", n, numBytes, err) - } - - var symbolID uint32 - for _, b := range buf { - symbolID <<= 8 - symbolID |= uint32(b) - } - - // math.MaxInt32 = 0x7F_FF_FF_FF - if (symbolID & 0x80000000) != 0 { - return nil, errors.Errorf("uint32 value %d overflows int32", symbolID) - } - - return Symbol{annotations: ann, id: int32(symbolID)}, nil -} - -// parseBinaryString reads the UTF-8 encoded string. -func parseBinaryString(ann []Symbol, lengthByte byte, r io.Reader) (Value, error) { - if lengthByte == 0 { - return String{annotations: ann, text: []byte{}}, nil - } - - numBytes, errLength := determineLength16(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of string") - } - - buf := make([]byte, numBytes) - if n, err := r.Read(buf); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read string - read %d bytes of %d with err: %v", n, numBytes, err) - } - - return String{annotations: ann, text: buf}, nil -} - -// parseBinaryBytes reads the unencoded bytes, whether it be a Blob (high == binaryTypeBlob) -// or Clob (high == binaryTypeClob). -func parseBinaryBytes(ann []Symbol, high byte, lengthByte byte, r io.Reader) (Value, error) { - numBytes, errLength := determineLength32(lengthByte, r) - if errLength != nil { - return nil, errors.WithMessage(errLength, "unable to parse length of bytes") - } - - buf := make([]byte, numBytes) - if numBytes != 0 { - if n, err := r.Read(buf); err != nil || n != int(numBytes) { - return nil, errors.Errorf("unable to read bytes - read %d bytes of %d with err: %v", n, numBytes, err) - } - } - - if high == binaryTypeClob { - return Clob{annotations: ann, text: buf}, nil - } - return Blob{annotations: ann, binary: buf}, nil -} diff --git a/ion/parse_binary_test.go b/ion/parse_binary_test.go deleted file mode 100644 index a960934e..00000000 --- a/ion/parse_binary_test.go +++ /dev/null @@ -1,1109 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "errors" - "fmt" - "io/ioutil" - "math" - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -func TestParseBinaryBlob(t *testing.T) { - tests := []struct { - name string - blob []byte - expected *Digest - expectedErr error - }{ - { - name: "no bytes", - blob: []byte{}, - expectedErr: errors.New("read 0 bytes of binary version marker with err: EOF"), - }, - { - name: "only three bytes", - blob: ion10BVM[:3], - expectedErr: errors.New("read 3 bytes of binary version marker with err: "), - }, - { - name: "byte version marker for Ion 2.0", - blob: []byte{0xE0, 0x02, 0x00, 0xEA}, - expectedErr: errors.New("invalid binary version marker: 0xe0 0x02 0x00 0xea"), - }, - { - name: "no data after BVM", - blob: ion10BVM, - expected: &Digest{}, - }, - { - name: "unsupported type", - blob: append(ion10BVM, 0xF0), - expectedErr: errors.New("invalid header combination - high: 15 low: 0"), - }, - - // Null. - - { - name: "null.null", - blob: append(ion10BVM, 0x0F), - expected: &Digest{values: []Value{Null{}}}, - }, - { - name: "null.bool", - blob: append(ion10BVM, 0x1F), - expected: &Digest{values: []Value{Null{typ: TypeBool}}}, - }, - { - name: "null.int", - blob: append(ion10BVM, 0x2F), - expected: &Digest{values: []Value{Null{typ: TypeInt}}}, - }, - { - name: "null.int (negative)", - blob: append(ion10BVM, 0x3F), - expected: &Digest{values: []Value{Null{typ: TypeInt}}}, - }, - { - name: "null.float", - blob: append(ion10BVM, 0x4F), - expected: &Digest{values: []Value{Null{typ: TypeFloat}}}, - }, - { - name: "null.decimal", - blob: append(ion10BVM, 0x5F), - expected: &Digest{values: []Value{Null{typ: TypeDecimal}}}, - }, - { - name: "null.timestamp", - blob: append(ion10BVM, 0x6F), - expected: &Digest{values: []Value{Null{typ: TypeTimestamp}}}, - }, - { - name: "null.symbol", - blob: append(ion10BVM, 0x7F), - expected: &Digest{values: []Value{Null{typ: TypeSymbol}}}, - }, - { - name: "null.string", - blob: append(ion10BVM, 0x8F), - expected: &Digest{values: []Value{Null{typ: TypeString}}}, - }, - { - name: "null.clob", - blob: append(ion10BVM, 0x9F), - expected: &Digest{values: []Value{Null{typ: TypeClob}}}, - }, - { - name: "null.blob", - blob: append(ion10BVM, 0xAF), - expected: &Digest{values: []Value{Null{typ: TypeBlob}}}, - }, - { - name: "null.list", - blob: append(ion10BVM, 0xBF), - expected: &Digest{values: []Value{Null{typ: TypeList}}}, - }, - { - name: "null.sexp", - blob: append(ion10BVM, 0xCF), - expected: &Digest{values: []Value{Null{typ: TypeSExp}}}, - }, - { - name: "null.struct", - blob: append(ion10BVM, 0xDF), - expected: &Digest{values: []Value{Null{typ: TypeStruct}}}, - }, - - // Padding and Bool. - - { - name: "zero length padding", - blob: append(ion10BVM, 0x00), - expected: &Digest{values: []Value{padding{}}}, - }, - { - name: "two bytes of padding", - blob: append(ion10BVM, 0x01, 0xFF), - expected: &Digest{values: []Value{padding{[]byte{0xFF}}}}, - }, - { - name: "sixteen bytes of padding", - blob: append(ion10BVM, 0x0E, 0x8E, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF), - expected: &Digest{values: []Value{padding{[]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}}}}, - }, - { - name: "bool false", - blob: append(ion10BVM, 0x10), - expected: &Digest{values: []Value{Bool{isSet: true}}}, - }, - { - name: "bool true", - blob: append(ion10BVM, 0x11), - expected: &Digest{values: []Value{Bool{isSet: true, value: true}}}, - }, - { - name: "bool invalid representation", - blob: append(ion10BVM, 0x12), - expectedErr: errors.New("invalid bool representation 0x2"), - }, - - // Symbol and String. - - { - name: "zero length symbol", - blob: append(ion10BVM, 0x70), - expected: &Digest{values: []Value{Symbol{}}}, - }, - { - name: "symbol - length is too big", - blob: append(ion10BVM, 0x75), - expectedErr: errors.New("symbol ID length of 5 bytes exceeds expected maximum of 4"), - }, - { - name: "symbolID too large", - blob: append(ion10BVM, 0x74, 0x80, 0x00, 0x00, 0x00), - expectedErr: errors.New("uint32 value 2147483648 overflows int32"), - }, - { - name: "symbolID max value", - blob: append(ion10BVM, 0x74, 0x7F, 0xFF, 0xFF, 0xFF), - expected: &Digest{values: []Value{Symbol{id: math.MaxInt32}}}, - }, - { - name: "zero length string", - blob: append(ion10BVM, 0x80), - expected: &Digest{values: []Value{String{text: []byte{}}}}, - }, - { - name: "short string", - blob: append(ion10BVM, 0x83, 0x62, 0x6f, 0x6f), - expected: &Digest{values: []Value{String{text: []byte{'b', 'o', 'o'}}}}, - }, - { - name: "long string", - blob: append(ion10BVM, 0x8E, 0x8E, 0x62, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f), - expected: &Digest{values: []Value{String{text: []byte{'b', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o'}}}}, - }, - - // Blob and Clob - - { - name: "zero length blob", - blob: append(ion10BVM, 0xA0), - expected: &Digest{values: []Value{Blob{binary: []byte{}}}}, - }, - { - name: "zero length clob", - blob: append(ion10BVM, 0x90), - expected: &Digest{values: []Value{Clob{text: []byte{}}}}, - }, - { - name: "short blob", - blob: append(ion10BVM, 0xA3, 0x62, 0x6f, 0x6f), - expected: &Digest{values: []Value{Blob{binary: []byte{'b', 'o', 'o'}}}}, - }, - { - name: "short clob", - blob: append(ion10BVM, 0x93, 0x62, 0x6f, 0x6f), - expected: &Digest{values: []Value{Clob{text: []byte{'b', 'o', 'o'}}}}, - }, - { - name: "long blob", - blob: append(ion10BVM, 0xAE, 0x8E, 0x62, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f), - expected: &Digest{values: []Value{Blob{binary: []byte{'b', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o'}}}}, - }, - { - name: "long clob", - blob: append(ion10BVM, 0x9E, 0x8E, 0x62, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f, 0x6f), - expected: &Digest{values: []Value{Clob{text: []byte{'b', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o'}}}}, - }, - - // Positive and negative Int. - - { - name: "zero length positive int", - blob: append(ion10BVM, 0x20), - // Expected text value is not verified. - expected: &Digest{values: []Value{Int{isSet: true, binary: []byte{}, text: []byte{0x30}}}}, - }, - { - name: "zero length negative int", - blob: append(ion10BVM, 0x30), - expectedErr: errors.New("negative zero is invalid"), - }, - { - name: "int - negative zero", - blob: append(ion10BVM, 0x31, 0x00), - expectedErr: errors.New("negative zero is invalid"), - }, - { - name: "int - length is too high", - blob: append(ion10BVM, 0x2E, 0x10, 0x00, 0x00, 0x00, 0x80), - expectedErr: errors.New("unable to parse length of int: number is too big to fit into uint32: 0x10 0x00 0x00 0x00 0x80"), - }, - { - name: "short positive int", - blob: append(ion10BVM, 0x21, 0x42), - expected: &Digest{values: []Value{Int{isSet: true, binary: []byte{0x42}}}}, - }, - { - name: "short negative int", - blob: append(ion10BVM, 0x31, 0x42), - expected: &Digest{values: []Value{Int{isSet: true, isNegative: true, binary: []byte{0x42}}}}, - }, - { - name: "long positive int", - blob: append(ion10BVM, 0x2E, 0x8E, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e), - expected: &Digest{values: []Value{Int{isSet: true, binary: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}}}}, - }, - { - name: "long negative int", - blob: append(ion10BVM, 0x3E, 0x8E, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e), - expected: &Digest{values: []Value{Int{isSet: true, isNegative: true, binary: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}}}}, - }, - - // Float and Decimal. - - { - name: "zero length float", - blob: append(ion10BVM, 0x40), - expected: &Digest{values: []Value{Float{isSet: true, binary: []byte{}}}}, - }, - { - name: "float - invalid length", - blob: append(ion10BVM, 0x42), - expectedErr: errors.New("invalid float length 2"), - }, - { - name: "float - 4 byte value", - blob: append(ion10BVM, 0x44, 0x01, 0x02, 0x03, 0x04), - expected: &Digest{values: []Value{Float{isSet: true, binary: []byte{0x01, 0x02, 0x03, 0x04}}}}, - }, - { - name: "float - 8 byte value", - blob: append(ion10BVM, 0x48, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08), - expected: &Digest{values: []Value{Float{isSet: true, binary: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}}}}, - }, - { - name: "zero length decimal", - blob: append(ion10BVM, 0x50), - expected: &Digest{values: []Value{Decimal{isSet: true, binary: []byte{}}}}, - }, - { - name: "decimal - exponent takes up all the bytes", - blob: append(ion10BVM, 0x52, 0x00, 0x80), - expectedErr: errors.New("invalid decimal - total length 2 with exponent length 2"), - }, - { - name: "decimal - exponent isn't terminated", - blob: append(ion10BVM, 0x5E, 0x8E, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e), - expectedErr: errors.New("unable to read exponent part of decimal: number not terminated after 14 bytes"), - }, - { - name: "decimal - length is too high", - blob: append(ion10BVM, 0x5E, 0x04, 0x00, 0x80), - expectedErr: errors.New("unable to parse length of decimal: number is too big to fit into uint16: 0x04 0x00 0x80"), - }, - { - name: "short decimal", - blob: append(ion10BVM, 0x52, 0x80, 0x08), - expected: &Digest{values: []Value{Decimal{isSet: true, binary: []byte{0x80, 0x08}}}}, - }, - { - name: "long decimal", - blob: append(ion10BVM, 0x5E, 0x8E, 0x01, 0x02, 0x03, 0x80, 0x05, 0x06, 0x07, 0x80, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e), - expected: &Digest{values: []Value{Decimal{isSet: true, binary: []byte{0x01, 0x02, 0x03, 0x80, 0x05, 0x06, 0x07, 0x80, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}}}}, - }, - - // Timestamp. - - { - name: "timestamp - too short", - blob: append(ion10BVM, 0x61), - expectedErr: errors.New("timestamp must have a length of at least two bytes"), - }, - { - name: "timestamp - length exceeds maximum", - blob: append(ion10BVM, 0x6D), - expectedErr: errors.New("timestamp length of 13 exceeds expected maximum of 12"), - }, - { - name: "timestamp - offset isn't terminated", - blob: append(ion10BVM, 0x62, 0x00, 0x00), - expectedErr: errors.New("unable to determine timestamp offset: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "timestamp - year only", - blob: append(ion10BVM, 0x63, 0x80, 0x0F, 0xD0), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionYear, - binary: []byte{0x80, 0x0F, 0xD0}, - value: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year isn't terminated", - blob: append(ion10BVM, 0x63, 0x80, 0x00, 0x00), - expectedErr: errors.New("unable to determine timestamp year: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "timestamp - year and month", - blob: append(ion10BVM, 0x64, 0x80, 0x0F, 0xD0, 0x81), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMonth, - binary: []byte{0x80, 0x0F, 0xD0, 0x81}, - value: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - }}}, - }, - { - name: "timestamp - month isn't terminated", - blob: append(ion10BVM, 0x64, 0x80, 0x0F, 0xD0, 0x01), - expectedErr: errors.New("unable to determine timestamp month: number not terminated after 1 bytes"), - }, - { - name: "timestamp - year, month, and day", - blob: append(ion10BVM, 0x65, 0x80, 0x0F, 0xD0, 0x81, 0x82), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionDay, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82}, - value: time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), - }}}, - }, - { - name: "timestamp - day isn't terminated", - blob: append(ion10BVM, 0x65, 0x80, 0x0F, 0xD0, 0x81, 0x01), - expectedErr: errors.New("unable to determine timestamp day: number not terminated after 1 bytes"), - }, - { - name: "timestamp - hour without minutes", - blob: append(ion10BVM, 0x66, 0x80, 0x0F, 0xD0, 0x81, 0x81, 0x81), - expectedErr: errors.New("invalid timestamp - cannot specify hours without minutes"), - }, - { - name: "timestamp - year, month, day, hour, and minute", - blob: append(ion10BVM, 0x67, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMinute, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84}, - value: time.Date(2000, 1, 2, 3, 4, 0, 0, time.UTC), - }}}, - }, - { - name: "timestamp - hour isn't terminated", - blob: append(ion10BVM, 0x67, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x03, 0x84), - expectedErr: errors.New("unable to determine timestamp hour: number not terminated after 1 bytes"), - }, - { - name: "timestamp - minute isn't terminated", - blob: append(ion10BVM, 0x67, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x04), - expectedErr: errors.New("unable to determine timestamp minute: number not terminated after 1 bytes"), - }, - { - name: "timestamp - year, month, day, hour, minute, and second", - blob: append(ion10BVM, 0x68, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionSecond, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - second isn't terminated", - blob: append(ion10BVM, 0x68, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x05), - expectedErr: errors.New("unable to determine timestamp second: number not terminated after 1 bytes"), - }, - { - name: "timestamp - year, month, day, hour, minute, second, and millisecond1", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC1), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond1, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC1}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and millisecond2", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC2), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond2, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC2}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and millisecond3", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC3), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond3, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC3}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and millisecond4", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC4), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond4, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC4}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and microsecond1", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC5), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMicrosecond1, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC5}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and microsecond2", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC6), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMicrosecond2, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC6}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and microsecond3", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC7), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMicrosecond3, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC7}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and microsecond4", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC8), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMicrosecond4, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC8}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and microsecond5", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC9), - // Reading the exponent removes the stop bit of the variable int which turns - // the C into a 4. - expectedErr: errors.New("invalid exponent for timestamp fractional second: 0x49"), - }, - { - name: "timestamp - exponent isn't terminated", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0x04), - expectedErr: errors.New("unable to determine timestamp fractional second exponent: number not terminated after 1 bytes"), - }, - { - name: "timestamp - year, month, day, hour, minute, second, exponent, and coefficient", - blob: append(ion10BVM, 0x6A, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC1, 0x06), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond1, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC1, 0x06}, - value: time.Date(2000, 1, 2, 3, 4, 5, int(600*time.Millisecond), time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, exponent, and 2 byte coefficient", - blob: append(ion10BVM, 0x6B, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC1, 0x00, 0x06), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond1, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC1, 0x00, 0x06}, - value: time.Date(2000, 1, 2, 3, 4, 5, int(600*time.Millisecond), time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, and exponent1 - exponent is < C1", - blob: append(ion10BVM, 0x69, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC0), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionSecond, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC0}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp - year, month, day, hour, minute, second, exponent, and coefficient - exponent is < C1 and coefficient is 0", - blob: append(ion10BVM, 0x6A, 0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC0, 0x00), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionSecond, - binary: []byte{0x80, 0x0F, 0xD0, 0x81, 0x82, 0x83, 0x84, 0x85, 0xC0, 0x00}, - value: time.Date(2000, 1, 2, 3, 4, 5, 0, time.UTC), - }}}, - }, - { - name: "timestamp2011-02-20T19_30_59_100-08_00.10n", - blob: append(ion10BVM, 0x6B, 0x43, 0xE0, 0x0F, 0xDB, 0x82, 0x94, 0x93, 0x9E, 0xBB, 0xC3, 0x64), - expected: &Digest{values: []Value{Timestamp{ - precision: TimestampPrecisionMillisecond3, - binary: []byte{0x43, 0xE0, 0x0F, 0xDB, 0x82, 0x94, 0x93, 0x9E, 0xBB, 0xC3, 0x64}, - value: time.Date(2011, 2, 20, 19, 30, 59, int(100*time.Millisecond), time.FixedZone("-8:00", -480)), - }}}, - }, - - // List and S-Expression. - - { - name: "zero length list", - blob: append(ion10BVM, 0xB0), - expected: &Digest{values: []Value{List{values: []Value{}}}}, - }, - { - name: "zero length sexp", - blob: append(ion10BVM, 0xC0), - expected: &Digest{values: []Value{SExp{values: []Value{}}}}, - }, - { - name: "list - invalid bool", - blob: append(ion10BVM, 0xB1, 0x12), - expectedErr: errors.New("unable to parse list: invalid bool representation 0x2"), - }, - { - name: "sexp - invalid bool", - blob: append(ion10BVM, 0xC1, 0x12), - expectedErr: errors.New("unable to parse list: invalid bool representation 0x2"), - }, - { - name: "list - valid bool, invalid float", - blob: append(ion10BVM, 0xB2, 0x11, 0x42), - expectedErr: errors.New("unable to parse list: invalid float length 2"), - }, - { - name: "sexp - valid bool, invalid float", - blob: append(ion10BVM, 0xC2, 0x11, 0x42), - expectedErr: errors.New("unable to parse list: invalid float length 2"), - }, - { - name: "nested list - invalid bool", - blob: append(ion10BVM, 0xB2, 0xB1, 0x12), - expectedErr: errors.New("unable to parse list: unable to parse list: invalid bool representation 0x2"), - }, - { - name: "nested sexp - invalid bool", - blob: append(ion10BVM, 0xC2, 0xB1, 0x12), - expectedErr: errors.New("unable to parse list: unable to parse list: invalid bool representation 0x2"), - }, - { - name: "list - valid bool", - blob: append(ion10BVM, 0xB1, 0x11), - expected: &Digest{ - values: []Value{List{ - values: []Value{Bool{isSet: true, value: true}}, - }}, - }, - }, - { - name: "sexp - valid bool", - blob: append(ion10BVM, 0xC1, 0x11), - expected: &Digest{ - values: []Value{SExp{ - values: []Value{Bool{isSet: true, value: true}}, - }}, - }, - }, - { - name: "sexp in list - valid bool", - blob: append(ion10BVM, 0xB2, 0xC1, 0x11), - expected: &Digest{ - values: []Value{List{ - values: []Value{SExp{ - values: []Value{Bool{isSet: true, value: true}}, - }}, - }}, - }, - }, - { - name: "list in sexp - valid bool", - blob: append(ion10BVM, 0xC2, 0xB1, 0x11), - expected: &Digest{ - values: []Value{SExp{ - values: []Value{List{ - values: []Value{Bool{isSet: true, value: true}}, - }}, - }}, - }, - }, - - // Struct. - - { - name: "zero length struct", - blob: append(ion10BVM, 0xD0), - expected: &Digest{values: []Value{Struct{fields: []StructField{}}}}, - }, - { - name: "struct - symbol isn't terminated", - blob: append(ion10BVM, 0xD2, 0x00, 0x00), - expectedErr: errors.New("unable to read struct field symbol: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "struct - invalid bool field value", - blob: append(ion10BVM, 0xD2, 0x80, 0x12), - expectedErr: errors.New("unable to read struct field value: invalid bool representation 0x2"), - }, - { - name: "struct - no value", - blob: append(ion10BVM, 0xD2, 0x00, 0x84), - expectedErr: errors.New("unable to read struct field value: EOF"), - }, - { - name: "struct - valid bool", - blob: append(ion10BVM, 0xD2, 0x80, 0x11), - expected: &Digest{values: []Value{Struct{fields: []StructField{ - {Symbol: Symbol{}, Value: Bool{isSet: true, value: true}}, - }}}}, - }, - { - name: "struct - valid padding", - blob: append(ion10BVM, 0xD3, 0x80, 0x01, 0xFF), - expected: &Digest{values: []Value{Struct{fields: []StructField{}}}}, - }, - - // Annotation. - - { - name: "annotation - length too short", - blob: append(ion10BVM, 0xE1), - expectedErr: errors.New("length must be at least 3 for an annotation wrapper, found 1"), - }, - { - name: "annotation - annotation length not terminated", - blob: append(ion10BVM, 0xE3, 0x00, 0x00, 0x11), - expectedErr: errors.New("unable to determine annotation symbol length: number not terminated after 3 bytes"), - }, - { - name: "annotation - annotation length equal to numBytes", - blob: append(ion10BVM, 0xE3, 0x83, 0x84, 0x11), - expectedErr: errors.New("invalid lengths for annotation - field length is 3 while annotation symbols length is 3"), - }, - { - name: "annotation - annotation not terminated", - blob: append(ion10BVM, 0xE3, 0x81, 0x00, 0x11), - expectedErr: errors.New("unable to read annotation symbol: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "annotation - invalid bool", - blob: append(ion10BVM, 0xE3, 0x81, 0x84, 0x12), - expectedErr: errors.New("unable to read annotation value: invalid bool representation 0x2"), - }, - { - name: "annotation - length exceeds single value", - blob: append(ion10BVM, 0xE4, 0x81, 0x84, 0x11, 0x11), - expectedErr: errors.New("annotation declared 4 bytes but there are 1 bytes left"), - }, - { - name: "annotation on noop padding", - blob: append(ion10BVM, 0xE3, 0x81, 0x84, 0x00), - expectedErr: errors.New("annotation on padding is not legal"), - }, - { - name: "annotation - valid bool", - blob: append(ion10BVM, 0xE3, 0x81, 0x84, 0x11), - expected: &Digest{values: []Value{Bool{ - annotations: []Symbol{{id: 4}}, - isSet: true, - value: true, - }}}, - }, - { - name: "two annotations - valid bool", - blob: append(ion10BVM, 0xE4, 0x82, 0x84, 0x87, 0x11), - expected: &Digest{values: []Value{Bool{ - annotations: []Symbol{{id: 4}, {id: 7}}, - isSet: true, - value: true, - }}}, - }, - - // Error - EOF while reading length. - - { - name: "padding - EOF while reading length", - blob: append(ion10BVM, 0x0E, 0x0E), - expectedErr: errors.New("unable to parse length of padding: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "pos int - EOF while reading length", - blob: append(ion10BVM, 0x2E, 0x0E), - expectedErr: errors.New("unable to parse length of int: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "neg int - EOF while reading length", - blob: append(ion10BVM, 0x3E, 0x0E), - expectedErr: errors.New("unable to parse length of int: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "decimal - EOF while reading length", - blob: append(ion10BVM, 0x5E, 0x0E), - expectedErr: errors.New("unable to parse length of decimal: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "timestamp - EOF while reading length", - blob: append(ion10BVM, 0x6E, 0x0E), - expectedErr: errors.New("unable to parse length of timestamp: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "symbol - EOF while reading length", - blob: append(ion10BVM, 0x7E, 0x0E), - expectedErr: errors.New("unable to parse length of symbol: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "string - EOF while reading length", - blob: append(ion10BVM, 0x8E, 0x0E), - expectedErr: errors.New("unable to parse length of string: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "clob - EOF while reading length", - blob: append(ion10BVM, 0x9E, 0x0E), - expectedErr: errors.New("unable to parse length of bytes: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "blob - EOF while reading length", - blob: append(ion10BVM, 0xAE, 0x0E), - expectedErr: errors.New("unable to parse length of bytes: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "list - EOF while reading length", - blob: append(ion10BVM, 0xBE, 0x0E), - expectedErr: errors.New("unable to parse length of list: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "sexp - EOF while reading length", - blob: append(ion10BVM, 0xCE, 0x0E), - expectedErr: errors.New("unable to parse length of list: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "struct - EOF while reading length", - blob: append(ion10BVM, 0xDE, 0x0E), - expectedErr: errors.New("unable to parse length of struct: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "struct - EOF while reading length special length 1 case", - blob: append(ion10BVM, 0xD1, 0x0E), - expectedErr: errors.New("unable to parse length of struct: read 0 bytes (wanted one) of number with err: EOF"), - }, - { - name: "annotation - EOF while reading length", - blob: append(ion10BVM, 0xEE, 0x0E), - expectedErr: errors.New("unable to parse length of annotation: read 0 bytes (wanted one) of number with err: EOF"), - }, - - // Error - EOF while reading value. - - { - name: "padding - EOF while reading padding", - blob: append(ion10BVM, 0x0E, 0x8E, 0xFF), - expectedErr: errors.New("read 1 of expected 14 padding bytes with err: "), - }, - { - name: "pos int - EOF while reading value", - blob: append(ion10BVM, 0x22, 0x08), - expectedErr: errors.New("unable to read int - read 1 bytes of 2 with err: "), - }, - { - name: "neg int - EOF while reading value", - blob: append(ion10BVM, 0x32, 0x08), - expectedErr: errors.New("unable to read int - read 1 bytes of 2 with err: "), - }, - { - name: "float - EOF while reading value", - blob: append(ion10BVM, 0x44, 0x08), - expectedErr: errors.New("unable to read float - read 1 bytes of 4 with err: "), - }, - { - name: "decimal - EOF while reading value", - blob: append(ion10BVM, 0x52, 0x08), - expectedErr: errors.New("unable to read decimal - read 1 bytes of 2 with err: "), - }, - { - name: "timestamp - EOF while reading value", - blob: append(ion10BVM, 0x62, 0x08), - expectedErr: errors.New("unable to read timestamp - read 1 bytes of 2 with err: "), - }, - { - name: "symbol - EOF while reading value", - blob: append(ion10BVM, 0x72, 0x08), - expectedErr: errors.New("unable to read symbol ID - read 1 bytes of 2 with err: "), - }, - { - name: "string - EOF while reading value", - blob: append(ion10BVM, 0x82, 0x08), - expectedErr: errors.New("unable to read string - read 1 bytes of 2 with err: "), - }, - { - name: "clob - EOF while reading bytes", - blob: append(ion10BVM, 0x92, 0x08), - expectedErr: errors.New("unable to read bytes - read 1 bytes of 2 with err: "), - }, - { - name: "blob - EOF while reading bytes", - blob: append(ion10BVM, 0xA2, 0x08), - expectedErr: errors.New("unable to read bytes - read 1 bytes of 2 with err: "), - }, - { - name: "list - EOF while reading values", - blob: append(ion10BVM, 0xB2, 0x08), - expectedErr: errors.New("unable to read list - read 1 bytes of 2 with err: "), - }, - { - name: "sexp - EOF while reading values", - blob: append(ion10BVM, 0xC2, 0x08), - expectedErr: errors.New("unable to read list - read 1 bytes of 2 with err: "), - }, - { - name: "struct - EOF while reading values", - blob: append(ion10BVM, 0xD2, 0x08), - expectedErr: errors.New("unable to read struct - read 1 bytes of 2 with err: "), - }, - { - name: "annotation - EOF while reading values", - blob: append(ion10BVM, 0xE3, 0x81), - expectedErr: errors.New("unable to read annotation - read 1 bytes of 3 with err: "), - }, - } - - for _, tst := range tests { - test := tst - t.Run(test.name, func(t *testing.T) { - out, err := parseBinaryBlob(test.blob) - if diff := cmpDigests(test.expected, out); diff != "" { - t.Error("out: (-expected, +found)", diff) - } - if diff := cmpErrs(test.expectedErr, err); diff != "" { - t.Error("err: (-expected, +found)", diff) - } - }) - } -} - -func TestBinaryStream(t *testing.T) { - // Test the cases that aren't covered by TestParseBinaryBlob. - tests := []struct { - name string - blob []byte - expected []Value - expectedErr error - }{ - { - name: "two digests, two booleans", - blob: []byte{0xE0, 0x01, 0x00, 0xEA, 0x11, 0xE0, 0x01, 0x00, 0xEA, 0x11}, - expected: []Value{Bool{isSet: true, value: true}, Bool{isSet: true, value: true}}, - }, - { - name: "EOF reading second BVM", - blob: []byte{0xE0, 0x01, 0x00, 0xEA, 0x11, 0xE0, 0x01, 0x00}, - expectedErr: errors.New("unable to read binary version marker - read 2 bytes of 3 with err: "), - }, - { - name: "invalid second BVM", - blob: []byte{0xE0, 0x01, 0x00, 0xEA, 0x11, 0xE0, 0x02, 0x00, 0xEA, 0x11}, - expectedErr: errors.New("invalid binary version marker: 0xe0 0x02 0x00 0xea"), - }, - { - name: "invalid boolean in second Digest", - blob: []byte{0xE0, 0x01, 0x00, 0xEA, 0x11, 0xE0, 0x01, 0x00, 0xEA, 0x12}, - expected: []Value{Bool{isSet: true, value: true}}, - expectedErr: errors.New("invalid bool representation 0x2"), - }, - } - - for _, tst := range tests { - test := tst - t.Run(test.name, func(t *testing.T) { - ch := parseBinaryStream(bytes.NewReader(test.blob)) - - var values []Value - var err error - Loop: - for { - select { - case item, ok := <-ch: - if !ok { - break Loop - } - if item.Error != nil { - err = item.Error - break Loop - } - values = append(values, item.Digest.values...) - case <-time.After(1 * time.Second): - t.Fatal("timed out") - } - } - - if diff := cmpValueSlices(test.expected, values); diff != "" { - t.Error("(-expected, +found)", diff) - } - if diff := cmpErrs(test.expectedErr, err); diff != "" { - t.Error("err: (-expected, +found)", diff) - } - }) - } -} - -func TestIonTests_Binary_Good(t *testing.T) { - filesToSkip := map[string]bool{ - // TODO amzn/ion-go#4 (these test edge cases around various type codes) - "T14.10n": true, - "T15.10n": true, - "T5.10n": true, - "T6-large.10n": true, - "T7-large.10n": true, - } - - testFilePath := "../ion-tests/iontestdata/good" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".10n") { - return nil - } - - name := info.Name() - if _, ok := filesToSkip[name]; ok { - t.Log("skipping", name) - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - out, err := parseBinaryBlob(data) - if err != nil { - t.Fatal(err) - } - if out == nil || len(out.Value()) == 0 { - t.Error("expected out to have at least one value") - } - - // TODO: If we are in the equivs directory, then verify that each top-level - // Value in the Digest is comprised of equivalent sub-elements. Need - // the Value() functions to be able to pull that off. - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} - -func TestIonTests_Binary_Equivalents(t *testing.T) { - testFilePath := "../ion-tests/iontestdata/good/equivs" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".10n") { - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - fmt.Println("parsing file", path) - out, err := parseBinaryBlob(data) - if err != nil { - t.Fatal(err) - } - - if out == nil || len(out.Value()) == 0 { - t.Error("expected out to have at least one value") - } - - for i, value := range out.Value() { - t.Log("collection", i, "of", info.Name()) - switch value.Type() { - case TypeList: - assertEquivalentValues(value.(List).values, t) - case TypeSExp: - assertEquivalentValues(value.(SExp).values, t) - default: - t.Error("top-element item is", value.Type(), "for", info.Name()) - } - } - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} - -func TestIonTests_Binary_Bad(t *testing.T) { - filesToSkip := map[string]bool{ - // TODO: Deal with symbol tables and verification of SymbolIDs. - "annotationSymbolIDUnmapped.10n": true, - "fieldNameSymbolIDUnmapped.10n": true, - "localSymbolTableWithMultipleImportsFields.10n": true, - "localSymbolTableWithMultipleSymbolsAndImportsFields.10n": true, - "localSymbolTableWithMultipleSymbolsFields.10n": true, - "symbolIDUnmapped.10n": true, - // Not performing timestamp verification on parse. - "leapDayNonLeapYear_1.10n": true, - "leapDayNonLeapYear_2.10n": true, - "timestampSept31.10n": true, - // Not performing string verification on parse. - "stringWithLatinEncoding.10n": true, - // TODO amzn/ion-go#4 ('null' annotation) - "type_14_length_15.10n": true, - } - - testFilePath := "../ion-tests/iontestdata/bad" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".10n") { - return nil - } - - name := info.Name() - if _, ok := filesToSkip[name]; ok { - t.Log("skipping", name) - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - out, err := parseBinaryBlob(data) - if err == nil { - t.Error("expected error but found none") - } - if out != nil { - t.Errorf("%#v", out) - } - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} - -func Test_parseBinaryNull(t *testing.T) { - // It's not possible to hit this case going through the parser, but we want - // to demonstrate that if something gets broken this is what it looks like. - if val, err := parseBinaryNull(0xFF); val != nil || err == nil { - t.Errorf("expected nil value and error but found %#v and %+v", val, err) - } -} diff --git a/ion/parse_text.go b/ion/parse_text.go deleted file mode 100644 index 889a7c34..00000000 --- a/ion/parse_text.go +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "runtime" - - "github.com/amzn/ion-go/internal/lex" - "github.com/pkg/errors" -) - -// ParseText parses all of the bytes from the given Reader into an -// instance of Digest. Assume that the entire contents of reading the -// given reader will be kept in memory. This allows for lazy evaluation -// of values, e.g. don't turn "1234" into an int unless its value is -// accessed. -func ParseText(reader io.Reader) (*Digest, error) { - text, err := ioutil.ReadAll(reader) - if err != nil { - return nil, errors.Wrap(err, "unable to read ion text to parse") - } - - t := &parser{} - if err := t.Parse(text); err != nil { - return nil, err - } - - return t.digest, nil -} - -type parser struct { - digest *Digest - lex *lex.Lexer - token [3]lex.Item - peekCount int -} - -// panicf formats the given error and panics. Panicking provides a quick exit from -// whatever depth of parsing we are at and is recovered from during the call to Parse(). -func (p *parser) panicf(format string, args ...interface{}) { - format = fmt.Sprintf("parsing line %d - %s", p.lex.LineNumber(), format) - panic(fmt.Errorf(format, args...)) -} - -func (p *parser) next() lex.Item { - if p.peekCount > 0 { - p.peekCount-- - } else { - p.token[0] = p.lex.NextItem() - if p.token[0].Type == lex.IonError { - p.panicf("Encountered error lexing the next value: %v", p.token[0]) - } - } - return p.token[p.peekCount] -} - -// Backs the input stream up one item. -func (p *parser) backup() { - p.peekCount++ -} - -// Returns the next non-comment item, consuming all comments. -func (p *parser) nextNonComment() (item lex.Item) { - for { - item = p.next() - if item.Type != lex.IonCommentBlock && item.Type != lex.IonCommentLine { - break - } - } - return item -} - -// Returns but does not consume the next non-comment token, -// while consuming all comments. -func (p *parser) peekNonComment() lex.Item { - var item lex.Item - for { - item = p.next() - if item.Type != lex.IonCommentBlock && item.Type != lex.IonCommentLine { - break - } - } - p.backup() - return item -} - -// recover is a handler that turns panics into error returns from Parse. The -// panic is retained if it is of the runtime.Error variety. -func (p *parser) recover(err *error) { - e := recover() - if e != nil { - // We only want to capture errors that we panic on. - if _, ok := e.(runtime.Error); ok { - panic(e) - } - *err = e.(error) - } - return -} - -// Parses the given input string and makes the resulting parse tree available -// in the Root object of the tree. The nodes that are parsed are assigned the -// given priority. -func (p *parser) Parse(text []byte) (err error) { - defer p.recover(&err) - p.lex = lex.New(text) - p.parse() - return nil -} - -func (p *parser) parse() { - var values []Value - for item := p.peekNonComment(); item.Type != lex.IonEOF; item = p.peekNonComment() { - values = append(values, p.parseValue(false)) - } - - // If there is a version marker then it will be the first symbol. - if len(values) > 0 && values[0].Type() == TypeSymbol { - sym := values[0].Text() - if bytes.HasPrefix(sym, []byte("$ion_")) && !bytes.Equal(sym, []byte("$ion_1_0")) { - p.panicf("unsupported ION version %s", sym) - } - } - - p.digest = &Digest{values: values} -} - -func (p *parser) parseValue(allowOperator bool) Value { - var annotations []Symbol - for { - item := p.peekNonComment() - switch item.Type { - case lex.IonError: - p.panicf("unable to parse input: " + item.String()) - case lex.IonBinaryStart: - return p.parseBinary(annotations) - case lex.IonDecimal: - return p.parseDecimal(annotations) - case lex.IonFloat, lex.IonInfinity: - return p.parseFloat(annotations) - case lex.IonInt: - return p.parseInt(annotations, intBase10) - case lex.IonIntBinary: - return p.parseInt(annotations, intBase2) - case lex.IonIntHex: - return p.parseInt(annotations, intBase16) - case lex.IonListStart: - return p.parseList(annotations) - case lex.IonNull: - return p.parseNull(annotations) - case lex.IonOperator: - if !allowOperator { - p.panicf("operator not allowed outside s-expression %v", item) - } - return p.parseSymbol(annotations, true) - case lex.IonSExpStart: - return p.parseSExpression(annotations) - case lex.IonString: - return String{annotations: annotations, text: p.doStringReplacements(p.next().Val)} - case lex.IonStringLong: - return p.parseLongString(annotations) - case lex.IonStructStart: - return p.parseStruct(annotations) - case lex.IonSymbol, lex.IonSymbolQuoted: - symbol := p.parseSymbol(annotations, false) - if item = p.peekNonComment(); item.Type == lex.IonDoubleColon { - annotation, ok := symbol.(Symbol) - if !ok { - p.panicf("invalid annotation type %q", symbol.Type()) - } - // Annotations themselves don't have annotations. - annotation.annotations = nil - annotations = append(annotations, annotation) - fmt.Printf("Annotations: %#v\n", annotations) - p.nextNonComment() - continue - } - return symbol - case lex.IonTimestamp: - return p.parseTimestamp(annotations) - default: - p.panicf("unexpected item type %q", item.Type) - } - } -} diff --git a/ion/parse_text_container.go b/ion/parse_text_container.go deleted file mode 100644 index 9ce35245..00000000 --- a/ion/parse_text_container.go +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "github.com/amzn/ion-go/internal/lex" -) - -// This file contains text parsers for List, SExp, and Struct. - -func (p *parser) parseList(annotations []Symbol) List { - if item := p.next(); item.Type != lex.IonListStart { - p.panicf("expected list start but found %s", item) - } - - var values []Value - var prev lex.Item - for item := p.peekNonComment(); item.Type != lex.IonListEnd && item.Type != lex.IonError; prev, item = item, p.peekNonComment() { - if item.Type == lex.IonComma { - if prev.Type == 0 { - p.panicf("list may not start with a comma") - } - if prev.Type == lex.IonComma { - p.panicf("list must have a value defined between commas") - } - p.next() - continue - } else if prev.Type != lex.IonComma && prev.Type != 0 { - p.panicf("list items must be separated by commas") - } - - values = append(values, p.parseValue(false)) - } - - // Eat the end of the list. An improperly terminated list creates an error - // before we hit this spot, but check it just to be safe. - if item := p.next(); item.Type != lex.IonListEnd { - p.panicf("expected list end but found %s", item) - } - - return List{annotations: annotations, values: values} -} - -func (p *parser) parseSExpression(annotations []Symbol) SExp { - if item := p.next(); item.Type != lex.IonSExpStart { - p.panicf("expected s-expression start but found %s", item) - } - - var values []Value - for item := p.peekNonComment(); item.Type != lex.IonSExpEnd && item.Type != lex.IonError; item = p.peekNonComment() { - values = append(values, p.parseValue(true)) - } - - // Eat the end of the s-expression. An improperly terminated s-expression creates an error - // before we hit this spot, but check it just to be safe. - if item := p.next(); item.Type != lex.IonSExpEnd { - p.panicf("expected s-expression end but found %s", item) - } - - return SExp{annotations: annotations, values: values} -} - -func (p *parser) parseStruct(annotations []Symbol) Struct { - if item := p.next(); item.Type != lex.IonStructStart { - p.panicf("expected struct start but found %s", item) - } - - var values []StructField - var prev lex.Item - for item := p.peekNonComment(); item.Type != lex.IonStructEnd && item.Type != lex.IonError; prev, item = item, p.peekNonComment() { - if item.Type == lex.IonComma { - if prev.Type == 0 { - p.panicf("struct may not start with a comma") - } - if prev.Type == lex.IonComma { - p.panicf("struct must have a field defined between commas") - } - p.next() - continue - } else if prev.Type != lex.IonComma && prev.Type != 0 { - p.panicf("struct fields must be separated by commas") - } - - // Struct field names are not allowed to have annotations. - // It's possible for the symbol that gets parsed to be a special reserved - // Symbol, e.g. true, that resolves to a non-Symbol type. We need to put - // that back into a Symbol for the struct. - parsed := p.parseSymbol(nil, false) - if pt := parsed.Type(); pt == TypeBool || pt == TypeNull || pt == TypeFloat || pt == TypeDecimal { - p.panicf("invalid type for field: %s", pt) - } - - symbol, ok := parsed.(Symbol) - if !ok { - symbol = Symbol{text: parsed.Text()} - } - - if item = p.nextNonComment(); item.Type != lex.IonColon { - p.panicf("expected colon after symbol in struct but found %s", item) - } - value := p.parseValue(false) - values = append(values, StructField{Symbol: symbol, Value: value}) - } - - // Eat the end of the structure. An improperly terminated struct creates an error - // before we hit this spot, but check it just to be safe. - if item := p.next(); item.Type != lex.IonStructEnd { - p.panicf("expected struct end but found %s", item) - } - - return Struct{annotations: annotations, fields: values} -} diff --git a/ion/parse_text_numeric.go b/ion/parse_text_numeric.go deleted file mode 100644 index 54a55908..00000000 --- a/ion/parse_text_numeric.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -// parseDecimal parses a decimal value and counts the number of characters -// in the coefficient. -func (p *parser) parseDecimal(annotations []Symbol) Decimal { - text := p.next().Val - - // TODO: Properly calculate and track the precision of the decimal. - - return Decimal{annotations: annotations, isSet: true, text: text} -} - -// parseFloat parses the next value as a float. -func (p *parser) parseFloat(annotations []Symbol) Float { - return Float{annotations: annotations, isSet: true, text: p.next().Val} -} - -// parseInt parses an int for the given base. -func (p *parser) parseInt(annotations []Symbol, base intBase) Int { - text := p.next().Val - // An empty slice of bytes is not a valid int, so we're going to make the assumption - // that we can check the first element of the text slice. - return Int{annotations: annotations, isSet: true, base: base, isNegative: text[0] == '-', text: text} -} - -// parseTimestamp parses the next value as a Timestamp. -func (p *parser) parseTimestamp(annotations []Symbol) Timestamp { - return Timestamp{annotations: annotations, text: p.next().Val} -} diff --git a/ion/parse_text_simple.go b/ion/parse_text_simple.go deleted file mode 100644 index 2fd73a77..00000000 --- a/ion/parse_text_simple.go +++ /dev/null @@ -1,327 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "github.com/amzn/ion-go/internal/lex" - "math" - "strconv" - "unicode/utf8" -) - -// This file contains text parsers for Null, Padding, Bool, Symbol, String, Blob, and Clob. - -func (p *parser) parseBinary(annotations []Symbol) Value { - // Eat the IonBinaryStart. - if item := p.next(); item.Type != lex.IonBinaryStart { - p.panicf("expected binary start but found %q", item) - } - - var value Value - switch item := p.next(); item.Type { - case lex.IonBlob: - value = Blob{annotations: annotations, text: removeAny(item.Val, []byte(" \t\n\r\f\v"))} - case lex.IonClobShort: - value = Clob{annotations: annotations, text: p.doClobReplacements(item.Val)} - case lex.IonClobLong: - text := item.Val - for peek := p.peekNonComment(); peek.Type == lex.IonClobLong; peek = p.peekNonComment() { - text = append(text, p.next().Val...) - } - value = Clob{annotations: annotations, text: p.doClobReplacements(text)} - default: - p.panicf("expected a blob or clob but found %q", item) - } - - if item := p.next(); item.Type != lex.IonBinaryEnd { - p.panicf("expected binary end but found %q", item) - } - - return value -} - -func (p *parser) parseLongString(annotations []Symbol) String { - text := []byte{} - for item := p.peekNonComment(); item.Type == lex.IonStringLong; item = p.peekNonComment() { - text = append(text, p.doStringReplacements(p.next().Val)...) - } - - return String{annotations: annotations, text: text} -} - -// decodeHex takes a slice that is the input buffer and decodes it into a slice containing UTF-8 code units. -// The start parameter is the index in the slice start of a hex-encoded rune to decode. -// The escapeLen parameter is the length of the escape to decode and must be a length -// that is a power of two, no longer than size of uint32, and within the length of the input slice. -func (p *parser) decodeHex(input []byte, start int, hexLen int) []byte { - // Length must be a power of two and no larger that size of uint32. - if hexLen <= 0 || hexLen > 8 || hexLen%2 != 0 { - // calling code must give us a proper slice... - p.panicf("hex escape is invalid length (%d)", hexLen) - } - // Construct a working slice to decode with. - inLen := len(input) - if start < 0 || start >= inLen { - p.panicf("start of hex escape (%d) is negative or greater than or equal to input length (%d)", start, inLen) - } - end := start + hexLen - if end > inLen { - p.panicf("end of hex escape (%d) is greater than input length (%d)", end, inLen) - } - hex := input[start:end] - - // Decode the hex string into a UTF-32 scalar. - buf := make([]byte, utf8.UTFMax) - var cp uint64 - for i := 0; i < hexLen; i += 2 { - octet, errParse := strconv.ParseUint(string(hex[i:i+2]), 16, 8) - if errParse != nil { - p.panicf("invalid hex escape (%q) was not caught by lexer: %v", hex, errParse) - } - cp = (cp << 8) | octet - } - - // Now serialize back as UTF-8 code units. - encodeLen := utf8.EncodeRune(buf, rune(cp)) - - return buf[0:encodeLen] -} - -// doStringReplacements converts escaped characters into their equivalent -// character while handling cases involving \r. -func (p *parser) doStringReplacements(str []byte) []byte { - strLen := len(str) - ret := make([]byte, 0, strLen) - for index := 0; index < strLen; index++ { - switch ch := str[index]; ch { - case '\r': - // Turn \r into \n. - ret = append(ret, '\n') - // We need to treat both "\r\n" and "\r" as "\n", so - // skip extra if what comes next is "\n". - if index < strLen-1 && str[index+1] == '\n' { - index++ - } - case '\\': - if index >= strLen-1 { - continue - } - // We have an escape character. Do different things depending on - // what we are escaping. - switch next := str[index+1]; next { - case '\r': - // Newline being escaped. - index++ - // Newline being escaped may be \r\n or just \r. - if index < strLen-1 && str[index+1] == '\n' { - index++ - } - case '\n': - // Newline being escaped. - index++ - case 'r', 'n': - // Treat both "\\r" and "\\n" as "\n" - ret = append(ret, '\n') - index++ - case '\'', '"', '\\': - ret = append(ret, next) - index++ - case 'x': - index += 2 - data := p.decodeHex(str, index, 2) - index += 1 - ret = append(ret, data...) - case 'u': - index += 2 - data := p.decodeHex(str, index, 4) - index += 3 - ret = append(ret, data...) - case 'U': - index += 2 - data := p.decodeHex(str, index, 8) - index += 7 - ret = append(ret, data...) - default: - // Don't have anything special to do with the next character, so - // just add the current character and let the next one get added - // as normal. - ret = append(ret, ch) - } - default: - ret = append(ret, ch) - } - } - - return ret -} - -// doClobReplacements is like doStringReplacements but is restricted to escapes that CLOBs have. -func (p *parser) doClobReplacements(str []byte) []byte { - strLen := len(str) - ret := make([]byte, 0, strLen) - for index := 0; index < strLen; index++ { - switch ch := str[index]; ch { - case '\r': - // We normalize "\r" and "\r\n" as "\n". - if index < strLen-1 && str[index+1] == '\n' { - index++ - } - ret = append(ret, '\n') - case '\\': - if index >= strLen-1 { - continue - } - // We have an escape character. Do different things depending on - // what we are escaping. - switch next := str[index+1]; next { - case '\r': - // Newline being escaped. - index++ - // Newline being escaped may be \r\n or just \r. - if index < strLen-1 && str[index+1] == '\n' { - index++ - } - case '\n': - // Newline being escaped. - index++ - case 'n': - ret = append(ret, '\n') - index++ - case 'r': - ret = append(ret, '\r') - index++ - case '\'', '"', '\\': - ret = append(ret, next) - index++ - case 'x': - index += 2 - data := p.decodeHex(str, index, 2) - index += 1 - ret = append(ret, data...) - default: - // Don't have anything special to do with the next character, so - // just add the current character and let the next one get added - // as normal. - ret = append(ret, ch) - } - default: - ret = append(ret, ch) - } - } - - return ret -} - -// removeAny removes any occurrence of any of the given bytes from the given -// data and returns the result. -func removeAny(data []byte, any []byte) []byte { - ret := make([]byte, 0, len(data)) - for _, ch := range data { - if fnd := bytes.IndexRune(any, rune(ch)); fnd >= 0 { - continue - } - ret = append(ret, ch) - } - - return ret -} - -func (p *parser) parseNull(annotations []Symbol) Null { - item := p.next() - - switch string(item.Val) { - case "null.blob": - return Null{annotations: annotations, typ: TypeBlob} - case "null.bool": - return Null{annotations: annotations, typ: TypeBool} - case "null.clob": - return Null{annotations: annotations, typ: TypeClob} - case "null.decimal": - return Null{annotations: annotations, typ: TypeDecimal} - case "null.float": - return Null{annotations: annotations, typ: TypeFloat} - case "null.int": - return Null{annotations: annotations, typ: TypeInt} - case "null.list": - return Null{annotations: annotations, typ: TypeList} - case "null.null": - return Null{annotations: annotations, typ: TypeNull} - case "null.sexp": - return Null{annotations: annotations, typ: TypeSExp} - case "null.string": - return Null{annotations: annotations, typ: TypeString} - case "null.struct": - return Null{annotations: annotations, typ: TypeStruct} - case "null.symbol": - return Null{annotations: annotations, typ: TypeSymbol} - case "null.timestamp": - return Null{annotations: annotations, typ: TypeTimestamp} - default: - p.panicf("invalid null type: %v", item) - } - - // Not reach-able, but Go doesn't know that p.panicf always panics. - return Null{} -} - -// parseSymbol parses a quoted or unquoted symbol. There are several reserved symbols -// that hold special meaning, e.g. null.bool, that the lexer does not differentiate -// from other symbols. This method treats the reserved symbols differently and returns -// the correct type. -func (p *parser) parseSymbol(annotations []Symbol, allowOperator bool) Value { - item := p.next() - // Include IonString here since struct fields are ostensibly symbols, but - // quoted strings can be used to express them. IonOperator is basically a - // specialized version of Symbol. Long strings have more extensive parsing - // so back up and kick off that process. - switch item.Type { - case lex.IonOperator: - if !allowOperator { - p.panicf("operator not allowed here: %v", item) - } - case lex.IonNull: - // Null has its own special parsing, so backup and give that a go. - p.backup() - return p.parseNull(annotations) - case lex.IonSymbol, lex.IonSymbolQuoted, lex.IonString: - case lex.IonStringLong: - p.backup() - return p.parseLongString(annotations) - default: - p.panicf("expected operation, symbol, quoted symbol, or string but found %v", item) - } - - quoted := item.Type == lex.IonSymbolQuoted || item.Type == lex.IonString || item.Type == lex.IonStringLong - if !quoted { - switch string(item.Val) { - case "true": - return Bool{annotations: annotations, isSet: true, value: true} - case "false": - return Bool{annotations: annotations, isSet: true, value: false} - case "nan": - nan := math.NaN() - return Float{isSet: true, value: &nan} - case "null": - return Null{annotations: annotations} - } - } - - // TODO: Figure out why the bytes in item.Val get overwritten when we don't - // make an explicit copy of the data. - //return Symbol{annotations: annotations, quoted: quoted, text: doStringReplacements(item.Val)} - return Symbol{annotations: annotations, quoted: quoted, text: append([]byte{}, p.doStringReplacements(item.Val)...)} -} diff --git a/ion/parse_text_test.go b/ion/parse_text_test.go deleted file mode 100644 index 1af599dc..00000000 --- a/ion/parse_text_test.go +++ /dev/null @@ -1,555 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "errors" - "io/ioutil" - "os" - "path/filepath" - "strings" - "testing" -) - -// TODO: Need to wire up tests that call some of the underlying parse functions individually -// so that we can bypass the checks that the lexer makes and trigger the panics. - -func TestParseText(t *testing.T) { - tests := []struct { - name string - text string - expected *Digest - expectedErr error - }{ - // Strings. - - { - name: "short and long strings", - text: "\"short string\"\n'''long'''\n'''string'''", - expected: &Digest{values: []Value{ - String{text: []byte("short string")}, - String{text: []byte("longstring")}, - }}, - }, - { - name: "escaped strings", - text: `"H\x48\u0048\U00000048" '''h\x68\u0068\U00000068'''`, - expected: &Digest{values: []Value{ - String{text: []byte("HHHH")}, - String{text: []byte("hhhh")}, - }}, - }, - - // Symbols - - { - name: "symbol", - text: "'short symbol'", - expected: &Digest{values: []Value{ - Symbol{text: []byte("short symbol")}, - }}, - }, - { - name: "escaped symbols", - text: `'H\x48\u0048\U00000048'`, - expected: &Digest{values: []Value{ - Symbol{text: []byte("HHHH")}, - }}, - }, - - // Numeric - - { - name: "infinity", - text: "inf +inf -inf", - expected: &Digest{values: []Value{ - // "inf" must have a plus or minus on it to be considered a number. - Symbol{text: []byte("inf")}, - Float{isSet: true, text: []byte("+inf")}, - Float{isSet: true, text: []byte("-inf")}, - }}, - }, - { - name: "integers", - text: "0 -1 1_2_3 0xFf -0xFf 0Xe_d 0b10 -0b10 0B1_0", - expected: &Digest{values: []Value{ - Int{isSet: true, text: []byte("0")}, - Int{isSet: true, isNegative: true, text: []byte("-1")}, - Int{isSet: true, text: []byte("1_2_3")}, - Int{isSet: true, base: intBase16, text: []byte("0xFf")}, - Int{isSet: true, isNegative: true, base: intBase16, text: []byte("-0xFf")}, - Int{isSet: true, base: intBase16, text: []byte("0Xe_d")}, - Int{isSet: true, base: intBase2, text: []byte("0b10")}, - Int{isSet: true, isNegative: true, base: intBase2, text: []byte("-0b10")}, - Int{isSet: true, base: intBase2, text: []byte("0B1_0")}, - }}, - }, - { - name: "decimals", - text: "0. 0.123 -0.12d4 0D-0 0d+0 12_34.56_78", - expected: &Digest{values: []Value{ - Decimal{isSet: true, text: []byte("0.")}, - Decimal{isSet: true, text: []byte("0.123")}, - Decimal{isSet: true, text: []byte("-0.12d4")}, - Decimal{isSet: true, text: []byte("0D-0")}, - Decimal{isSet: true, text: []byte("0d+0")}, - Decimal{isSet: true, text: []byte("12_34.56_78")}, - }}, - }, - { - name: "floats", - text: "0E0 0.12e-4 -0e+0", - expected: &Digest{values: []Value{ - Float{isSet: true, text: []byte("0E0")}, - Float{isSet: true, text: []byte("0.12e-4")}, - Float{isSet: true, text: []byte("-0e+0")}, - }}, - }, - - { - name: "dates", - text: "2019T 2019-10T 2019-10-30 2019-10-30T", - expected: &Digest{values: []Value{ - Timestamp{precision: TimestampPrecisionYear, text: []byte("2019T")}, - Timestamp{precision: TimestampPrecisionMonth, text: []byte("2019-10T")}, - Timestamp{precision: TimestampPrecisionDay, text: []byte("2019-10-30")}, - Timestamp{precision: TimestampPrecisionDay, text: []byte("2019-10-30T")}, - }}, - }, - { - name: "times", - text: "2019-10-30T22:30Z 2019-10-30T12:30:59+02:30 2019-10-30T12:30:59.999-02:30", - expected: &Digest{values: []Value{ - Timestamp{precision: TimestampPrecisionMinute, text: []byte("2019-10-30T22:30Z")}, - Timestamp{precision: TimestampPrecisionSecond, text: []byte("2019-10-30T12:30:59+02:30")}, - Timestamp{precision: TimestampPrecisionMillisecond3, text: []byte("2019-10-30T12:30:59.999-02:30")}, - }}, - }, - - // Binary. - - { - name: "short blob", - text: "{{+AB/}}", - expected: &Digest{values: []Value{Blob{text: []byte("+AB/")}}}, - }, - { - name: "padded blob with whitespace", - text: "{{ + A\nB\t/abc= }}", - expected: &Digest{values: []Value{Blob{text: []byte("+AB/abc=")}}}, - }, - { - name: "short clob", - text: `{{ "A\n" }}`, - expected: &Digest{values: []Value{Clob{text: []byte("A\n")}}}, - }, - { - name: "long clob", - text: "{{ '''+AB/''' }}", - expected: &Digest{values: []Value{Clob{text: []byte("+AB/")}}}, - }, - { - name: "multiple long clobs", - text: "{{ '''A\\nB'''\n'''foo''' }}", - expected: &Digest{values: []Value{Clob{text: []byte("A\nBfoo")}}}, - }, - { - name: "escaped clobs", - text: `{{"H\x48\x48H"}} {{'''h\x68\x68h'''}}`, - expected: &Digest{values: []Value{ - Clob{text: []byte("HHHH")}, - Clob{text: []byte("hhhh")}, - }}, - }, - - // Containers - - { - name: "struct with symbol to symbol", - text: `{symbol1: 'symbol', 'symbol2': symbol}`, - expected: &Digest{values: []Value{ - Struct{fields: []StructField{ - {Symbol: Symbol{text: []byte("symbol1")}, Value: Symbol{quoted: true, text: []byte("symbol")}}, - {Symbol: Symbol{quoted: true, text: []byte("symbol2")}, Value: Symbol{text: []byte("symbol")}}, - }}, - }}, - }, - { - name: "struct with annotated field", - text: `{symbol1: ann::'symbol'}`, - expected: &Digest{values: []Value{ - Struct{fields: []StructField{ - {Symbol: Symbol{text: []byte("symbol1")}, Value: Symbol{annotations: []Symbol{{text: []byte("ann")}}, quoted: true, text: []byte("symbol")}}, - }}, - }}, - }, - { - name: "struct with doubly-annotated field", - text: `{symbol1: ann1::ann2::'symbol'}`, - expected: &Digest{values: []Value{ - Struct{fields: []StructField{ - {Symbol: Symbol{text: []byte("symbol1")}, Value: Symbol{annotations: []Symbol{{text: []byte("ann1")}, {text: []byte("ann2")}}, quoted: true, text: []byte("symbol")}}, - }}, - }}, - }, - { - name: "struct with comments between symbol and value", - text: "{abc : // Line\n/* Block */ {{ \"A\\n\" }}}", - expected: &Digest{values: []Value{ - Struct{fields: []StructField{ - {Symbol: Symbol{text: []byte("abc")}, Value: Clob{text: []byte("A\n")}}, - }}, - }}, - }, - - { - name: "struct with empty list, struct, and sexp", - text: "{a:[], b:{}, c:()}", - expected: &Digest{values: []Value{ - Struct{fields: []StructField{ - {Symbol: Symbol{text: []byte("a")}, Value: List{}}, - {Symbol: Symbol{text: []byte("b")}, Value: Struct{}}, - {Symbol: Symbol{text: []byte("c")}, Value: SExp{}}, - }}, - }}, - }, - { - name: "list with empty list, struct, and sexp", - text: "[[], {}, ()]", - expected: &Digest{values: []Value{ - List{values: []Value{List{}, Struct{}, SExp{}}}, - }}, - }, - { - name: "list of things", - text: "[a, 1, ' ', {}, () /* comment */ ]", - expected: &Digest{values: []Value{ - List{values: []Value{ - Symbol{text: []byte("a")}, - Int{isSet: true, text: []byte("1")}, - Symbol{text: []byte(" ")}, - Struct{}, - SExp{}, - }}, - }}, - }, - { - name: "struct of things", - text: "{'a' : 1 , s:'', 'st': {}, \n/* comment */lst:[],\"sexp\":()}", - expected: &Digest{values: []Value{ - Struct{fields: []StructField{ - {Symbol: Symbol{text: []byte("a")}, Value: Int{isSet: true, text: []byte("1")}}, - {Symbol: Symbol{text: []byte("s")}, Value: Symbol{text: []byte("")}}, - {Symbol: Symbol{text: []byte("st")}, Value: Struct{}}, - {Symbol: Symbol{text: []byte("lst")}, Value: List{}}, - {Symbol: Symbol{text: []byte("sexp")}, Value: SExp{}}, - }}, - }}, - }, - { - name: "s-expression of things", - text: "(a+b/c<( j * k))", - expected: &Digest{values: []Value{ - SExp{values: []Value{ - Symbol{text: []byte("a")}, - Symbol{text: []byte("+")}, - Symbol{text: []byte("b")}, - Symbol{text: []byte("/")}, - Symbol{text: []byte("c")}, - Symbol{text: []byte("<")}, - SExp{values: []Value{ - Symbol{text: []byte("j")}, - Symbol{text: []byte("*")}, - Symbol{text: []byte("k")}, - }}, - }}, - }}, - }, - - // Error cases - - { - name: "list starts with comma", - text: "[, [], {}, ()]", - expectedErr: errors.New("parsing line 1 - list may not start with a comma"), - }, - { - name: "struct starts with comma", - text: "{, a:1}", - expectedErr: errors.New("parsing line 1 - struct may not start with a comma"), - }, - { - name: "list without commas", - text: "[[] {} ()]", - expectedErr: errors.New("parsing line 1 - list items must be separated by commas"), - }, - { - name: "struct without commas", - text: "{a:1 b:2}", - expectedErr: errors.New("parsing line 1 - struct fields must be separated by commas"), - }, - } - for _, tst := range tests { - test := tst - t.Run(test.name, func(t *testing.T) { - digest, err := ParseText(strings.NewReader(test.text)) - if diff := cmpDigests(test.expected, digest); diff != "" { - t.Logf("expected: %#v", test.expected) - t.Logf("found: %#v", digest) - t.Error("(-expected, +found)", diff) - } - if diff := cmpErrs(test.expectedErr, err); diff != "" { - t.Error("err: (-expected, +found)", diff) - } - }) - } -} - -func TestIonTests_Text_Good(t *testing.T) { - // We don't support UTF-16 or UTF-32 so skip those two test files. - filesToSkip := map[string]bool{ - "utf16.ion": true, - "utf32.ion": true, - } - - testFilePath := "../ion-tests/iontestdata/good" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".ion") { - return nil - } - - name := info.Name() - if _, ok := filesToSkip[name]; ok { - t.Log("skipping", name) - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - out, err := ParseText(bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } - - // There are a couple of files where correct parsing yields an empty Digest. - if strings.HasSuffix(path, "blank.ion") || strings.HasSuffix(path, "empty.ion") { - if out == nil { - t.Error("expected out to not be nil") - } - } else if out == nil || len(out.Value()) == 0 { - t.Error("expected out to have at least one value") - } - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} - -func TestIonTests_Text_Equivalents(t *testing.T) { - // We have some use-cases that are not yet supported. - filesToSkip := map[string]bool{ - // TODO: Deal with symbol tables and verification of SymbolIDs. - "annotatedIvms.ion": true, - "keywordPrefixes.ion": true, - "localSymbolTableAppend.ion": true, - "localSymbolTableNullSlots.ion": true, - "localSymbolTableWithAnnotations.ion": true, - "localSymbolTables.ion": true, - "localSymbolTablesValuesWithAnnotations.ion": true, - "nonIVMNoOps.ion": true, - "systemSymbols.ion": true, - // "Structures are unordered collections of name/value pairs." Comparing - // the structs for equivalency requires specialized logic that is not part - // of the spec. - "structsFieldsDiffOrder.ion": true, - "structsFieldsRepeatedNames.ion": true, - // We don't support arbitrary precision for timestamps. Once you get - // past microseconds it's pretty meaningless. - "timestampsLargeFractionalPrecision.ion": true, - // These files contain UTF16 and UTF32 which we do not support. - "stringU0001D11E.ion": true, - "stringUtf8.ion": true, - } - - testFilePath := "../ion-tests/iontestdata/good/equivs" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".ion") { - return nil - } - - name := info.Name() - if _, ok := filesToSkip[name]; ok { - t.Log("skipping", name) - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - out, err := ParseText(bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } - - if out == nil || len(out.Value()) == 0 { - t.Error("expected out to have at least one value") - } - - for i, value := range out.Value() { - t.Log("collection", i, "of", info.Name()) - switch value.Type() { - case TypeList: - assertEquivalentValues(value.(List).values, t) - case TypeSExp: - assertEquivalentValues(value.(SExp).values, t) - default: - t.Error("top-element item is", value.Type(), "for", info.Name()) - } - } - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} - -func TestIonTests_Text_NonEquivalents(t *testing.T) { - // We have some use-cases that are not yet supported. - filesToSkip := map[string]bool{ - // TODO: Deal with symbol tables and verification of SymbolIDs. - "annotations.ion": true, - "symbols.ion": true, - // Not properly tracking decimal precision yet. - "decimals.ion": true, - "nonNulls.ion": true, - // Not handling negative zero. - "floats.ion": true, - "floatsVsDecimals.ion": true, - // Not properly handling unknown local offset yet. - "timestamps.ion": true, - } - - testFilePath := "../ion-tests/iontestdata/good/non-equivs" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".ion") { - return nil - } - - name := info.Name() - if _, ok := filesToSkip[name]; ok { - t.Log("skipping", name) - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - out, err := ParseText(bytes.NewReader(data)) - if err != nil { - t.Fatal(err) - } - - if out == nil || len(out.Value()) == 0 { - t.Error("expected out to have at least one value") - } - - for i, value := range out.Value() { - t.Log("collection", i, "of", info.Name()) - switch value.Type() { - case TypeList: - assertNonEquivalentValues(value.(List).values, t) - case TypeSExp: - assertNonEquivalentValues(value.(SExp).values, t) - default: - t.Error("top-element item is", value.Type(), "for", info.Name()) - } - } - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} - -func TestIonTests_Text_Bad(t *testing.T) { - filesToSkip := map[string]bool{ - // TODO: Deal with symbol tables and verification of SymbolIDs. - "annotationSymbolIDUnmapped.ion": true, - "localSymbolTableImportNegativeMaxId.ion": true, - "localSymbolTableImportNonIntegerMaxId.ion": true, - "localSymbolTableImportNullMaxId.ion": true, - "localSymbolTableWithMultipleImportsFields.ion": true, - "localSymbolTableWithMultipleSymbolsAndImportsFields.ion": true, - "localSymbolTableWithMultipleSymbolsFields.ion": true, - "symbolIDUnmapped.ion": true, - // We only support UTF-8 - "fieldNameSymbolIDUnmapped.ion": true, - "longStringSplitEscape_2.ion": true, - "surrogate_1.ion": true, - "surrogate_2.ion": true, - "surrogate_3.ion": true, - "surrogate_4.ion": true, - "surrogate_5.ion": true, - "surrogate_6.ion": true, - "surrogate_7.ion": true, - "surrogate_8.ion": true, - "surrogate_9.ion": true, - "surrogate_10.ion": true, - } - - testFilePath := "../ion-tests/iontestdata/bad" - walkFn := func(path string, info os.FileInfo, err error) error { - if info.IsDir() || !strings.HasSuffix(path, ".ion") { - return nil - } - - name := info.Name() - if _, ok := filesToSkip[name]; ok { - t.Log("skipping", name) - return nil - } - - t.Run(strings.TrimPrefix(path, testFilePath), func(t *testing.T) { - data, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } - out, err := ParseText(bytes.NewReader(data)) - if err == nil { - t.Error("expected error but found none") - } - if out != nil { - t.Errorf("%#v", out.values) - } - }) - return nil - } - if err := filepath.Walk(testFilePath, walkFn); err != nil { - t.Fatal(err) - } -} diff --git a/ion/symbols_system.go b/ion/symbols_system.go deleted file mode 100644 index 9a3e5ad8..00000000 --- a/ion/symbols_system.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -const ( - symbolTextIon = "$ion" - symbolIDIon = 1 - - // Version Identifier for Ion 1.0. - symbolTextIon10 = "$ion_1_0" - symbolIDIon10 = 2 - - symbolTextTable = "$ion_symbol_table" - symbolIDTable = 3 - - symbolTextName = "name" - symbolIDName = 4 - - symbolTextVersion = "version" - symbolIDVersion = 5 - - symbolTextImports = "imports" - symbolIDImports = 6 - - symbolTextSymbols = "symbols" - symbolIDSymbols = 7 - - symbolTextMaxID = "max_id" - symbolIDMaxID = 8 - - symbolTextSharedTable = "$ion_shared_symbol_table" - symbolIDSharedTable = 9 -) - -func newSystemSymbolTable() *SymbolTable { - table := newSymbolTableRaw(symbolTableKindSystem, symbolTextIon, 1) - table.InternToken(symbolTextIon) - table.InternToken(symbolTextIon10) - table.InternToken(symbolTextTable) - table.InternToken(symbolTextName) - table.InternToken(symbolTextVersion) - table.InternToken(symbolTextImports) - table.InternToken(symbolTextSymbols) - table.InternToken(symbolTextMaxID) - table.InternToken(symbolTextSharedTable) - - return table -} - -var ( - // systemSymbolTable is the implicitly defined Ion 1.0 symbol table that all local symbol tables inherit. - systemSymbolTable = *newSystemSymbolTable() -) diff --git a/ion/symbols_table.go b/ion/symbols_table.go deleted file mode 100644 index 50469ec2..00000000 --- a/ion/symbols_table.go +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -type symbolTableKind uint8 - -const ( - symbolTableKindSystem symbolTableKind = iota - symbolTableKindShared - symbolTableKindLocal -) - -// TODO Consider using `maligned` from `golangci-lint` on this struct. - -// SymbolTable is the core lookup structure for Tokens from their symbolic ID. -type SymbolTable struct { - kind symbolTableKind - name string - version int32 - imports []SymbolTable - textMap map[string]SymbolToken - tokens []SymbolToken - maxSID int64 -} - -func newSymbolTableRaw(kind symbolTableKind, name string, version int32) *SymbolTable { - table := SymbolTable{ - kind: kind, - name: name, - version: version, - textMap: make(map[string]SymbolToken), - } - return &table -} - -// newSymbolTable constructs a new empty symbol table of the given type. -// The name and version are applicable for shared/system symbol tables and should be default -// values for local symbol tables. -// The given imports will be loaded into the newly constructed symbol table. -func newSymbolTable(kind symbolTableKind, name string, version int32, imports ...SymbolTable) *SymbolTable { - table := newSymbolTableRaw(kind, name, version) - // Copy in the import "references." - table.imports = append(imports[:0:0], imports...) - if kind == symbolTableKindLocal { - // For local symbol table, import system symbol table. - imports = append([]SymbolTable{systemSymbolTable}, imports...) - } - for _, importTable := range imports { - // TODO Consider if we should not inline the imports or make this delegate to the imports or configurable. - table.tokens = append(table.tokens, importTable.tokens...) - for text, newToken := range importTable.textMap { - if _, exists := table.textMap[text]; !exists { - table.textMap[text] = newToken - } - } - table.maxSID += importTable.maxSID - } - return table -} - -// newLocalSymbolTable creates an instance with symbolTableKindLocal. -// The name is empty and the version is zero. These fields are inapplicable to local symbol tables. -func newLocalSymbolTable(imports ...SymbolTable) *SymbolTable { - return newSymbolTable(symbolTableKindLocal, "", 0, imports...) -} - -// newSharedSymbolTable creates an instance with symbolTableKindShared. -// Returns `nil` if the version is not positive. -func newSharedSymbolTable(name string, version int32, imports ...SymbolTable) *SymbolTable { - if version <= 0 { - return nil - } - return newSymbolTable(symbolTableKindShared, name, version, imports...) -} - -// BySID returns the underlying SymbolToken by local ID. -func (t *SymbolTable) BySID(sid int64) (SymbolToken, bool) { - if sid <= 0 || sid > t.maxSID { - return symbolTokenUndefined, false - } - return t.tokens[sid-1], true -} - -// ByText returns the underlying SymbolToken by text lookup. -func (t *SymbolTable) ByText(text string) (SymbolToken, bool) { - if tok, exists := t.textMap[text]; exists { - return tok, true - } - return symbolTokenUndefined, false -} - -// InternToken adds a symbol to the given table if it does not exist. -func (t *SymbolTable) InternToken(symText string) SymbolToken { - if tok, exists := t.textMap[symText]; exists { - return tok - } - - t.maxSID += 1 - var source *ImportSource = nil - if t.kind != symbolTableKindLocal { - // defining a token within a shared/system table makes the token have a source referring to its - // own table and SID - source = newSource(t.name, t.maxSID) - } - - tok := SymbolToken{ - Text: &symText, - localSID: t.maxSID, - Source: source, - } - t.tokens = append(t.tokens, tok) - t.textMap[symText] = tok - return tok -} diff --git a/ion/symbols_table_test.go b/ion/symbols_table_test.go deleted file mode 100644 index 06f387c0..00000000 --- a/ion/symbols_table_test.go +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "math/rand" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestNewSharedSymbolTableBadVersion(t *testing.T) { - if newSharedSymbolTable("foo", 0) != nil { - t.Error("Expected version zero to return nil") - } - if newSharedSymbolTable("bar", -1) != nil { - t.Error("Expected negative version to return nil") - } -} - -var exportAll = cmp.Exporter(func(reflect.Type) bool { - return true -}) - -func assertSymbolTokenEquals(x, y SymbolToken, t *testing.T) { - t.Helper() - - diff := cmp.Diff(x, y, exportAll) - if diff != "" { - t.Errorf("Tokens are structurally different: %s", diff) - } -} - -func newString(value string) *string { - return &value -} - -type symbolTableCase struct { - desc string - kind symbolTableKind - name string - version int32 - imports []SymbolTable - maxSID int64 - tokens []SymbolToken - table SymbolTable -} - -func makeSystemSymbolTableCase() symbolTableCase { - // construct up all the expected tokens to test against - tokens := []SymbolToken{ - { - Text: newString("$ion"), - localSID: symbolIDIon, - Source: newSource(symbolTextIon, symbolIDIon), - }, - { - Text: newString("$ion_1_0"), - localSID: symbolIDIon10, - Source: newSource(symbolTextIon, symbolIDIon10), - }, - { - Text: newString("$ion_symbol_table"), - localSID: symbolIDTable, - Source: newSource(symbolTextIon, symbolIDTable), - }, - { - Text: newString("name"), - localSID: symbolIDName, - Source: newSource(symbolTextIon, symbolIDName), - }, - { - Text: newString("version"), - localSID: symbolIDVersion, - Source: newSource(symbolTextIon, symbolIDVersion), - }, - { - Text: newString("imports"), - localSID: symbolIDImports, - Source: newSource(symbolTextIon, symbolIDImports), - }, - { - Text: newString("symbols"), - localSID: symbolIDSymbols, - Source: newSource(symbolTextIon, symbolIDSymbols), - }, - { - Text: newString("max_id"), - localSID: symbolIDMaxID, - Source: newSource(symbolTextIon, symbolIDMaxID), - }, - { - Text: newString("$ion_shared_symbol_table"), - localSID: symbolIDSharedTable, - Source: newSource(symbolTextIon, symbolIDSharedTable), - }, - } - - // explicitly make a new system table to test against so we don't manipulate the global one - table := newSystemSymbolTable() - return symbolTableCase{ - desc: "systemSymbolTable", - kind: symbolTableKindSystem, - name: symbolTextIon, - version: 1, - maxSID: 9, - tokens: tokens, - table: *table, - } -} - -func makeSymbolTableCase( - desc string, - start int64, - kind symbolTableKind, - name string, - version int32, - makeToken func(text string, sid int64) SymbolToken, - newTable func(imports ...SymbolTable) *SymbolTable, - texts []string, - imports ...SymbolTable) symbolTableCase { - - // expect that we start at end of system table - var sid int64 = start - for _, imp := range imports { - sid += imp.maxSID - } - - // manually calculate what the implementation should do - var tokens []SymbolToken - for _, text := range texts { - sid += 1 - tokens = append(tokens, makeToken(text, sid)) - } - - // construct the table to test - table := newTable(imports...) - for _, text := range texts { - table.InternToken(text) - } - - return symbolTableCase{ - desc: desc, - kind: kind, - name: name, - version: version, - imports: imports, - maxSID: sid, - tokens: tokens, - table: *table, - } -} - -func makeSharedSymbolTableCase( - desc string, name string, version int32, texts []string, imports ...SymbolTable) symbolTableCase { - makeToken := func(text string, sid int64) SymbolToken { - source := ImportSource{ - Table: name, - SID: sid, - } - return SymbolToken{ - Text: newString(text), - localSID: sid, - Source: &source, - } - } - newTable := func(imports ...SymbolTable) *SymbolTable { - return newSharedSymbolTable(name, version, imports...) - } - return makeSymbolTableCase( - desc, 0, symbolTableKindShared, name, version, makeToken, newTable, texts, imports...) -} - -func makeLocalSymbolTableCase(desc string, texts []string, imports ...SymbolTable) symbolTableCase { - makeToken := func(text string, sid int64) SymbolToken { - return SymbolToken{ - Text: newString(text), - localSID: sid, - Source: nil, - } - } - return makeSymbolTableCase( - desc, 9, symbolTableKindLocal, "", 0, makeToken, newLocalSymbolTable, texts, imports...) -} - -// Arbitrary, deterministic seed for our symbol table testing -const testSystemSymbolTableSeed = 0x03CB815D3119751B - -func TestSymbolTable(t *testing.T) { - // TODO test tables with imports - cases := []symbolTableCase{ - makeSystemSymbolTableCase(), - makeLocalSymbolTableCase("Empty LST", nil), - makeLocalSymbolTableCase("Simple LST", []string{"a", "b", "c"}), - makeSharedSymbolTableCase("Empty SST", "foo", 1, nil), - makeSharedSymbolTableCase("Simple SST", "bar", 3, []string{"cat", "dog", "moose"}), - } - - for _, c := range cases { - t.Run(c.desc, func(t *testing.T) { - // TODO Consider using cmp.Transform to make these tests less verbose... - if c.kind != c.table.kind { - t.Errorf("Table kind mismatch: %d != %d", c.kind, c.table.kind) - } - if c.name != c.table.name { - t.Errorf("Table name mismatch: %s != %s", c.name, c.table.name) - } - if c.version != c.table.version { - t.Errorf("Table version mismatch: %d != %d", c.version, c.table.version) - } - importDiff := cmp.Diff(c.imports, c.table.imports) - if importDiff != "" { - t.Errorf("Table imports mismatch:\n%s", importDiff) - } - if c.maxSID != c.table.maxSID { - t.Errorf("Table maxSID mismatch: %d != %d", c.maxSID, c.table.maxSID) - } - for _, badSID := range []int64{0, -10, c.maxSID + 1, c.maxSID + 128} { - tok, ok := c.table.BySID(badSID) - if ok { - t.Errorf("Found a token %v for a non-existent SID %d", tok, badSID) - } - } - - // shuffle the tokens to randomize search a bit - rnd := rand.New(rand.NewSource(testSystemSymbolTableSeed)) - rnd.Shuffle(len(c.tokens), func(i, j int) { - c.tokens[i], c.tokens[j] = c.tokens[j], c.tokens[i] - }) - for _, expected := range c.tokens { - for i := 0; i < 2; i++ { - actualByID, ok := c.table.BySID(expected.localSID) - if !ok { - t.Error("Could not find ", expected) - } else { - assertSymbolTokenEquals(expected, actualByID, t) - } - - actualByText, ok := c.table.ByText(*expected.Text) - if !ok { - t.Error("Could not find ", expected) - } else { - assertSymbolTokenEquals(expected, actualByText, t) - } - - // interning the same text should not affect the above - newTok := c.table.InternToken(*expected.Text) - if newTok != actualByText { - t.Errorf("Interned Token should be the same: %v != %v", newTok, actualByText) - } - } - } - }) - } - -} diff --git a/ion/symbols_token.go b/ion/symbols_token.go deleted file mode 100644 index 32fe6e76..00000000 --- a/ion/symbols_token.go +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import "fmt" - -const ( - // The placeholder for when a symbol token has no symbol ID. - SymbolIDUnknown = -1 -) - -// ImportSource is a reference to a SID within a shared symbol table. -type ImportSource struct { - // The name of the shared symbol table that this token refers to. - Table string - // The ID of the interned symbol text within the shared SymbolTable. - // This must be greater than 1. - SID int64 -} - -func newSource(table string, sid int64) *ImportSource { - value := ImportSource{ - Table: table, - SID: sid, - } - return &value -} - -// A symbolic token for Ion. -// Symbol tokens are the values that annotations, field names, and the textual content of Ion symbol values. -// The `nil` value for SymbolToken is $0. -type SymbolToken struct { - // The string text of the token or nil if unknown. - Text *string - // Local symbol ID associated with the token. - localSID int64 - // The shared symbol table location that this token came from, or nil if undefined. - Source *ImportSource -} - -var ( - // symbolTokenUndefined is the sentinel for invalid tokens. - // The `nil` value is actually $0 which is a defined token. - symbolTokenUndefined = SymbolToken{ - localSID: SymbolIDUnknown, - } -) - -func (t SymbolToken) String() string { - text := "nil" - if t.Text != nil { - text = fmt.Sprintf("%q", *t.Text) - } - - source := "nil" - if t.Source != nil { - source = fmt.Sprintf("{%q %d}", t.Source.Table, t.Source.SID) - } - - return fmt.Sprintf("{%s %d %s}", text, t.localSID, source) -} diff --git a/ion/symbols_token_test.go b/ion/symbols_token_test.go deleted file mode 100644 index 71281c74..00000000 --- a/ion/symbols_token_test.go +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ -package ion - -import ( - "fmt" - "testing" - - "github.com/google/go-cmp/cmp" -) - -// Make sure SymbolToken conforms to Stringer -var _ fmt.Stringer = SymbolToken{} - -func TestSymbolToken_String(t *testing.T) { - cases := []struct { - desc string - token SymbolToken - expected string - }{ - { - desc: "Text and SID", - token: SymbolToken{ - Text: newString("hello"), - localSID: 10, - Source: nil, - }, - expected: `{"hello" 10 nil}`, - }, - { - desc: "nil Text", - token: SymbolToken{ - Text: nil, - localSID: 11, - Source: nil, - }, - expected: `{nil 11 nil}`, - }, - { - desc: "Text and SID with Import", - token: SymbolToken{ - Text: newString("world"), - localSID: 12, - Source: newSource("foobar", 3), - }, - expected: `{"world" 12 {"foobar" 3}}`, - }, - } - - for _, c := range cases { - t.Run(c.desc, func(t *testing.T) { - if diff := cmp.Diff(c.expected, c.token.String()); diff != "" { - t.Errorf("Token String() differs (-expected, +actual):\n%s", diff) - } - }) - } -} diff --git a/ion/types.go b/ion/types.go deleted file mode 100644 index 5243e022..00000000 --- a/ion/types.go +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -type Type int - -const ( - TypeNull Type = iota - TypeAnnotation - TypeBlob - TypeBool - TypeClob - TypeDecimal - TypeFloat - TypeInt - TypeList - TypeLongString - TypePadding - TypeSExp - TypeString - TypeStruct - TypeSymbol - TypeTimestamp -) - -var typeNameMap = map[Type]string{ - TypeNull: "Null", TypeAnnotation: "Annotation", TypeBlob: "Blob", TypeClob: "Clob", TypeDecimal: "Decimal", - TypeFloat: "Float", TypeInt: "Int", TypeList: "List", TypeLongString: "LongString", TypePadding: "Padding", - TypeSExp: "S-Expression", TypeString: "String", TypeStruct: "Struct", TypeSymbol: "Symbol", TypeTimestamp: "Timestamp", -} - -// String satisfies Stringer. -func (t Type) String() string { - if s, ok := typeNameMap[t]; ok { - return s - } - return "Unknown" -} - -const ( - textBoolFalse = "false" - textBoolTrue = "true" - textNullBool = "null.bool" - textNullTimestamp = "null.timestamp" -) - -// Digest is a top-level Ion container of Ion values. It is also the -// granularity of binary encoding Ion content. It is not a defined -// type in the Ion spec, but is used as a container of types and Symbols. -type Digest struct { - values []Value -} - -// Value returns the Values that make up this Digest. -func (d Digest) Value() []Value { - return d.values -} - -// Value is a basic interface for all Ion types. -// http://amzn.github.io/ion-docs/docs/spec.html -type Value interface { - // Annotations returns any annotations that have been set for this Value. - Annotations() []Symbol - // Binary returns the binary representation of the Value. - // http://amzn.github.io/ion-docs/docs/binary.html - Binary() []byte - // Text returns the text representation of the Value. - // http://amzn.github.io/ion-docs/docs/text.html - Text() []byte - // IsNull returns whether or not this instance of the value represents a - // null value for a given type. - // TODO: Determine if we want to use IsNull or use the Null struct. - IsNull() bool - // Type returns the Type of the Value. - Type() Type - - // Note: There is no general "Value" function to retrieve the Go version of - // the underlying value because we would need to define it to return - // interface{}. This decision may be revisited after playing with - // the library a bit. -} - -// Padding represents no-op padding in a binary stream. -type padding struct { - // Note that the name "length" is a little bit of a misnomer since a - // padding of length n pads n+1 bytes. - binary []byte -} - -// Annotations satisfies Value. -func (p padding) Annotations() []Symbol { - return nil -} - -// Binary satisfies Value. -func (p padding) Binary() []byte { - return p.binary -} - -// Text satisfies Value. -func (p padding) Text() []byte { - // Text padding isn't a thing. - return nil -} - -// IsNull satisfies Value. -func (p padding) IsNull() bool { - return false -} - -// Type satisfies Value. -func (p padding) Type() Type { - return TypePadding -} - -// Bool is the boolean type. -type Bool struct { - annotations []Symbol - isSet bool - value bool -} - -// Value returns the boolean value. This will be false if it has not been set. -func (b Bool) Value() bool { - if !b.isSet { - return false - } - return b.value -} - -// Annotations satisfies Value. -func (b Bool) Annotations() []Symbol { - return b.annotations -} - -// Binary satisfies Value. -func (b Bool) Binary() []byte { - return nil -} - -// Text satisfies Value. -func (b Bool) Text() []byte { - if !b.isSet { - return []byte(textNullBool) - } - if b.value { - return []byte(textBoolTrue) - } - return []byte(textBoolFalse) -} - -// IsNull satisfies Value. -func (b Bool) IsNull() bool { - return !b.isSet -} - -// Type satisfies Value. -func (b Bool) Type() Type { - return TypeBool -} - -// Null represents Null values and is able to take on the guise of -// any of the null-able types. -type Null struct { - annotations []Symbol - typ Type -} - -// Annotations satisfies Value. -func (n Null) Annotations() []Symbol { - return n.annotations -} - -// Binary satisfies Value. -func (n Null) Binary() []byte { - // TODO: Implement returning a byte based on the null type. - return nil -} - -// Text satisfies Value. -func (n Null) Text() []byte { - switch n.typ { - case TypeBlob: - return []byte("null.blob") - case TypeBool: - return []byte("null.bool") - case TypeClob: - return []byte("null.clob") - case TypeDecimal: - return []byte("null.decimal") - case TypeFloat: - return []byte("null.float") - case TypeInt: - return []byte("null.int") - case TypeList: - return []byte("null.list") - case TypeLongString: - return []byte("null.string") - case TypeNull: - return []byte("null.null") - case TypeSExp: - return []byte("null.sexp") - case TypeString: - return []byte("null.string") - case TypeStruct: - return []byte("null.struct") - case TypeSymbol: - return []byte("null.symbol") - case TypeTimestamp: - return []byte("null.timestamp") - default: - return []byte("null") - } -} - -// IsNull satisfies Value. -func (n Null) IsNull() bool { - return true -} - -// Type satisfies Value. -func (n Null) Type() Type { - return n.typ -} diff --git a/ion/types_binary.go b/ion/types_binary.go deleted file mode 100644 index 300d183f..00000000 --- a/ion/types_binary.go +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "encoding/base64" -) - -// This file contains the binary-like types Blob and Clob. - -const ( - textNullBlob = "null.blob" - textNullClob = "null.clob" -) - -// Blob is binary data of user-defined encoding. -type Blob struct { - annotations []Symbol - binary []byte - text []byte -} - -// Annotations satisfies Value. -func (b Blob) Annotations() []Symbol { - return b.annotations -} - -// Value returns the Base64 encoded version of the Blob. -func (b Blob) Value() []byte { - return b.Text() -} - -// Binary returns the raw binary representation of the Blob. If the -// representation was originally text and there is a problem decoding it, -// then this will panic. This is because it is assumed that the original -// parsing of the text value will catch improperly formatted encodings. -func (b Blob) Binary() []byte { - if len(b.binary) != 0 || len(b.text) == 0 { - return b.binary - } - - // Trim any whitespace characters from the text representation of - // the Blob. - trimmedText := bytes.Map(func(r rune) rune { - if r == ' ' || r == '\r' || r == '\n' || r == '\t' || r == '\f' || r == '\v' { - return -1 - } - return r - }, b.text) - - b.binary = make([]byte, base64.StdEncoding.DecodedLen(len(trimmedText))) - if _, err := base64.StdEncoding.Decode(b.binary, trimmedText); err != nil { - panic(err) - } - return b.binary -} - -// Text returns the Base64 encoded version of the Blob. -func (b Blob) Text() []byte { - if b.IsNull() { - return []byte(textNullBlob) - } - - if len(b.text) != 0 || len(b.binary) == 0 { - return b.text - } - - b.text = make([]byte, base64.StdEncoding.EncodedLen(len(b.binary))) - base64.StdEncoding.Encode(b.text, b.binary) - return b.text -} - -// IsNull satisfies Value. -func (b Blob) IsNull() bool { - return b.binary == nil && b.text == nil -} - -// Type satisfies Value. -func (b Blob) Type() Type { - return TypeBlob -} - -// Clob is text data of user-defined encoding. It is a binary type that is -// designed for binary values that are either text encoded in a code page that -// is ASCII compatible or should be octet editable by a human (escaped string -// syntax vs. base64 encoded data). -type Clob struct { - annotations []Symbol - text []byte -} - -// Value returns the single string version of the Clob. -func (c Clob) Value() string { - if c.IsNull() { - return "" - } - - return string(c.text) -} - -// Annotations satisfies Value. -func (c Clob) Annotations() []Symbol { - return c.annotations -} - -// Binary returns the raw binary representation of the Clob. Because the binary -// format is represented directly as the octet values this returns the same -// representation as Text(). -func (c Clob) Binary() []byte { - return c.text -} - -// Text returns a text representation of the Clob. -func (c Clob) Text() []byte { - if c.IsNull() { - return []byte(textNullClob) - } - - return c.text -} - -// IsNull satisfies Value. -func (c Clob) IsNull() bool { - return c.text == nil -} - -// Type satisfies Value. -func (c Clob) Type() Type { - return TypeClob -} diff --git a/ion/types_character.go b/ion/types_character.go deleted file mode 100644 index 25cf3de2..00000000 --- a/ion/types_character.go +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -// This file contains the string-like types: String and Symbol. - -// String is a unicode text literal of arbitrary length. -type String struct { - annotations []Symbol - text []byte -} - -func (s String) Value() string { - return string(s.text) -} - -// Annotations satisfies Value. -func (s String) Annotations() []Symbol { - return s.annotations -} - -// Binary satisfies Value. -func (s String) Binary() []byte { - // These are always sequences of Unicode characters, encoded as a sequence of UTF-8 octets. - return s.text -} - -// Text returns a string representation of the symbol if a string representation -// has been set. Otherwise it will be empty. -func (s String) Text() []byte { - return s.text -} - -// IsNull satisfies Value. -func (s String) IsNull() bool { - return s.text == nil -} - -// Type satisfies Value. -func (String) Type() Type { - return TypeString -} - -// Symbol is an interned identifier that is represented as an ID -// and/or text. If the id is 0 and the text is empty, then this -// represent null.symbol. -type Symbol struct { - annotations []Symbol - id int32 - quoted bool - text []byte -} - -// Id returns the ID of the Symbol if it has been set, or SymbolIDUnknown if -// it has not. -func (s Symbol) Id() int32 { - if s.id == 0 { - return SymbolIDUnknown - } - return s.id -} - -func (s Symbol) Value() string { - // TODO: Things with Symbol tables and looking up the value when we - // only have an ID. - return string(s.text) -} - -// Annotations satisfies Value. -func (s Symbol) Annotations() []Symbol { - return s.annotations -} - -// Binary satisfies Value. -func (s Symbol) Binary() []byte { - // TODO: Return symbol ID. - return nil -} - -// Text returns a string representation of the symbol if a string representation -// has been set. Otherwise it will be empty. -func (s Symbol) Text() []byte { - return s.text -} - -// IsNull satisfies Value. -func (s Symbol) IsNull() bool { - return s.id == 0 && len(s.text) == 0 -} - -// Type satisfies Value. -func (Symbol) Type() Type { - return TypeSymbol -} diff --git a/ion/types_container.go b/ion/types_container.go deleted file mode 100644 index cf7fe078..00000000 --- a/ion/types_container.go +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" -) - -// This file contains the container-like types: List, SExp, and Struct. - -const ( - textNullList = "null.list" - textNullSExp = "null.sexp" - textNullStruct = "null.struct" -) - -// List is an ordered collections of Values. The contents of the list are -// heterogeneous, each element can have a different type. Homogeneous lists -// may be imposed by schema validation tools. -type List struct { - annotations []Symbol - values []Value -} - -// Value returns the values that this list holds. -func (lst List) Value() []Value { - return lst.values -} - -// Annotations satisfies Value. -func (lst List) Annotations() []Symbol { - return lst.annotations -} - -// Binary satisfies Value. -func (lst List) Binary() []byte { - // TODO: Figure out how we want to do binary serialization of containers. - return nil -} - -// Text satisfies Value. -func (lst List) Text() []byte { - if lst.values == nil { - return []byte(textNullList) - } - - parts := make([][]byte, len(lst.values)) - for index, value := range lst.values { - parts[index] = value.Text() - } - - return bytes.Join(parts, []byte(",")) -} - -// IsNull satisfies Value. -func (lst List) IsNull() bool { - return lst.values == nil -} - -// Type satisfies Value. -func (List) Type() Type { - return TypeList -} - -// SExp (S-Expression) is an ordered collection of values with application-defined -// semantics. The contents of the list are -// heterogeneous, each element can have a different type. Homogeneous lists -// may be imposed by schema validation tools. -type SExp struct { - annotations []Symbol - values []Value -} - -// Value returns the values held within the s-expression. -func (s SExp) Value() []Value { - return s.values -} - -// Annotations satisfies Value. -func (s SExp) Annotations() []Symbol { - return s.annotations -} - -// Binary satisfies Value. -func (s SExp) Binary() []byte { - // TODO: Figure out how we want to do binary serialization of containers. - return nil -} - -// Text satisfies Value. -func (s SExp) Text() []byte { - if s.values == nil { - return []byte(textNullSExp) - } - - parts := make([][]byte, len(s.values)) - for index, value := range s.values { - parts[index] = value.Text() - } - - return bytes.Join(parts, []byte(" ")) -} - -// IsNull satisfies Value. -func (s SExp) IsNull() bool { - return s.values == nil -} - -// Type satisfies Value. -func (SExp) Type() Type { - return TypeSExp -} - -// StructField represents the field of a Struct. -type StructField struct { - Symbol Symbol - Value Value -} - -// Struct is an unordered collection of tagged values. -// When two fields in the same struct have the same name we say there -// are “repeated names” or “repeated fields”. All such fields must be -// preserved, any StructField that has a repeated name must not be discarded. -type Struct struct { - annotations []Symbol - fields []StructField -} - -// Value returns the fields that this struct holds. -func (s Struct) Value() []StructField { - return s.fields -} - -// Annotations satisfies Value. -func (s Struct) Annotations() []Symbol { - return s.annotations -} - -// Binary satisfies Value. -func (s Struct) Binary() []byte { - // TODO: Figure out how we want to do binary serialization of containers. - return nil -} - -// Text satisfies Value. -func (s Struct) Text() []byte { - if s.fields == nil { - return []byte(textNullStruct) - } - - parts := make([][]byte, len(s.fields)) - for index, fld := range s.fields { - line := append(fld.Symbol.Text(), ':') - parts[index] = append(line, fld.Value.Text()...) - } - - return bytes.Join(parts, []byte(",")) -} - -// IsNull satisfies Value. -func (s Struct) IsNull() bool { - return s.fields == nil -} - -// Type satisfies Value. -func (Struct) Type() Type { - return TypeStruct -} diff --git a/ion/types_numeric.go b/ion/types_numeric.go deleted file mode 100644 index 3fe6e882..00000000 --- a/ion/types_numeric.go +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "fmt" - "math" - "math/big" - "strconv" - - "github.com/pkg/errors" - "github.com/shopspring/decimal" -) - -// This file contains the numeric types Decimal, Float, and Int. - -const ( - textNullDecimal = "null.decimal" - textNullFloat = "null.float" - textNullInt = "null.int" -) - -// Decimal is a decimal-encoded real number of arbitrary precision. -// The decimal’s value is coefficient * 10 ^ exponent. -type Decimal struct { - annotations []Symbol - value *decimal.Decimal - binary []byte - text []byte - isSet bool -} - -func (d Decimal) Value() decimal.Decimal { - if d.IsNull() { - return decimal.New(0, 0) - } - - if len(d.text) > 0 { - // decimal uses "e" for exponent while Ion uses "d" - text := bytes.ReplaceAll(d.text, []byte{'d'}, []byte{'e'}) - text = bytes.ReplaceAll(text, []byte{'D'}, []byte{'E'}) - // decimal doesn't handle underscores - text = bytes.ReplaceAll(text, []byte{'_'}, []byte{}) - - val, err := decimal.NewFromString(string(text)) - if err != nil { - panic(err) - } - - d.value = &val - return val - } - - if binLen := len(d.binary); binLen > 0 { - // Bytes are comprised of a VarInt and an Int. - dataReader := bytes.NewReader(d.binary) - exponent, errExp := readVarInt64(dataReader) - if errExp != nil { - panic(errors.WithMessage(errExp, "unable to read exponent part of decimal")) - } - - coefficient := d.binary[binLen-dataReader.Len():] - // SetBytes takes unsigned bytes in big-endian order, so we need to copy the - // sign of our Int and then erase the traces of that sign. - isNegative := (coefficient[0] & 0x80) != 0 - coefficient[0] &= 0x7F - - bigInt := &big.Int{} - bigInt.SetBytes(coefficient) - if isNegative { - bigInt.Neg(bigInt) - } - - val := decimal.NewFromBigInt(bigInt, int32(exponent)) - d.value = &val - return val - } - return decimal.New(0, 0) -} - -// Annotations satisfies Value. -func (d Decimal) Annotations() []Symbol { - return d.annotations -} - -// Binary returns the Decimal in binary form. -func (d Decimal) Binary() []byte { - if d.binary != nil { - return d.binary - } - - if d.IsNull() { - return []byte{} - } - - // TODO: Turn value into the binary representation. - _ = d.Value() - - return d.binary -} - -// Text returns the Decimal in text form. If no exponent is set, -// then the text form will not include one. Otherwise the formatted -// text is in the form: d -func (d Decimal) Text() []byte { - if d.text != nil { - return d.text - } - - if d.IsNull() { - return []byte(textNullDecimal) - } - - val := d.Value() - text := "" - // If the decimal we have is represented by an exact float value, then - // use that to make the string. Otherwise we need to turn the big.Int - // into a string. - if floatVal, exact := val.Float64(); exact { - d.text = strconv.AppendFloat(nil, floatVal, 'f', -1, 64) - } else { - if val.IsNegative() { - text = "-" - } - text += val.Coefficient().String() - } - - if exp := val.Exponent(); exp != 0 { - d.text = append(d.text, []byte(fmt.Sprintf("d%d", exp))...) - } - - d.text = []byte(text) - return d.text -} - -// IsNull satisfies Value. -func (d Decimal) IsNull() bool { - return !d.isSet -} - -// Type satisfies Value. -func (d Decimal) Type() Type { - return TypeDecimal -} - -// Float is a binary-encoded floating point number (IEEE 64-bit). -type Float struct { - annotations []Symbol - value *float64 - binary []byte - text []byte - isSet bool -} - -// Value returns the value of this Float, or 0 if the value is null. -func (f Float) Value() float64 { - if f.value != nil { - return *f.value - } - - if f.IsNull() { - return 0 - } - - // A binary length other than 0 (no value), 4, or 8 is not accepted by - // the parser. - if binLen := len(f.binary); binLen == 4 || binLen == 8 { - if binLen == 4 { - u32 := (uint32(f.binary[3]) << 24) | (uint32(f.binary[2]) << 16) | (uint32(f.binary[1]) << 8) | uint32(f.binary[0]) - f64 := float64(math.Float32frombits(u32)) - f.value = &f64 - } else { - u64 := (uint64(f.binary[7]) << 56) | (uint64(f.binary[6]) << 48) | (uint64(f.binary[5]) << 40) | (uint64(f.binary[4]) << 32) | - (uint64(f.binary[3]) << 24) | (uint64(f.binary[2]) << 16) | (uint64(f.binary[1]) << 8) | uint64(f.binary[0]) - f64 := math.Float64frombits(u64) - f.value = &f64 - } - return *f.value - } - - if len(f.text) > 0 { - text := string(bytes.ReplaceAll(f.text, []byte{'_'}, []byte{})) - f64, err := strconv.ParseFloat(text, 64) - // The float value when the given string is too big is - // +/- infinity. - if err != nil { - numErr, ok := err.(*strconv.NumError) - if !ok || numErr.Err != strconv.ErrRange { - panic(err) - } - } - f.value = &f64 - return f64 - } - - // It's possible for a binary float to be set to a zero-length slice, in which case - // the value is not null but there is no binary or text value to parse. - return 0 -} - -// Annotations satisfies Value. -func (f Float) Annotations() []Symbol { - return f.annotations -} - -// Binary returns the Float in binary form. -func (f Float) Binary() []byte { - if f.binary != nil { - return f.binary - } - - if f.IsNull() { - return []byte{} - } - - // TODO: Turn value into the binary representation. - _ = f.Value() - - return f.binary -} - -// Text returns the Float in text form. -func (f Float) Text() []byte { - if f.text != nil { - return f.text - } - - if f.IsNull() { - return []byte(textNullFloat) - } - - f.text = strconv.AppendFloat(nil, f.Value(), 'f', -1, 64) - return f.text -} - -// IsNull satisfies Value. -func (f Float) IsNull() bool { - return !f.isSet -} - -// Type satisfies Value. -func (f Float) Type() Type { - return TypeFloat -} - -// intBase represents the various bases (binary, decimal, hexadecimal) -// that can be used to represent an integer in text. The zero value -// is intBase10 which is decimal. -type intBase int - -const ( - intBase10 intBase = iota - intBase2 - intBase16 -) - -// Int is a signed integer of arbitrary size. -type Int struct { - annotations []Symbol - value *big.Int - binary []byte - text []byte - base intBase - isNegative bool - isSet bool -} - -// Value returns the representation of this Int as a big.Int. -// If this represents null.Int, then nil is returned. -func (i Int) Value() *big.Int { - if i.IsNull() || i.value != nil { - return i.value - } - - if len(i.text) > 0 { - text := string(bytes.ReplaceAll(i.text, []byte{'_'}, []byte{})) - i.value = new(big.Int) - i.value.SetString(text, 0) - return i.value - } - - if len(i.binary) > 0 { - // TODO - } - return nil -} - -// Annotations satisfies Value. -func (i Int) Annotations() []Symbol { - return i.annotations -} - -// Binary returns the Int in binary form. -func (i Int) Binary() []byte { - if i.binary != nil { - return i.binary - } - - val := i.Value() - if val == nil { - return []byte{} - } - - // TODO: Turn value into the binary representation. - _ = i.Value() - - return i.binary -} - -// Text returns the Int in text form. -func (i Int) Text() []byte { - if i.text != nil { - return i.text - } - - val := i.Value() - if val == nil { - return []byte(textNullInt) - } - - i.text = val.Append(nil, 10) - return i.text -} - -// IsNull satisfies Value. -func (i Int) IsNull() bool { - return !i.isSet -} - -// Type satisfies Value. -func (i Int) Type() Type { - return TypeInt -} diff --git a/ion/types_test.go b/ion/types_test.go deleted file mode 100644 index 50790461..00000000 --- a/ion/types_test.go +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "fmt" - "strconv" - "testing" - - "github.com/google/go-cmp/cmp" -) - -// Ensure that all of our types satisfy the Value interface. -var _ Value = Blob{} -var _ Value = Bool{} -var _ Value = Clob{} -var _ Value = Decimal{} -var _ Value = Float{} -var _ Value = Int{} -var _ Value = List{} -var _ Value = Null{} -var _ Value = padding{} -var _ Value = SExp{} -var _ Value = String{} -var _ Value = Struct{} -var _ Value = Symbol{} -var _ Value = Timestamp{} - -func TestBool(t *testing.T) { - tests := []struct { - b Bool - isNull bool - expectedValue bool - expectedText string - }{ - {isNull: true, expectedText: "null.bool"}, - {b: Bool{value: true}, isNull: true, expectedText: "null.bool"}, - {b: Bool{isSet: true}, expectedText: "false"}, - {b: Bool{isSet: true, value: true}, expectedValue: true, expectedText: "true"}, - } - - for _, tst := range tests { - test := tst - t.Run(fmt.Sprintf("%#v", test.b), func(t *testing.T) { - if isNull := test.b.IsNull(); isNull != test.isNull { - t.Error("expected IsNull", test.isNull, "but found", isNull) - } - if found := test.b.Value(); found != test.expectedValue { - t.Error("expected value", test.expectedValue, "but found", found) - } - if diff := cmp.Diff(test.expectedText, string(test.b.Text())); diff != "" { - t.Error("(-expected, +found)", diff) - } - if typ := test.b.Type(); typ != TypeBool { - t.Error("expected TypeBool", TypeBool, "but found", typ) - } - }) - } -} - -func TestNull(t *testing.T) { - tests := []struct { - typ Type - expectedText string - }{ - {expectedText: "null.null"}, - {typ: TypeBlob, expectedText: "null.blob"}, - {typ: TypeBool, expectedText: "null.bool"}, - {typ: TypeClob, expectedText: "null.clob"}, - {typ: TypeDecimal, expectedText: "null.decimal"}, - {typ: TypeFloat, expectedText: "null.float"}, - {typ: TypeInt, expectedText: "null.int"}, - {typ: TypeList, expectedText: "null.list"}, - {typ: TypeLongString, expectedText: "null.string"}, - {typ: TypeSExp, expectedText: "null.sexp"}, - {typ: TypeString, expectedText: "null.string"}, - {typ: TypeStruct, expectedText: "null.struct"}, - {typ: TypeSymbol, expectedText: "null.symbol"}, - {typ: TypeTimestamp, expectedText: "null.timestamp"}, - } - - for _, tst := range tests { - test := tst - t.Run(strconv.Itoa(int(test.typ)), func(t *testing.T) { - null := &Null{typ: test.typ} - if diff := cmp.Diff(test.expectedText, string(null.Text())); diff != "" { - t.Error("(-expected, +found)", diff) - } - if diff := cmp.Diff(test.typ, null.Type()); diff != "" { - t.Error("(-expected, +found)", diff) - } - if !null.IsNull() { - t.Error("expected IsNull to be true") - } - }) - } -} diff --git a/ion/types_timestamp.go b/ion/types_timestamp.go deleted file mode 100644 index b7078ba4..00000000 --- a/ion/types_timestamp.go +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package ion - -import ( - "bytes" - "fmt" - "time" - - "github.com/pkg/errors" -) - -// This file contains the Timestamp type. - -type TimestampPrecision int - -const ( - TimestampPrecisionYear TimestampPrecision = iota + 1 - TimestampPrecisionMonth - TimestampPrecisionDay - TimestampPrecisionMinute - TimestampPrecisionSecond - TimestampPrecisionMillisecond1 - TimestampPrecisionMillisecond2 - TimestampPrecisionMillisecond3 - TimestampPrecisionMillisecond4 - TimestampPrecisionMicrosecond1 - TimestampPrecisionMicrosecond2 - TimestampPrecisionMicrosecond3 - TimestampPrecisionMicrosecond4 -) - -var timestampPrecisionNameMap = map[TimestampPrecision]string{ - TimestampPrecisionYear: "Year", TimestampPrecisionMonth: "Month", TimestampPrecisionDay: "Day", - TimestampPrecisionMinute: "Minute", TimestampPrecisionSecond: "Second", - TimestampPrecisionMillisecond1: "Millisecond Tenths", TimestampPrecisionMillisecond2: "Millisecond Hundredths", - TimestampPrecisionMillisecond3: "Millisecond Thousandths", TimestampPrecisionMillisecond4: "Millisecond TenThousandths", - TimestampPrecisionMicrosecond1: "Microsecond Tenths", TimestampPrecisionMicrosecond2: "Microsecond Hundredths", - TimestampPrecisionMicrosecond3: "Microsecond Thousandths", TimestampPrecisionMicrosecond4: "Microsecond TenThousandths", -} - -// String satisfies Stringer. -func (t TimestampPrecision) String() string { - if s, ok := timestampPrecisionNameMap[t]; ok { - return s - } - return "Unknown" -} - -const ( - offsetLocalUnknown = -1 -) - -/* - From https://www.w3.org/TR/NOTE-datetime - - Year: - YYYY (eg 1997) - Year and month: - YYYY-MM (eg 1997-07) - Complete date: - YYYY-MM-DD (eg 1997-07-16) - Complete date plus hours and minutes: - YYYY-MM-DDThh:mmTZD (eg 1997-07-16T19:20+01:00) - Complete date plus hours, minutes and seconds: - YYYY-MM-DDThh:mm:ssTZD (eg 1997-07-16T19:20:30+01:00) - Complete date plus hours, minutes, seconds and a decimal fraction of a second - YYYY-MM-DDThh:mm:ss.sTZD (eg 1997-07-16T19:20:30.45+01:00) - - where: - - YYYY = four-digit year - MM = two-digit month (01=January, etc.) - DD = two-digit day of month (01 through 31) - hh = two digits of hour (00 through 23) (am/pm NOT allowed) - mm = two digits of minute (00 through 59) - ss = two digits of second (00 through 59) - s = one or more digits representing a decimal fraction of a second - TZD = time zone designator (Z or +hh:mm or -hh:mm) -*/ -var ( - // precisionFormatMap maps the above valid types to the Go magic time - // format string "Mon Jan 2 15:04:05 -0700 MST 2006" - precisionFormatMap = map[TimestampPrecision]string{ - TimestampPrecisionYear: "2006T", - TimestampPrecisionMonth: "2006-01T", - TimestampPrecisionDay: "2006-01-02T", - TimestampPrecisionMinute: "2006-01-02T15:04", - TimestampPrecisionSecond: "2006-01-02T15:04:05", - TimestampPrecisionMillisecond1: "2006-01-02T15:04:05.0", - TimestampPrecisionMillisecond2: "2006-01-02T15:04:05.00", - TimestampPrecisionMillisecond3: "2006-01-02T15:04:05.000", - TimestampPrecisionMillisecond4: "2006-01-02T15:04:05.0000", - TimestampPrecisionMicrosecond1: "2006-01-02T15:04:05.00000", - TimestampPrecisionMicrosecond2: "2006-01-02T15:04:05.000000", - TimestampPrecisionMicrosecond3: "2006-01-02T15:04:05.0000000", - TimestampPrecisionMicrosecond4: "2006-01-02T15:04:05.00000000", - } -) - -// Timestamp represents date/time/timezone moments of arbitrary precision. -// Two timestamps are only equivalent if they represent the same instant -// with the same offset and precision. -type Timestamp struct { - annotations []Symbol - binary []byte - text []byte - // offset in minutes. Use offsetLocalUnknown to denote when an explicit offset - // is not known. - offset time.Duration - precision TimestampPrecision - value time.Time -} - -// Precision returns to what precision the timestamp was set to. -func (t Timestamp) Precision() TimestampPrecision { - if t.precision == 0 && len(t.text) > 0 { - t.precision = determinePrecision(t.text) - } - return t.precision -} - -// Value returns the value of the timestamp. -func (t Timestamp) Value() time.Time { - if t.IsNull() || !t.value.IsZero() { - return t.value - } - - if len(t.text) > 0 { - if t.precision == 0 { - t.precision = determinePrecision(t.text) - } - - text := bytes.TrimSuffix(t.text, []byte("Z")) - // TODO: Handle the unknown local offset case properly. - // -00:00 is a special offset which means that the offset is local. - text = bytes.TrimSuffix(text, []byte("-00:00")) - - format := precisionFormatMap[t.precision] - if bytes.Count(text, []byte("-")) == 3 || bytes.LastIndex(text, []byte("+")) > 0 { - format += "-07:00" - } - - // The "T" is optional when the precision is month or day, so add it - // to the text if it's missing so that the Time parser doesn't fail. - if (t.precision == TimestampPrecisionMonth || t.precision == TimestampPrecisionDay) && !bytes.HasSuffix(text, []byte("T")) { - text = append(text, 'T') - } - - timestamp, err := time.Parse(format, string(text)) - if err != nil { - panic(errors.Wrap(err, "unable to parse timestamp")) - } - t.value = timestamp - } - - return t.value -} - -func determinePrecision(text []byte) TimestampPrecision { - // There is no real variability in format until we get to dealing - // with time, so we can handle all of the date-only cases using length. - switch size := len(text); { - case size <= 5: - return TimestampPrecisionYear - case size <= 8: - return TimestampPrecisionMonth - case size <= 11: - return TimestampPrecisionDay - } - - // Trim off the date portion. We only care about time now. - tim := text[bytes.Index(text, []byte("T"))+1:] - - // Trim off any timezone portion. - tim = bytes.TrimSuffix(tim, []byte("Z")) - if plusIndex := bytes.Index(tim, []byte("+")); plusIndex > 0 { - tim = tim[:plusIndex] - } - if minusIndex := bytes.Index(tim, []byte("-")); minusIndex > 0 { - tim = tim[:minusIndex] - } - - // Now we can just count characters. - switch len(tim) { - case 5: - return TimestampPrecisionMinute - case 8: - return TimestampPrecisionSecond - case 10: - return TimestampPrecisionMillisecond1 - case 11: - return TimestampPrecisionMillisecond2 - case 12: - return TimestampPrecisionMillisecond3 - case 13: - return TimestampPrecisionMillisecond4 - case 14: - return TimestampPrecisionMicrosecond1 - case 15: - return TimestampPrecisionMicrosecond2 - case 16: - return TimestampPrecisionMicrosecond3 - case 17: - return TimestampPrecisionMicrosecond4 - } - - return TimestampPrecisionDay -} - -// Annotations satisfies Value. -func (t Timestamp) Annotations() []Symbol { - return t.annotations -} - -// Binary satisfies Value. -func (t Timestamp) Binary() []byte { - if len(t.binary) > 0 { - return t.binary - } - - // TODO - return t.binary -} - -// Text satisfies Value. -func (t Timestamp) Text() []byte { - if t.IsNull() { - return []byte(textNullTimestamp) - } - if len(t.text) > 0 { - return t.text - } - - val := t.Value() - t.text = []byte(val.Format(precisionFormatMap[t.precision])) - - // If the precision doesn't include time, then we don't - // need to generate an offset. - if t.precision <= TimestampPrecisionDay { - return t.text - } - - var offset []byte - switch t.offset { - case offsetLocalUnknown: - // Do nothing. - case 0: - offset = append(offset, 'Z') - default: - if t.offset > 0 { - offset = append(offset, '+') - } else { - offset = append(offset, '-') - } - hh := int(t.offset.Hours()) - mm := int(t.offset.Minutes()) - offset = append(offset, []byte(fmt.Sprintf("%02d:%02d", hh, mm))...) - } - t.text = append(t.text, offset...) - - return t.text -} - -// IsNull satisfies Value. -func (t Timestamp) IsNull() bool { - return t.binary == nil && t.text == nil -} - -// Type satisfies Value. -func (t Timestamp) Type() Type { - return TypeTimestamp -} From 9bf5542bdd41a3aa890b09800995c4c9500b4b0d Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 19:55:07 -0700 Subject: [PATCH 42/56] Move into ion folder --- ion/binaryreader.go | 487 +++++++++++++++ ion/binaryreader_test.go | 454 ++++++++++++++ ion/binarywriter.go | 562 +++++++++++++++++ ion/binarywriter_test.go | 387 ++++++++++++ ion/bits.go | 285 +++++++++ ion/bits_test.go | 205 ++++++ ion/bitstream.go | 935 ++++++++++++++++++++++++++++ ion/bitstream_test.go | 111 ++++ ion/buf.go | 119 ++++ ion/buf_test.go | 79 +++ ion/catalog.go | 88 +++ ion/catalog_test.go | 62 ++ ion/consts.go | 52 ++ ion/ctx.go | 65 ++ ion/decimal.go | 342 +++++++++++ ion/decimal_test.go | 312 ++++++++++ ion/err.go | 88 +++ ion/fields.go | 122 ++++ ion/marshal.go | 338 ++++++++++ ion/marshal_test.go | 162 +++++ ion/reader.go | 388 ++++++++++++ ion/reader_test.go | 127 ++++ ion/skipper.go | 863 ++++++++++++++++++++++++++ ion/skipper_test.go | 178 ++++++ ion/symboltable.go | 475 ++++++++++++++ ion/symboltable_test.go | 181 ++++++ ion/textreader.go | 658 ++++++++++++++++++++ ion/textreader_test.go | 788 ++++++++++++++++++++++++ ion/textutils.go | 382 ++++++++++++ ion/textutils_test.go | 164 +++++ ion/textwriter.go | 359 +++++++++++ ion/textwriter_test.go | 351 +++++++++++ ion/tokenizer.go | 1265 ++++++++++++++++++++++++++++++++++++++ ion/tokenizer_test.go | 571 +++++++++++++++++ ion/type.go | 125 ++++ ion/type_test.go | 21 + ion/unmarshal.go | 673 ++++++++++++++++++++ ion/unmarshal_test.go | 539 ++++++++++++++++ ion/writer.go | 165 +++++ 39 files changed, 13528 insertions(+) create mode 100644 ion/binaryreader.go create mode 100644 ion/binaryreader_test.go create mode 100644 ion/binarywriter.go create mode 100644 ion/binarywriter_test.go create mode 100644 ion/bits.go create mode 100644 ion/bits_test.go create mode 100644 ion/bitstream.go create mode 100644 ion/bitstream_test.go create mode 100644 ion/buf.go create mode 100644 ion/buf_test.go create mode 100644 ion/catalog.go create mode 100644 ion/catalog_test.go create mode 100644 ion/consts.go create mode 100644 ion/ctx.go create mode 100644 ion/decimal.go create mode 100644 ion/decimal_test.go create mode 100644 ion/err.go create mode 100644 ion/fields.go create mode 100644 ion/marshal.go create mode 100644 ion/marshal_test.go create mode 100644 ion/reader.go create mode 100644 ion/reader_test.go create mode 100644 ion/skipper.go create mode 100644 ion/skipper_test.go create mode 100644 ion/symboltable.go create mode 100644 ion/symboltable_test.go create mode 100644 ion/textreader.go create mode 100644 ion/textreader_test.go create mode 100644 ion/textutils.go create mode 100644 ion/textutils_test.go create mode 100644 ion/textwriter.go create mode 100644 ion/textwriter_test.go create mode 100644 ion/tokenizer.go create mode 100644 ion/tokenizer_test.go create mode 100644 ion/type.go create mode 100644 ion/type_test.go create mode 100644 ion/unmarshal.go create mode 100644 ion/unmarshal_test.go create mode 100644 ion/writer.go diff --git a/ion/binaryreader.go b/ion/binaryreader.go new file mode 100644 index 00000000..9a64a7c2 --- /dev/null +++ b/ion/binaryreader.go @@ -0,0 +1,487 @@ +package ion + +import ( + "bufio" + "fmt" +) + +// A binaryReader reads binary Ion. +type binaryReader struct { + reader + + bits bitstream + cat Catalog + lst SymbolTable +} + +func newBinaryReaderBuf(in *bufio.Reader, cat Catalog) Reader { + r := &binaryReader{ + cat: cat, + } + r.bits.Init(in) + return r +} + +// SymbolTable returns the current symbol table. +func (r *binaryReader) SymbolTable() SymbolTable { + return r.lst +} + +// Next moves the reader to the next value. +func (r *binaryReader) Next() bool { + if r.eof || r.err != nil { + return false + } + + r.clear() + + done := false + for !done { + done, r.err = r.next() + if r.err != nil { + return false + } + } + + return !r.eof +} + +// Next consumes the next raw value from the stream, returning true if it +// represents a user-facing value and false if it does not. +func (r *binaryReader) next() (bool, error) { + if err := r.bits.Next(); err != nil { + return false, err + } + + code := r.bits.Code() + switch code { + case bitcodeEOF: + r.eof = true + return true, nil + + case bitcodeBVM: + err := r.readBVM() + return false, err + + case bitcodeFieldID: + err := r.readFieldName() + return false, err + + case bitcodeAnnotation: + err := r.readAnnotations() + return false, err + + case bitcodeNull: + if !r.bits.IsNull() { + // NOP padding; skip it and keep going. + err := r.bits.SkipValue() + return false, err + } + r.valueType = NullType + return true, nil + + case bitcodeFalse, bitcodeTrue: + r.valueType = BoolType + if !r.bits.IsNull() { + r.value = (r.bits.Code() == bitcodeTrue) + } + return true, nil + + case bitcodeInt, bitcodeNegInt: + r.valueType = IntType + if !r.bits.IsNull() { + val, err := r.bits.ReadInt() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeFloat: + r.valueType = FloatType + if !r.bits.IsNull() { + val, err := r.bits.ReadFloat() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeDecimal: + r.valueType = DecimalType + if !r.bits.IsNull() { + val, err := r.bits.ReadDecimal() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeTimestamp: + r.valueType = TimestampType + if !r.bits.IsNull() { + val, err := r.bits.ReadTimestamp() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeSymbol: + r.valueType = SymbolType + if !r.bits.IsNull() { + id, err := r.bits.ReadSymbolID() + if err != nil { + return false, err + } + r.value = r.resolve(id) + } + return true, nil + + case bitcodeString: + r.valueType = StringType + if !r.bits.IsNull() { + val, err := r.bits.ReadString() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeClob: + r.valueType = ClobType + if !r.bits.IsNull() { + val, err := r.bits.ReadBytes() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeBlob: + r.valueType = BlobType + if !r.bits.IsNull() { + val, err := r.bits.ReadBytes() + if err != nil { + return false, err + } + r.value = val + } + return true, nil + + case bitcodeList: + r.valueType = ListType + if !r.bits.IsNull() { + r.value = ListType + } + return true, nil + + case bitcodeSexp: + r.valueType = SexpType + if !r.bits.IsNull() { + r.value = SexpType + } + return true, nil + + case bitcodeStruct: + r.valueType = StructType + if !r.bits.IsNull() { + r.value = StructType + } + + // If it's a local symbol table, install it and keep going. + if r.ctx.peek() == ctxAtTopLevel && isIonSymbolTable(r.annotations) { + err := r.readLocalSymbolTable() + return false, err + } + + return true, nil + } + panic(fmt.Sprintf("invalid bitcode %v", code)) +} + +func isIonSymbolTable(as []string) bool { + return len(as) > 0 && as[0] == "$ion_symbol_table" +} + +// ReadBVM reads a BVM, validates it, and resets the local symbol table. +func (r *binaryReader) readBVM() error { + major, minor, err := r.bits.ReadBVM() + if err != nil { + return err + } + + switch major { + case 1: + switch minor { + case 0: + r.lst = V1SystemSymbolTable + return nil + } + } + + return &UnsupportedVersionError{ + int(major), + int(minor), + r.bits.Pos() - 4, + } +} + +// ReadLocalSymbolTable reads and installs a new local symbol table. +func (r *binaryReader) readLocalSymbolTable() error { + if r.IsNull() { + r.clear() + r.lst = V1SystemSymbolTable + return nil + } + + if err := r.StepIn(); err != nil { + return err + } + + imps := []SharedSymbolTable{} + syms := []string{} + + for r.Next() { + var err error + switch r.FieldName() { + case "imports": + imps, err = r.readImports() + case "symbols": + syms, err = r.readSymbols() + } + if err != nil { + return err + } + } + + if err := r.StepOut(); err != nil { + return err + } + + r.lst = NewLocalSymbolTable(imps, syms) + return nil +} + +// ReadImports reads the imports field of a local symbol table. +func (r *binaryReader) readImports() ([]SharedSymbolTable, error) { + if r.valueType == SymbolType && r.value == "$ion_symbol_table" { + // Special case that imports the current local symbol table. + if r.lst == nil || r.lst == V1SystemSymbolTable { + return nil, nil + } + + imps := r.lst.Imports() + lsst := NewSharedSymbolTable("", 0, r.lst.Symbols()) + return append(imps, lsst), nil + } + + if r.Type() != ListType || r.IsNull() { + return nil, nil + } + if err := r.StepIn(); err != nil { + return nil, err + } + + imps := []SharedSymbolTable{} + for r.Next() { + imp, err := r.readImport() + if err != nil { + return nil, err + } + if imp != nil { + imps = append(imps, imp) + } + } + + err := r.StepOut() + return imps, err +} + +// ReadImport reads an import definition. +func (r *binaryReader) readImport() (SharedSymbolTable, error) { + if r.Type() != StructType || r.IsNull() { + return nil, nil + } + if err := r.StepIn(); err != nil { + return nil, err + } + + name := "" + version := 0 + maxID := uint64(0) + + for r.Next() { + var err error + switch r.FieldName() { + case "name": + if r.Type() == StringType { + name, err = r.StringValue() + } + case "version": + if r.Type() == IntType { + version, err = r.IntValue() + } + case "max_id": + if r.Type() == IntType { + var i int64 + i, err = r.Int64Value() + if i < 0 { + i = 0 + } + maxID = uint64(i) + } + } + if err != nil { + return nil, err + } + } + + if err := r.StepOut(); err != nil { + return nil, err + } + + if name == "" || name == "$ion" { + return nil, nil + } + if version < 1 { + version = 1 + } + + var imp SharedSymbolTable + if r.cat != nil { + imp = r.cat.FindExact(name, version) + if imp == nil { + imp = r.cat.FindLatest(name) + } + } + + if maxID == 0 { + if imp == nil || version != imp.Version() { + return nil, fmt.Errorf("ion: import of shared table %v/%v lacks a valid max_id, but an exact "+ + "match was not found in the catalog", name, version) + } + maxID = imp.MaxID() + } + + if imp == nil { + imp = &bogusSST{ + name: name, + version: version, + maxID: maxID, + } + } else { + imp = imp.Adjust(maxID) + } + + return imp, nil +} + +// ReadSymbols reads the symbols from a symbol table. +func (r *binaryReader) readSymbols() ([]string, error) { + if r.Type() != ListType { + return nil, nil + } + if err := r.StepIn(); err != nil { + return nil, err + } + + syms := []string{} + for r.Next() { + if r.Type() == StringType { + sym, err := r.StringValue() + if err != nil { + return nil, err + } + syms = append(syms, sym) + } else { + syms = append(syms, "") + } + } + + err := r.StepOut() + return syms, err +} + +// ReadFieldName reads and resolves a field name. +func (r *binaryReader) readFieldName() error { + id, err := r.bits.ReadFieldID() + if err != nil { + return err + } + + r.fieldName = r.resolve(id) + return nil +} + +// ReadAnnotations reads and resolves a set of annotations. +func (r *binaryReader) readAnnotations() error { + ids, err := r.bits.ReadAnnotationIDs() + if err != nil { + return err + } + + as := make([]string, len(ids)) + for i, id := range ids { + as[i] = r.resolve(id) + } + + r.annotations = as + return nil +} + +// Resolve resolves a symbol ID to a symbol value (possibly ${id} if we're +// missing the appropriate symbol table). +func (r *binaryReader) resolve(id uint64) string { + s, ok := r.lst.FindByID(id) + if !ok { + return fmt.Sprintf("$%v", id) + } + return s +} + +// StepIn steps in to a container-type value +func (r *binaryReader) StepIn() error { + if r.err != nil { + return r.err + } + + if r.valueType != ListType && r.valueType != SexpType && r.valueType != StructType { + return &UsageError{"Reader.StepIn", fmt.Sprintf("cannot step in to a %v", r.valueType)} + } + if r.value == nil { + return &UsageError{"Reader.StepIn", "cannot step in to a null container"} + } + + r.ctx.push(containerTypeToCtx(r.valueType)) + r.clear() + r.bits.StepIn() + + return nil +} + +// StepOut steps out of a container-type value. +func (r *binaryReader) StepOut() error { + if r.err != nil { + return r.err + } + if r.ctx.peek() == ctxAtTopLevel { + return &UsageError{"Reader.StepOut", "cannot step out of top-level datagram"} + } + + if err := r.bits.StepOut(); err != nil { + return err + } + + r.clear() + r.ctx.pop() + r.eof = false + + return nil +} diff --git a/ion/binaryreader_test.go b/ion/binaryreader_test.go new file mode 100644 index 00000000..4f0901b5 --- /dev/null +++ b/ion/binaryreader_test.go @@ -0,0 +1,454 @@ +package ion + +import ( + "fmt" + "math" + "math/big" + "testing" + "time" +) + +func TestReadBadBVMs(t *testing.T) { + t.Run("E00200E9", func(t *testing.T) { + // Need a good first one or we'll get sent to the text reader. + r := NewReaderBytes([]byte{0xE0, 0x01, 0x00, 0xEA, 0xE0, 0x02, 0x00, 0xE9}) + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() == nil { + t.Fatal("err is nil") + } + }) + + t.Run("E00200EA", func(t *testing.T) { + r := NewReaderBytes([]byte{0xE0, 0x02, 0x00, 0xEA}) + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() == nil { + t.Fatal("err is nil") + } + + uve, ok := r.Err().(*UnsupportedVersionError) + if !ok { + t.Fatal("err is not an UnsupportedVersionError") + } + if uve.Major != 2 { + t.Errorf("expected major=2, got %v", uve.Major) + } + if uve.Minor != 0 { + t.Errorf("expected minor=0, got %v", uve.Minor) + } + }) +} + +func TestReadNullLST(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE4, 0x82, 0x83, 0x87, 0xDF, + 0x71, 0x09, + } + r := NewReaderBytes(ion) + _symbol(t, r, "$ion_shared_symbol_table") + _eof(t, r) +} + +func TestReadEmptyLST(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE4, 0x82, 0x83, 0x87, 0xD0, + 0x71, 0x09, + } + r := NewReaderBytes(ion) + _symbol(t, r, "$ion_shared_symbol_table") + _eof(t, r) +} + +func TestReadBadLST(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE3, 0x81, 0x83, 0xD9, + 0x86, 0xB7, 0xD6, // imports:[{ + 0x84, 0x81, 'a', // name: "a", + 0x85, 0x21, 0x01, // version: 1}]} + 0x0F, // null + } + r := NewReaderBytes(ion) + if r.Next() { + t.Fatal("next returned true") + } + if r.Err() == nil { + t.Fatal("err is nil") + } +} + +func TestReadMultipleLSTs(t *testing.T) { + r := readBinary([]byte{ + 0x71, 0x0B, // $11 + 0x71, 0x6F, // bar + 0xE3, 0x81, 0x83, 0xDF, // $ion_symbol_table::null.struct + 0xEE, 0x8F, 0x81, 0x83, 0xDD, // $ion_symbol_table::{ + 0x86, 0x71, 0x03, // imports: $ion_symbol_table, + 0x87, 0xB8, // symbols:[ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" ]} + 0x71, 0x0B, // bar + 0x71, 0x0C, // $12 + 0x71, 0x6F, // $111 + 0xEC, 0x81, 0x83, 0xD9, // $ion_symbol_table::{ + 0x86, 0x71, 0x03, // imports: $ion_symbol_table + 0x87, 0xB4, // symbols:[ + 0x83, 'b', 'a', 'z', // "baz" ]} + 0x71, 0x0B, // bar + 0x71, 0x0C, // baz + }) + _symbol(t, r, "$11") + _symbol(t, r, "bar") + + _symbol(t, r, "bar") + _symbol(t, r, "$12") + _symbol(t, r, "$111") + + _symbol(t, r, "bar") + _symbol(t, r, "baz") + _eof(t, r) +} + +func TestReadBinaryLST(t *testing.T) { + r := readBinary([]byte{0x0F}) + _next(t, r, NullType) + + lst := r.SymbolTable() + if lst == nil { + t.Fatal("symboltable is nil") + } + + if lst.MaxID() != 111 { + t.Errorf("expected maxid=111, got %v", lst.MaxID()) + } + + if _, ok := lst.FindByID(109); ok { + t.Error("found a symbol for $109") + } + + sym, ok := lst.FindByID(111) + if !ok { + t.Fatal("no symbol defined for $111") + } + if sym != "bar" { + t.Errorf("expected $111=bar, got %v", sym) + } + + id, ok := lst.FindByName("foo") + if !ok { + t.Fatal("no id defined for foo") + } + if id != 110 { + t.Errorf("expected foo=$110, got $%v", id) + } + + if _, ok := lst.FindByID(112); ok { + t.Error("found a symbol for $112") + } + + if _, ok := lst.FindByName("bogus"); ok { + t.Error("found a symbol for bogus") + } +} + +func TestReadBinaryStructs(t *testing.T) { + r := readBinary([]byte{ + 0xDF, // null.struct + 0xD0, // {} + 0xEA, 0x81, 0xEE, 0xD7, // foo::{ + 0x84, 0xE3, 0x81, 0xEF, 0xD0, // name:bar::{}, + 0x88, 0x20, // max_id:0 + // } + }) + + _null(t, r, StructType) + _struct(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _structAF(t, r, "", []string{"foo"}, func(t *testing.T, r Reader) { + _structAF(t, r, "name", []string{"bar"}, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _intAF(t, r, "max_id", nil, 0) + }) + _eof(t, r) +} + +func TestReadBinarySexps(t *testing.T) { + r := readBinary([]byte{ + 0xCF, + 0xC3, 0xC1, 0xC0, 0xC0, + }) + + _null(t, r, SexpType) + _sexp(t, r, func(t *testing.T, r Reader) { + _sexp(t, r, func(t *testing.T, r Reader) { + _sexp(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + }) + _sexp(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _eof(t, r) + }) + _eof(t, r) +} + +func TestReadBinaryLists(t *testing.T) { + r := readBinary([]byte{ + 0xBF, + 0xB3, 0xB1, 0xB0, 0xB0, + }) + + _null(t, r, ListType) + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + }) + _list(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + _eof(t, r) + }) + _eof(t, r) +} + +func TestReadBinaryBlobs(t *testing.T) { + r := readBinary([]byte{ + 0xAF, + 0xA0, + 0xA1, 'a', + 0xAE, 0x96, + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', + ' ', 'l', 'o', 'n', 'g', 'e', 'r', + }) + + _null(t, r, BlobType) + _blob(t, r, []byte("")) + _blob(t, r, []byte("a")) + _blob(t, r, []byte("hello world but longer")) + _eof(t, r) +} + +func TestReadBinaryClobs(t *testing.T) { + r := readBinary([]byte{ + 0x9F, + 0x90, // {{}} + 0x91, 'a', // {{a}} + 0x9E, 0x96, + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', + ' ', 'l', 'o', 'n', 'g', 'e', 'r', + }) + + _null(t, r, ClobType) + _clob(t, r, []byte("")) + _clob(t, r, []byte("a")) + _clob(t, r, []byte("hello world but longer")) + _eof(t, r) +} + +func TestReadBinaryStrings(t *testing.T) { + r := readBinary([]byte{ + 0x8F, + 0x80, // "" + 0x81, 'a', // "a" + 0x8E, 0x96, + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', + ' ', 'l', 'o', 'n', 'g', 'e', 'r', + }) + + _null(t, r, StringType) + _string(t, r, "") + _string(t, r, "a") + _string(t, r, "hello world but longer") + _eof(t, r) +} + +func TestReadBinarySymbols(t *testing.T) { + r := readBinary([]byte{ + 0x7F, + 0x70, // $0 + 0x71, 0x01, // $ion + 0x71, 0x0A, // $10 + 0x71, 0x6E, // foo + 0xE4, 0x81, 0xEE, 0x71, 0x6F, // foo::bar + 0x78, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // ${maxint64} + }) + + _null(t, r, SymbolType) + _symbol(t, r, "$0") + _symbol(t, r, "$ion") + _symbol(t, r, "$10") + _symbol(t, r, "foo") + _symbolAF(t, r, "", []string{"foo"}, "bar") + _symbol(t, r, fmt.Sprintf("$%v", uint64(math.MaxUint64))) + _eof(t, r) +} + +func TestReadBinaryTimestamps(t *testing.T) { + r := readBinary([]byte{ + 0x6F, + 0x62, 0x80, 0x81, // 0001T + 0x63, 0x80, 0x81, 0x81, // 0001-01T + 0x64, 0x80, 0x81, 0x81, 0x81, // 0001-01-01T + 0x66, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, // 0001-01-01T00:00Z + 0x67, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80, // 0001-01-01T00:00:00Z + 0x6E, 0x8E, // 0x0E-bit timestamp + 0x04, 0xD8, // offset: +600 minutes (+10:00) + 0x0F, 0xE3, // year: 2019 + 0x88, // month: 8 + 0x84, // day: 4 + 0x88, // hour: 8 utc (18 local) + 0x8F, // minute: 15 + 0xAB, // second: 43 + 0xC9, // exp: -9 + 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 + }) + + _null(t, r, TimestampType) + + for i := 0; i < 5; i++ { + _timestamp(t, r, time.Time{}) + } + + nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") + _timestamp(t, r, nowish) + _eof(t, r) +} + +func TestReadBinaryDecimals(t *testing.T) { + r := readBinary([]byte{ + 0x50, // 0. + 0x5F, // null.decimal + 0x51, 0xC3, // 0.000, aka 0 x 10^-3 + 0x53, 0xC3, 0x03, 0xE8, // 1.000, aka 1000 x 10^-3 + 0x53, 0xC3, 0x83, 0xE8, // -1.000, aka -1000 x 10^-3 + 0x53, 0x00, 0xE4, 0x01, // 1d100, aka 1 * 10^100 + 0x53, 0x00, 0xE4, 0x81, // -1d100, aka -1 * 10^100 + }) + + _decimal(t, r, MustParseDecimal("0.")) + _null(t, r, DecimalType) + _decimal(t, r, MustParseDecimal("0.000")) + _decimal(t, r, MustParseDecimal("1.000")) + _decimal(t, r, MustParseDecimal("-1.000")) + _decimal(t, r, MustParseDecimal("1d100")) + _decimal(t, r, MustParseDecimal("-1d100")) + _eof(t, r) +} + +func TestReadBinaryFloats(t *testing.T) { + r := readBinary([]byte{ + 0x40, // 0 + 0x4F, // null.float + 0x44, 0x7F, 0x7F, 0xFF, 0xFF, // MaxFloat32 + 0x44, 0xFF, 0x7F, 0xFF, 0xFF, // -MaxFloat32 + 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 + 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 + 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf + 0x48, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -inf + 0x48, 0x7F, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // NaN + }) + + _float(t, r, 0) + _null(t, r, FloatType) + _float(t, r, math.MaxFloat32) + _float(t, r, -math.MaxFloat32) + _float(t, r, math.MaxFloat64) + _float(t, r, -math.MaxFloat64) + _float(t, r, math.Inf(1)) + _float(t, r, math.Inf(-1)) + _float(t, r, math.NaN()) + _eof(t, r) +} + +func TestReadBinaryInts(t *testing.T) { + r := readBinary([]byte{ + 0x20, // 0 + 0x2F, // null.int + 0x21, 0x01, // 1 + 0x31, 0x01, // -1 + 0x28, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x7FFFFFFFFFFFFFFF + 0x38, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -0x7FFFFFFFFFFFFFFF + 0x28, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0x8000000000000000 + 0x38, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0x8000000000000000 + }) + + _int(t, r, 0) + _null(t, r, IntType) + _int(t, r, 1) + _int(t, r, -1) + _int64(t, r, math.MaxInt64) + _int64(t, r, -math.MaxInt64) + + _uint(t, r, math.MaxInt64+1) + + i := new(big.Int).SetUint64(math.MaxInt64 + 1) + _bigInt(t, r, new(big.Int).Neg(i)) + + _eof(t, r) +} + +func TestReadBinaryBools(t *testing.T) { + r := readBinary([]byte{ + 0x10, // false + 0x11, // true + 0x1F, // null.bool + }) + + _bool(t, r, false) + _bool(t, r, true) + _null(t, r, BoolType) + _eof(t, r) +} + +func TestReadBinaryNulls(t *testing.T) { + r := readBinary([]byte{ + 0x00, // 1-byte NOP + 0x0F, // null + 0x01, 0xFF, // 2-byte NOP + 0xE3, 0x81, 0x81, 0x0F, // $ion::null + 0x0E, 0x8F, // 16-byte NOP + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xE4, 0x82, 0xEE, 0xEF, 0x0F, // foo::bar::null + }) + + _null(t, r, NullType) + _nullAF(t, r, NullType, "", []string{"$ion"}) + _nullAF(t, r, NullType, "", []string{"foo", "bar"}) + _eof(t, r) +} + +func TestReadEmptyBinary(t *testing.T) { + r := NewReaderBytes([]byte{0xE0, 0x01, 0x00, 0xEA}) + _eof(t, r) + _eof(t, r) +} + +func readBinary(ion []byte) Reader { + prefix := []byte{ + 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ + 0x86, 0xBE, 0x8E, // imports:[ + 0xDD, // { + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" + 0x85, 0x21, 0x2A, // version: 42 + 0x88, 0x21, 0x64, // max_id: 100 + // }] + 0x87, 0xB8, // symbols: [ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" + // ] + // } + } + return NewReaderBytes(append(prefix, ion...)) +} diff --git a/ion/binarywriter.go b/ion/binarywriter.go new file mode 100644 index 00000000..0115faf6 --- /dev/null +++ b/ion/binarywriter.go @@ -0,0 +1,562 @@ +package ion + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "strconv" + "strings" + "time" +) + +// A binaryWriter writes binary ion. +type binaryWriter struct { + writer + bufs bufstack + + lst SymbolTable + lstb SymbolTableBuilder + + wroteLST bool +} + +// NewBinaryWriter creates a new binary writer that will construct a +// local symbol table as it is written to. +func NewBinaryWriter(out io.Writer, sts ...SharedSymbolTable) Writer { + w := &binaryWriter{ + writer: writer{ + out: out, + }, + lstb: NewSymbolTableBuilder(sts...), + } + w.bufs.push(&datagram{}) + return w +} + +// NewBinaryWriterLST creates a new binary writer with a pre-built local +// symbol table. +func NewBinaryWriterLST(out io.Writer, lst SymbolTable) Writer { + return &binaryWriter{ + writer: writer{ + out: out, + }, + lst: lst, + } +} + +// WriteNull writes an untyped null. +func (w *binaryWriter) WriteNull() error { + return w.writeValue("Writer.WriteNull", []byte{0x0F}) +} + +// WriteNullType writes a typed null. +func (w *binaryWriter) WriteNullType(t Type) error { + return w.writeValue("Writer.WriteNullType", []byte{binaryNulls[t]}) +} + +// WriteBool writes a bool. +func (w *binaryWriter) WriteBool(val bool) error { + b := byte(0x10) + if val { + b = 0x11 + } + return w.writeValue("Writer.WriteBool", []byte{b}) +} + +// WriteInt writes an integer. +func (w *binaryWriter) WriteInt(val int64) error { + if val == 0 { + return w.writeValue("Writer.WriteInt", []byte{0x20}) + } + + code := byte(0x20) + mag := uint64(val) + + if val < 0 { + code = 0x30 + mag = uint64(-val) + } + + len := uintLen(mag) + buflen := len + tagLen(len) + + buf := make([]byte, 0, buflen) + buf = appendTag(buf, code, len) + buf = appendUint(buf, mag) + + return w.writeValue("Writer.WriteInt", buf) +} + +// WriteUint writes an unsigned integer. +func (w *binaryWriter) WriteUint(val uint64) error { + if val == 0 { + return w.writeValue("Writer.WriteUint", []byte{0x20}) + } + + len := uintLen(val) + buflen := len + tagLen(len) + + buf := make([]byte, 0, buflen) + buf = appendTag(buf, 0x20, len) + buf = appendUint(buf, val) + + return w.writeValue("Writer.WriteUint", buf) +} + +// WriteBigInt writes a big integer. +func (w *binaryWriter) WriteBigInt(val *big.Int) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteBigInt"); w.err != nil { + return w.err + } + + if w.err = w.writeBigInt(val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err +} + +// WriteBigInt writes the actual big integer value. +func (w *binaryWriter) writeBigInt(val *big.Int) error { + sign := val.Sign() + if sign == 0 { + return w.write([]byte{0x20}) + } + + code := byte(0x20) + if sign < 0 { + code = 0x30 + } + + bs := val.Bytes() + + bl := uint64(len(bs)) + if bl < 64 { + buflen := bl + tagLen(bl) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, code, bl) + buf = append(buf, bs...) + return w.write(buf) + } + + // no sense in copying, emit tag separately. + if err := w.writeTag(code, bl); err != nil { + return err + } + return w.write(bs) +} + +// WriteFloat writes a floating-point value. +func (w *binaryWriter) WriteFloat(val float64) error { + if val == 0 { + return w.writeValue("Writer.WriteFloat", []byte{0x40}) + } + + bs := make([]byte, 9) + bs[0] = 0x48 + + bits := math.Float64bits(val) + binary.BigEndian.PutUint64(bs[1:], bits) + + return w.writeValue("Writer.WriteFloat", bs) +} + +// WriteDecimal writes a decimal value. +func (w *binaryWriter) WriteDecimal(val *Decimal) error { + coef, exp := val.CoEx() + + vlen := uint64(0) + if exp != 0 { + vlen += varIntLen(int64(exp)) + } + if coef.Sign() != 0 { + vlen += bigIntLen(coef) + } + + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x50, vlen) + if exp != 0 { + buf = appendVarInt(buf, int64(exp)) + } + buf = appendBigInt(buf, coef) + + return w.writeValue("Writer.WriteDecimal", buf) +} + +// WriteTimestamp writes a timestamp value. +func (w *binaryWriter) WriteTimestamp(val time.Time) error { + _, offset := val.Zone() + offset /= 60 + utc := val.In(time.UTC) + + vlen := timeLen(offset, utc) + buflen := vlen + tagLen(vlen) + + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x60, vlen) + buf = appendTime(buf, offset, utc) + + return w.writeValue("Writer.WriteTimestamp", buf) +} + +// WriteSymbol writes a symbol value. +func (w *binaryWriter) WriteSymbol(val string) error { + id, err := w.resolve("Writer.WriteSymbol", val) + if err != nil { + w.err = err + return err + } + + vlen := uintLen(uint64(id)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x70, vlen) + buf = appendUint(buf, uint64(id)) + + return w.writeValue("Writer.WriteSymbol", buf) +} + +// WriteString writes a string. +func (w *binaryWriter) WriteString(val string) error { + if len(val) == 0 { + return w.writeValue("Writer.WriteString", []byte{0x80}) + } + + vlen := uint64(len(val)) + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, 0x80, vlen) + buf = append(buf, val...) + + return w.writeValue("Writer.WriteString", buf) +} + +// WriteClob writes a clob. +func (w *binaryWriter) WriteClob(val []byte) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteClob"); w.err != nil { + return w.err + } + + if w.err = w.writeLob(0x90, val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err +} + +// WriteBlob writes a blob. +func (w *binaryWriter) WriteBlob(val []byte) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { + return w.err + } + + if w.err = w.writeLob(0xA0, val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err +} + +func (w *binaryWriter) writeLob(code byte, val []byte) error { + vlen := uint64(len(val)) + + if vlen < 64 { + buflen := vlen + tagLen(vlen) + buf := make([]byte, 0, buflen) + + buf = appendTag(buf, code, vlen) + buf = append(buf, val...) + + return w.write(buf) + } + + if err := w.writeTag(code, vlen); err != nil { + return err + } + return w.write(val) +} + +// BeginList begins writing a list. +func (w *binaryWriter) BeginList() error { + if w.err == nil { + w.err = w.begin("Writer.BeginList", ctxInList, 0xB0) + } + return w.err +} + +// EndList finishes writing a list. +func (w *binaryWriter) EndList() error { + if w.err == nil { + w.err = w.end("Writer.EndList", ctxInList) + } + return w.err +} + +// BeginSexp begins writing an s-expression. +func (w *binaryWriter) BeginSexp() error { + if w.err == nil { + w.err = w.begin("Writer.BeginSexp", ctxInSexp, 0xC0) + } + return w.err +} + +// EndSexp finishes writing an s-expression. +func (w *binaryWriter) EndSexp() error { + if w.err == nil { + w.err = w.end("Writer.EndSexp", ctxInSexp) + } + return w.err +} + +// BeginStruct begins writing a struct. +func (w *binaryWriter) BeginStruct() error { + if w.err == nil { + w.err = w.begin("Writer.BeginStruct", ctxInStruct, 0xD0) + } + return w.err +} + +// EndStruct finishes writing a struct. +func (w *binaryWriter) EndStruct() error { + if w.err == nil { + w.err = w.end("Writer.EndStruct", ctxInStruct) + } + return w.err +} + +// Finish finishes writing a datagram. +func (w *binaryWriter) Finish() error { + if w.err != nil { + return w.err + } + if w.ctx.peek() != ctxAtTopLevel { + return &UsageError{"Writer.Finish", "not at top level"} + } + + w.clear() + w.wroteLST = false + + seq := w.bufs.peek() + if seq != nil { + w.bufs.pop() + if w.bufs.peek() != nil { + panic("at top level but too many bufseqs") + } + + lst := w.lstb.Build() + if err := w.writeLST(lst); err != nil { + return err + } + if w.err = w.emit(seq); w.err != nil { + return w.err + } + } + + return nil +} + +// Emit emits the given node. If we're currently at the top level, that +// means actually emitting to the output stream. If not, we emit append +// to the current bufseq. +func (w *binaryWriter) emit(node bufnode) error { + s := w.bufs.peek() + if s == nil { + return node.EmitTo(w.out) + } + s.Append(node) + return nil +} + +// Write emits the given bytes as an atom. +func (w *binaryWriter) write(bs []byte) error { + return w.emit(atom(bs)) +} + +// WriteValue writes a serialized value to the output stream. +func (w *binaryWriter) writeValue(api string, val []byte) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(api); w.err != nil { + return w.err + } + + if w.err = w.write(val); w.err != nil { + return w.err + } + + w.err = w.endValue() + return w.err +} + +// WriteTag writes out a type+length tag. Use me when you've already got the value to +// be written as a []byte and don't want to copy it. +func (w *binaryWriter) writeTag(code byte, len uint64) error { + tl := tagLen(len) + + tag := make([]byte, 0, tl) + tag = appendTag(tag, code, len) + + return w.write(tag) +} + +// WriteLST writes out a local symbol table. +func (w *binaryWriter) writeLST(lst SymbolTable) error { + if err := w.write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { + return err + } + return lst.WriteTo(w) +} + +// BeginValue begins the process of writing a value by writing out +// its field name and annotations. +func (w *binaryWriter) beginValue(api string) error { + // We have to record/empty these before calling writeLST, which + // will end up using/modifying them. Ugh. + name := w.fieldName + as := w.annotations + w.clear() + + // If we have a local symbol table and haven't written it out yet, do that now. + if w.lst != nil && !w.wroteLST { + w.wroteLST = true + if err := w.writeLST(w.lst); err != nil { + return err + } + } + + if w.inStruct() { + if name == "" { + return &UsageError{api, "field name not set"} + } + + id, err := w.resolve(api, name) + if err != nil { + return err + } + + buf := make([]byte, 0, 10) + buf = appendVarUint(buf, id) + if err := w.write(buf); err != nil { + return err + } + } + + if len(as) > 0 { + ids := make([]uint64, len(as)) + idlen := uint64(0) + + for i, a := range as { + id, err := w.resolve(api, a) + if err != nil { + return err + } + + ids[i] = id + idlen += varUintLen(id) + } + + buflen := idlen + varUintLen(idlen) + buf := make([]byte, 0, buflen) + + buf = appendVarUint(buf, idlen) + for _, id := range ids { + buf = appendVarUint(buf, id) + } + + // TODO: We could theoretically write the actual tag here if we know the + // length of the value ahead of time. + w.bufs.push(&container{code: 0xE0}) + if err := w.write(buf); err != nil { + return err + } + } + + return nil +} + +// EndValue ends the process of writing a value by flushing it and its annotations +// up a level, if needed. +func (w *binaryWriter) endValue() error { + seq := w.bufs.peek() + if seq != nil { + if c, ok := seq.(*container); ok && c.code == 0xE0 { + w.bufs.pop() + return w.emit(seq) + } + } + return nil +} + +// Begin begins writing a new container. +func (w *binaryWriter) begin(api string, t ctx, code byte) error { + if err := w.beginValue(api); err != nil { + return err + } + + w.ctx.push(t) + w.bufs.push(&container{code: code}) + + return nil +} + +// End ends writing a container, emitting its buffered contents up a level in the stack. +func (w *binaryWriter) end(api string, t ctx) error { + if w.ctx.peek() != t { + return &UsageError{api, "not in that kind of container"} + } + + seq := w.bufs.peek() + if seq != nil { + w.bufs.pop() + if err := w.emit(seq); err != nil { + return err + } + } + + w.clear() + w.ctx.pop() + + return w.endValue() +} + +// Resolve resolves a symbol to its ID. +func (w *binaryWriter) resolve(api, sym string) (uint64, error) { + if strings.HasPrefix(sym, "$") { + id, err := strconv.ParseUint(sym[1:], 10, 64) + if err == nil { + return id, nil + } + } + + if w.lst != nil { + id, ok := w.lst.FindByName(sym) + if !ok { + return 0, &UsageError{api, fmt.Sprintf("symbol '%v' not defined", sym)} + } + return id, nil + } + + id, _ := w.lstb.Add(sym) + return id, nil +} diff --git a/ion/binarywriter_test.go b/ion/binarywriter_test.go new file mode 100644 index 00000000..c1cbaaaa --- /dev/null +++ b/ion/binarywriter_test.go @@ -0,0 +1,387 @@ +package ion + +import ( + "bytes" + "encoding/hex" + "fmt" + "math" + "math/big" + "strings" + "testing" + "time" +) + +func TestWriteBinaryStruct(t *testing.T) { + eval := []byte{ + 0xD0, // {} + 0xEA, 0x81, 0xEE, 0xD7, // foo::{ + 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, + 0x88, 0x20, // max_id:0 + // } + } + testBinaryWriter(t, eval, func(w Writer) { + w.BeginStruct() + w.EndStruct() + + w.Annotation("foo") + w.BeginStruct() + { + w.FieldName("name") + w.Annotation("bar") + w.WriteNull() + + w.FieldName("max_id") + w.WriteInt(0) + } + w.EndStruct() + }) +} + +func TestWriteBinarySexp(t *testing.T) { + eval := []byte{ + 0xC0, // () + 0xE8, 0x81, 0xEE, 0xC5, // foo::( + 0xE3, 0x81, 0xEF, 0x0F, // bar::null, + 0x20, // 0 + // ) + } + testBinaryWriter(t, eval, func(w Writer) { + w.BeginSexp() + w.EndSexp() + + w.Annotation("foo") + w.BeginSexp() + { + w.Annotation("bar") + w.WriteNull() + + w.WriteInt(0) + } + w.EndSexp() + }) +} + +func TestWriteBinaryList(t *testing.T) { + eval := []byte{ + 0xB0, // [] + 0xE8, 0x81, 0xEE, 0xB5, // foo::[ + 0xE3, 0x81, 0xEF, 0x0F, // bar::null, + 0x20, // 0 + // ] + } + testBinaryWriter(t, eval, func(w Writer) { + w.BeginList() + w.EndList() + + w.Annotation("foo") + w.BeginList() + { + w.Annotation("bar") + w.WriteNull() + + w.WriteInt(0) + } + w.EndList() + }) +} + +func TestWriteBinaryBlob(t *testing.T) { + eval := []byte{ + 0xA0, + 0xAB, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBlob([]byte{}) + w.WriteBlob([]byte("Hello World")) + }) +} + +func TestWriteLargeBinaryBlob(t *testing.T) { + eval := make([]byte, 131) + eval[0] = 0xAE + eval[1] = 0x01 + eval[2] = 0x80 + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBlob(make([]byte, 128)) + }) +} + +func TestWriteBinaryClob(t *testing.T) { + eval := []byte{ + 0x90, + 0x9B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteClob([]byte{}) + w.WriteClob([]byte("Hello World")) + }) +} + +func TestWriteBinaryString(t *testing.T) { + eval := []byte{ + 0x80, // "" + 0x8B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + 0x8E, 0x9B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', + ' ', 'B', 'u', 't', ' ', 'E', 'v', 'e', 'n', ' ', 'L', 'o', 'n', 'g', 'e', 'r', + 0x84, 0xE0, 0x01, 0x00, 0xEA, + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteString("") + w.WriteString("Hello World") + w.WriteString("Hello World But Even Longer") + w.WriteString("\xE0\x01\x00\xEA") + }) +} + +func TestWriteBinarySymbol(t *testing.T) { + eval := []byte{ + 0x71, 0x01, // $ion + 0x71, 0x04, // name + 0x71, 0x05, // version + 0x71, 0x09, // $ion_shared_symbol_table + 0x74, 0xFF, 0xFF, 0xFF, 0xFF, // $4294967295 + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteSymbol("$ion") + w.WriteSymbol("name") + w.WriteSymbol("version") + w.WriteSymbol("$ion_shared_symbol_table") + w.WriteSymbol("$4294967295") + }) +} + +func TestWriteBinaryTimestamp(t *testing.T) { + eval := []byte{ + 0x67, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80, // 0001-01-01T00:00:00Z + 0x6E, 0x8E, // 0x0E-bit timestamp + 0x04, 0xD8, // offset: +600 minutes (+10:00) + 0x0F, 0xE3, // year: 2019 + 0x88, // month: 8 + 0x84, // day: 4 + 0x88, // hour: 8 utc (18 local) + 0x8F, // minute: 15 + 0xAB, // second: 43 + 0xC9, // exp: -9 + 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 + } + + nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteTimestamp(time.Time{}) + w.WriteTimestamp(nowish) + }) +} + +func TestWriteBinaryDecimal(t *testing.T) { + eval := []byte{ + 0x50, // 0. + 0x51, 0xC3, // 0.000, aka 0 x 10^-3 + 0x53, 0xC3, 0x03, 0xE8, // 1.000, aka 1000 x 10^-3 + 0x53, 0xC3, 0x83, 0xE8, // -1.000, aka -1000 x 10^-3 + 0x53, 0x00, 0xE4, 0x01, // 1d100, aka 1 * 10^100 + 0x53, 0x00, 0xE4, 0x81, // -1d100, aka -1 * 10^100 + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteDecimal(MustParseDecimal("0.")) + w.WriteDecimal(MustParseDecimal("0.000")) + w.WriteDecimal(MustParseDecimal("1.000")) + w.WriteDecimal(MustParseDecimal("-1.000")) + w.WriteDecimal(MustParseDecimal("1d100")) + w.WriteDecimal(MustParseDecimal("-1d100")) + }) +} + +func TestWriteBinaryFloats(t *testing.T) { + eval := []byte{ + 0x40, // 0 + 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 + 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 + 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf + 0x48, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -inf + 0x48, 0x7F, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // NaN + } + testBinaryWriter(t, eval, func(w Writer) { + w.WriteFloat(0) + w.WriteFloat(math.MaxFloat64) + w.WriteFloat(-math.MaxFloat64) + w.WriteFloat(math.Inf(1)) + w.WriteFloat(math.Inf(-1)) + w.WriteFloat(math.NaN()) + }) +} + +func TestWriteBinaryBigInts(t *testing.T) { + eval := []byte{ + 0x20, // 0 + 0x21, 0xFF, // 0xFF + 0x31, 0xFF, // -0xFF + 0x2E, 0x90, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // a really big integer + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBigInt(big.NewInt(0)) + w.WriteBigInt(big.NewInt(0xFF)) + w.WriteBigInt(big.NewInt(-0xFF)) + w.WriteBigInt(new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})) + }) +} + +func TestWriteBinaryReallyBigInts(t *testing.T) { + eval := []byte{ + 0x2E, 0x01, 0x80, // 128-byte positive integer + 0x80, // high bit set + } + eval = append(eval, make([]byte, 127)...) + testBinaryWriter(t, eval, func(w Writer) { + i := new(big.Int) + i = i.SetBit(i, 1023, 1) + w.WriteBigInt(i) + }) +} + +func TestWriteBinaryInts(t *testing.T) { + eval := []byte{ + 0x20, // 0 + 0x21, 0xFF, // 0xFF + 0x31, 0xFF, // -0xFF + 0x22, 0xFF, 0xFF, // 0xFFFF + 0x33, 0xFF, 0xFF, 0xFF, // -0xFFFFFF + 0x28, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // math.MaxInt64 + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteInt(0) + w.WriteInt(0xFF) + w.WriteInt(-0xFF) + w.WriteInt(0xFFFF) + w.WriteInt(-0xFFFFFF) + w.WriteInt(math.MaxInt64) + }) +} + +func TestWriteBinaryBoolAnnotated(t *testing.T) { + eval := []byte{ + 0xE4, // 4-byte annotated value + 0x82, // 2 bytes of annotations + 0x84, // $4 (name) + 0x85, // $5 (version) + 0x10, // false + } + + testBinaryWriter(t, eval, func(w Writer) { + w.Annotations("name", "version") + w.WriteBool(false) + }) +} + +func TestWriteBinaryBools(t *testing.T) { + eval := []byte{ + 0x10, // false + 0x11, // true + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteBool(false) + w.WriteBool(true) + }) +} + +func TestWriteBinaryNulls(t *testing.T) { + eval := []byte{ + 0x0F, + 0x1F, + 0x2F, + // 0x3F, // negative integer, not actually valid + 0x4F, + 0x5F, + 0x6F, + 0x7F, + 0x8F, + 0x9F, + 0xAF, + 0xBF, + 0xCF, + 0xDF, + } + + testBinaryWriter(t, eval, func(w Writer) { + w.WriteNull() + w.WriteNullType(BoolType) + w.WriteNullType(IntType) + w.WriteNullType(FloatType) + w.WriteNullType(DecimalType) + w.WriteNullType(TimestampType) + w.WriteNullType(SymbolType) + w.WriteNullType(StringType) + w.WriteNullType(ClobType) + w.WriteNullType(BlobType) + w.WriteNullType(ListType) + w.WriteNullType(SexpType) + w.WriteNullType(StructType) + }) +} + +func testBinaryWriter(t *testing.T, eval []byte, f func(w Writer)) { + val := writeBinary(t, f) + + prefix := []byte{ + 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ + 0x86, 0xBE, 0x8E, // imports:[ + 0xDD, // { + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" + 0x85, 0x21, 0x2A, // version: 42 + 0x88, 0x21, 0x64, // max_id: 100 + // }] + 0x87, 0xB8, // symbols: [ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" + // ] + // } + } + eval = append(prefix, eval...) + + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", fmtbytes(eval), fmtbytes(val)) + } +} + +func fmtbytes(bs []byte) string { + buf := strings.Builder{} + buf.WriteByte('[') + for i, b := range bs { + if i > 0 { + buf.WriteByte(' ') + } + buf.WriteString(hex.EncodeToString([]byte{b})) + } + buf.WriteByte(']') + return buf.String() +} + +func writeBinary(t *testing.T, f func(w Writer)) []byte { + bogusSyms := []string{} + for i := 0; i < 100; i++ { + bogusSyms = append(bogusSyms, fmt.Sprintf("bogus_sym_%v", i)) + } + + bogus := []SharedSymbolTable{ + NewSharedSymbolTable("bogus", 42, bogusSyms), + } + + buf := bytes.Buffer{} + w := NewBinaryWriterLST(&buf, NewLocalSymbolTable(bogus, []string{ + "foo", + "bar", + })) + + f(w) + + if err := w.Finish(); err != nil { + t.Fatal(err) + } + + return buf.Bytes() +} diff --git a/ion/bits.go b/ion/bits.go new file mode 100644 index 00000000..f3772702 --- /dev/null +++ b/ion/bits.go @@ -0,0 +1,285 @@ +package ion + +import ( + "math/big" + "time" +) + +// uintLen pre-calculates the length, in bytes, of the given uint value. +func uintLen(v uint64) uint64 { + len := uint64(1) + v >>= 8 + + for v > 0 { + len++ + v >>= 8 + } + + return len +} + +// appendUint appends a uint value to the given slice. The reader is +// expected to know how many bytes the value takes up. +func appendUint(b []byte, v uint64) []byte { + var buf [8]byte + + i := 7 + buf[i] = byte(v & 0xFF) + v >>= 8 + + for v > 0 { + i-- + buf[i] = byte(v & 0xFF) + v >>= 8 + } + + return append(b, buf[i:]...) +} + +// intLen pre-calculates the length, in bytes, of the given int value. +func intLen(n int64) uint64 { + if n == 0 { + return 0 + } + + mag := uint64(n) + if n < 0 { + mag = uint64(-n) + } + + len := uintLen(mag) + + // If the high bit is a one, we need an extra byte to store the sign bit. + hb := mag >> ((len - 1) * 8) + if hb&0x80 != 0 { + len++ + } + + return len +} + +// appendInt appends a (signed) int to the given slice. The reader is +// expected to know how many bytes the value takes up. +func appendInt(b []byte, n int64) []byte { + if n == 0 { + return b + } + + neg := false + mag := uint64(n) + + if n < 0 { + neg = true + mag = uint64(-n) + } + + var buf [8]byte + bits := buf[:0] + bits = appendUint(bits, mag) + + if bits[0]&0x80 == 0 { + // We've got space we can use for the sign bit. + if neg { + bits[0] ^= 0x80 + } + } else { + // We need to add more space. + bit := byte(0) + if neg { + bit = 0x80 + } + b = append(b, bit) + } + + return append(b, bits...) +} + +// bigIntLen pre-calculates the length, in bytes, of the given big.Int value. +func bigIntLen(v *big.Int) uint64 { + if v.Sign() == 0 { + return 0 + } + + bitl := v.BitLen() + bytel := bitl / 8 + + // Either bitl is evenly divisibly by 8, in which case we need another + // byte for the sign bit, or its not in which case we need to round up + // (but will then have room for the sign bit). + return uint64(bytel) + 1 +} + +// appendBigInt appends a (signed) big.Int to the given slice. The reader is +// expected to know how many bytes the value takes up. +func appendBigInt(b []byte, v *big.Int) []byte { + sign := v.Sign() + if sign == 0 { + return b + } + + bits := v.Bytes() + + if bits[0]&0x80 == 0 { + // We've got space we can use for the sign bit. + if sign < 0 { + bits[0] ^= 0x80 + } + } else { + // We need to add more space. + bit := byte(0) + if sign < 0 { + bit = 0x80 + } + b = append(b, bit) + } + + return append(b, bits...) +} + +// varUintLen pre-calculates the length, in bytes, of the given varUint value. +func varUintLen(v uint64) uint64 { + len := uint64(1) + v >>= 7 + + for v > 0 { + len++ + v >>= 7 + } + + return len +} + +// appendVarUint appends a variable-length-encoded uint to the given slice. +// Each byte stores seven bits of value; the high bit is a flag marking the +// last byte of the value. +func appendVarUint(b []byte, v uint64) []byte { + var buf [10]byte + + i := 9 + buf[i] = 0x80 | byte(v&0x7F) + v >>= 7 + + for v > 0 { + i-- + buf[i] = byte(v & 0x7F) + v >>= 7 + } + + return append(b, buf[i:]...) +} + +// varIntLen pre-calculates the length, in bytes, of the given varInt value. +func varIntLen(v int64) uint64 { + mag := uint64(v) + if v < 0 { + mag = uint64(-v) + } + + // Reserve one extra bit of the first byte for sign. + len := uint64(1) + mag >>= 6 + + for mag > 0 { + len++ + mag >>= 7 + } + + return len +} + +// appendVarInt appends a variable-length-encoded int to the given slice. +// Most bytes store seven bits of value; the high bit is a flag marking the +// last byte of the value. The first byte additionally stores a sign bit. +func appendVarInt(b []byte, v int64) []byte { + var buf [10]byte + + signbit := byte(0) + mag := uint64(v) + if v < 0 { + signbit = 0x40 + mag = uint64(-v) + } + + next := mag >> 6 + if next == 0 { + // The whole thing fits in one byte. + return append(b, 0x80|signbit|byte(mag&0x3F)) + } + + i := 9 + buf[i] = 0x80 | byte(mag&0x7F) + mag >>= 7 + next = mag >> 6 + + for next > 0 { + i-- + buf[i] = byte(mag & 0x7F) + mag >>= 7 + next = mag >> 6 + } + + i-- + buf[i] = signbit | byte(mag&0x3F) + + return append(b, buf[i:]...) +} + +// tagLen pre-calculates the length, in bytes, of a tag. +func tagLen(len uint64) uint64 { + if len < 0x0E { + return 1 + } + return 1 + varUintLen(len) +} + +// appendTag appends a code+len tag to the given slice. +func appendTag(b []byte, code byte, len uint64) []byte { + if len < 0x0E { + // Short form, with length embedded in the code byte. + return append(b, code|byte(len)) + } + + // Long form, with separate length. + b = append(b, code|0x0E) + return appendVarUint(b, len) +} + +// timeLen pre-calculates the length, in bytes, of the given time value. +func timeLen(offset int, utc time.Time) uint64 { + ret := varIntLen(int64(offset)) + + // Almost certainly two but let's be safe. + ret += varUintLen(uint64(utc.Year())) + + // Month, day, hour, minute, and second are all guaranteed to be one byte. + ret += 5 + + ns := utc.Nanosecond() + if ns > 0 { + ret++ // varIntLen(-9) + ret += intLen(int64(ns)) + } + + return ret +} + +// appendTime appends a timestamp value +func appendTime(b []byte, offset int, utc time.Time) []byte { + b = appendVarInt(b, int64(offset)) + + b = appendVarUint(b, uint64(utc.Year())) + b = appendVarUint(b, uint64(utc.Month())) + b = appendVarUint(b, uint64(utc.Day())) + + b = appendVarUint(b, uint64(utc.Hour())) + b = appendVarUint(b, uint64(utc.Minute())) + b = appendVarUint(b, uint64(utc.Second())) + + ns := utc.Nanosecond() + if ns > 0 { + b = appendVarInt(b, -9) + b = appendInt(b, int64(ns)) + } + + return b +} diff --git a/ion/bits_test.go b/ion/bits_test.go new file mode 100644 index 00000000..db2c15c4 --- /dev/null +++ b/ion/bits_test.go @@ -0,0 +1,205 @@ +package ion + +import ( + "bytes" + "fmt" + "math" + "math/big" + "testing" + "time" +) + +func TestAppendUint(t *testing.T) { + test := func(val uint64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := uintLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendUint(nil, val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 1, []byte{0}) + test(0xFF, 1, []byte{0xFF}) + test(0x1FF, 2, []byte{0x01, 0xFF}) + test(math.MaxUint64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) +} + +func TestAppendInt(t *testing.T) { + test := func(val int64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := intLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendInt(nil, val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 0, []byte{}) + test(0x7F, 1, []byte{0x7F}) + test(-0x7F, 1, []byte{0xFF}) + + test(0xFF, 2, []byte{0x00, 0xFF}) + test(-0xFF, 2, []byte{0x80, 0xFF}) + + test(0x7FFF, 2, []byte{0x7F, 0xFF}) + test(-0x7FFF, 2, []byte{0xFF, 0xFF}) + + test(math.MaxInt64, 8, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + test(-math.MaxInt64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + test(math.MinInt64, 9, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) +} + +func TestAppendBigInt(t *testing.T) { + test := func(val *big.Int, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := bigIntLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendBigInt(nil, val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(big.NewInt(0), 0, []byte{}) + test(big.NewInt(0x7F), 1, []byte{0x7F}) + test(big.NewInt(-0x7F), 1, []byte{0xFF}) + + test(big.NewInt(0xFF), 2, []byte{0x00, 0xFF}) + test(big.NewInt(-0xFF), 2, []byte{0x80, 0xFF}) + + test(big.NewInt(0x7FFF), 2, []byte{0x7F, 0xFF}) + test(big.NewInt(-0x7FFF), 2, []byte{0xFF, 0xFF}) +} + +func TestAppendVarUint(t *testing.T) { + test := func(val uint64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := varUintLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendVarUint(nil, val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 1, []byte{0x80}) + test(0x7F, 1, []byte{0xFF}) + test(0xFF, 2, []byte{0x01, 0xFF}) + test(0x1FF, 2, []byte{0x03, 0xFF}) + test(0x3FFF, 2, []byte{0x7F, 0xFF}) + test(0x7FFF, 3, []byte{0x01, 0x7F, 0xFF}) + test(0x7FFFFFFFFFFFFFFF, 9, []byte{0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(0xFFFFFFFFFFFFFFFF, 10, []byte{0x01, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) +} + +func TestAppendVarInt(t *testing.T) { + test := func(val int64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + len := varIntLen(val) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendVarInt(nil, val) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0, 1, []byte{0x80}) + + test(0x3F, 1, []byte{0xBF}) // 1011 1111 + test(-0x3F, 1, []byte{0xFF}) + + test(0x7F, 2, []byte{0x00, 0xFF}) + test(-0x7F, 2, []byte{0x40, 0xFF}) + + test(0x1FFF, 2, []byte{0x3F, 0xFF}) + test(-0x1FFF, 2, []byte{0x7F, 0xFF}) + + test(0x3FFF, 3, []byte{0x00, 0x7F, 0xFF}) + test(-0x3FFF, 3, []byte{0x40, 0x7F, 0xFF}) + + test(0x3FFFFFFFFFFFFFFF, 9, []byte{0x3F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(-0x3FFFFFFFFFFFFFFF, 9, []byte{0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + + test(math.MaxInt64, 10, []byte{0x00, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(-math.MaxInt64, 10, []byte{0x40, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) + test(math.MinInt64, 10, []byte{0x41, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}) +} + +func TestAppendTag(t *testing.T) { + test := func(code byte, vlen uint64, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("(%x,%v)", code, vlen), func(t *testing.T) { + len := tagLen(vlen) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendTag(nil, code, vlen) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + test(0x20, 1, 1, []byte{0x21}) + test(0x30, 0x0D, 1, []byte{0x3D}) + test(0x40, 0x0E, 2, []byte{0x4E, 0x8E}) + test(0x50, math.MaxInt64, 10, []byte{0x5E, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) +} + +func TestAppendTime(t *testing.T) { + test := func(val time.Time, elen uint64, ebits []byte) { + t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { + _, offset := val.Zone() + offset /= 60 + utc := val.In(time.UTC) + + len := timeLen(offset, utc) + if len != elen { + t.Errorf("expected len=%v, got len=%v", elen, len) + } + + bits := appendTime(nil, offset, utc) + if !bytes.Equal(bits, ebits) { + t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) + } + }) + } + + nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") + + test(time.Time{}, 7, []byte{0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80}) + test(nowish, 14, []byte{ + 0x04, 0xD8, // offset: +600 minutes (+10:00) + 0x0F, 0xE3, // year: 2019 + 0x88, // month: 8 + 0x84, // day: 4 + 0x88, // hour: 8 utc (18 local) + 0x8F, // minute: 15 + 0xAB, // second: 43 + 0xC9, // exp: -9 + 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 + }) +} diff --git a/ion/bitstream.go b/ion/bitstream.go new file mode 100644 index 00000000..512aa109 --- /dev/null +++ b/ion/bitstream.go @@ -0,0 +1,935 @@ +package ion + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "time" +) + +type bss uint8 + +const ( + bssBeforeValue bss = iota + bssOnValue + bssBeforeFieldID + bssOnFieldID +) + +type bitcode uint8 + +const ( + bitcodeNone bitcode = iota + bitcodeEOF + bitcodeBVM + bitcodeNull + bitcodeFalse + bitcodeTrue + bitcodeInt + bitcodeNegInt + bitcodeFloat + bitcodeDecimal + bitcodeTimestamp + bitcodeSymbol + bitcodeString + bitcodeClob + bitcodeBlob + bitcodeList + bitcodeSexp + bitcodeStruct + bitcodeFieldID + bitcodeAnnotation +) + +func (b bitcode) String() string { + switch b { + case bitcodeNone: + return "none" + case bitcodeEOF: + return "eof" + case bitcodeBVM: + return "bvm" + case bitcodeFalse: + return "false" + case bitcodeTrue: + return "true" + case bitcodeInt: + return "int" + case bitcodeNegInt: + return "negint" + case bitcodeFloat: + return "float" + case bitcodeDecimal: + return "decimal" + case bitcodeTimestamp: + return "timestamp" + case bitcodeSymbol: + return "symbol" + case bitcodeString: + return "string" + case bitcodeClob: + return "clob" + case bitcodeBlob: + return "blob" + case bitcodeList: + return "list" + case bitcodeSexp: + return "sexp" + case bitcodeStruct: + return "struct" + case bitcodeFieldID: + return "fieldid" + case bitcodeAnnotation: + return "annotation" + default: + return fmt.Sprintf("", uint8(b)) + } +} + +// A bitstream is a low-level parser for binary Ion values. +type bitstream struct { + in *bufio.Reader + pos uint64 + state bss + stack bitstack + + code bitcode + null bool + len uint64 +} + +// Init initializes this stream with the given bufio.Reader. +func (b *bitstream) Init(in *bufio.Reader) { + b.in = in +} + +// InitBytes initializes this stream with the given bytes. +func (b *bitstream) InitBytes(in []byte) { + b.in = bufio.NewReader(bytes.NewReader(in)) +} + +// Code returns the typecode of the current value. +func (b *bitstream) Code() bitcode { + return b.code +} + +// IsNull returns true if the current value is null. +func (b *bitstream) IsNull() bool { + return b.null +} + +// Pos returns the current position. +func (b *bitstream) Pos() uint64 { + return b.pos +} + +// Len returns the length of the current value. +func (b *bitstream) Len() uint64 { + return b.len +} + +// Next advances the stream to the next value. +func (b *bitstream) Next() error { + // If we have an unread value, skip over it to get to the next one. + switch b.state { + case bssOnValue, bssOnFieldID: + if err := b.SkipValue(); err != nil { + return err + } + } + + // If we're at the end of the current container, stop and make the user step out. + if !b.stack.empty() { + cur := b.stack.peek() + if b.pos == cur.end { + b.code = bitcodeEOF + return nil + } + } + + // If it's time to read a field id, do that. + if b.state == bssBeforeFieldID { + b.code = bitcodeFieldID + b.state = bssOnFieldID + return nil + } + + // Otherwise it's time to read a value. Read the tag byte. + c, err := b.read() + if err != nil { + return err + } + + // Found the end of the file. + if c == -1 { + b.code = bitcodeEOF + return nil + } + + // Parse the tag. + code, len := parseTag(c) + if code == bitcodeNone { + return &InvalidTagByteError{byte(c), b.pos - 1} + } + + b.state = bssOnValue + + if code == bitcodeAnnotation { + switch len { + case 0: + // This value is actually a BVM. It's invalid if we're not at the top level. + if !b.stack.empty() { + return &SyntaxError{"invalid BVM in a container", b.pos - 1} + } + b.code = bitcodeBVM + b.len = 3 + return nil + + case 0x0F: + // No such thing as a null annotation. + return &InvalidTagByteError{byte(c), b.pos - 1} + } + } + + // Booleans are a bit special; the 'length' stores the value. + if code == bitcodeFalse { + switch len { + case 0, 0x0F: + break + case 1: + code = bitcodeTrue + len = 0 + default: + // Other forms are invalid. + return &InvalidTagByteError{byte(c), b.pos - 1} + } + } + + if len == 0x0F { + // This value is actually a null. + b.code = code + b.null = true + return nil + } + + pos := b.pos + rem := b.remaining() + + // This value's actual len is encoded as a separate varUint. + if len == 0x0E { + var lenlen uint64 + len, lenlen, err = b.readVarUintLen(rem) + if err != nil { + return err + } + rem -= lenlen + } + + if len > rem { + msg := fmt.Sprintf("value overruns its container: %v vs %v", len, rem) + return &SyntaxError{msg, pos - 1} + } + + b.code = code + b.len = len + return nil +} + +// SkipValue skips over the current value. +func (b *bitstream) SkipValue() error { + switch b.state { + case bssBeforeFieldID, bssBeforeValue: + // No current value to skip yet. + return nil + + case bssOnFieldID: + if err := b.skipVarUint(); err != nil { + return err + } + b.state = bssBeforeValue + + case bssOnValue: + if b.len > 0 { + if err := b.skip(b.len); err != nil { + return err + } + } + b.state = b.stateAfterValue() + + default: + panic(fmt.Sprintf("invalid state %v", b.state)) + } + + b.clear() + return nil +} + +// StepIn steps in to a container. +func (b *bitstream) StepIn() { + switch b.code { + case bitcodeStruct: + b.state = bssBeforeFieldID + + case bitcodeList, bitcodeSexp: + b.state = bssBeforeValue + + default: + panic(fmt.Sprintf("StepIn called with b.code=%v", b.code)) + } + + b.stack.push(b.code, b.pos+b.len) + b.clear() +} + +// StepOut steps out of a container. +func (b *bitstream) StepOut() error { + if b.stack.empty() { + panic("StepOut called at top level") + } + + cur := b.stack.peek() + b.stack.pop() + + if cur.end < b.pos { + panic(fmt.Sprintf("end (%v) greater than b.pos (%v)", cur.end, b.pos)) + } + diff := cur.end - b.pos + + // Skip over anything left in the container we're stepping out of. + if diff > 0 { + if err := b.skip(diff); err != nil { + return err + } + } + + b.state = b.stateAfterValue() + b.clear() + + return nil +} + +// ReadBVM reads a binary version marker, returning its major and minor version. +func (b *bitstream) ReadBVM() (byte, byte, error) { + if b.code != bitcodeBVM { + panic("not a BVM") + } + + major, err := b.read1() + if err != nil { + return 0, 0, err + } + + minor, err := b.read1() + if err != nil { + return 0, 0, err + } + + end, err := b.read1() + if err != nil { + return 0, 0, err + } + + if end != 0xEA { + msg := fmt.Sprintf("invalid BVM: 0xE0 0x%02X 0x%02X 0x%02X", major, minor, end) + return 0, 0, &SyntaxError{msg, b.pos - 4} + } + + b.state = bssBeforeValue + b.clear() + + return byte(major), byte(minor), nil +} + +// ReadFieldID reads a field ID. +func (b *bitstream) ReadFieldID() (uint64, error) { + if b.code != bitcodeFieldID { + panic("not a field ID") + } + + id, err := b.readVarUint() + if err != nil { + return 0, err + } + + b.state = bssBeforeValue + b.code = bitcodeNone + + return id, nil +} + +// ReadAnnotationIDs reads a set of annotation IDs. +func (b *bitstream) ReadAnnotationIDs() ([]uint64, error) { + if b.code != bitcodeAnnotation { + panic("not an annotation") + } + + alen, lenlen, err := b.readVarUintLen(b.len) + if err != nil { + return nil, err + } + + if b.len-lenlen <= alen { + // The size of the annotations is larger than the remaining free space inside the + // annotation container. + return nil, &SyntaxError{"malformed annotation", b.pos - lenlen} + } + + as := []uint64{} + for alen > 0 { + id, idlen, err := b.readVarUintLen(alen) + if err != nil { + return nil, err + } + + as = append(as, id) + alen -= idlen + } + + b.state = bssBeforeValue + b.clear() + + return as, nil +} + +// ReadInt reads an integer value. +func (b *bitstream) ReadInt() (interface{}, error) { + if b.code != bitcodeInt && b.code != bitcodeNegInt { + panic("not an integer") + } + + bs, err := b.readN(b.len) + if err != nil { + return "", err + } + + var ret interface{} + switch { + case b.len == 0: + // Special case for zero. + ret = int64(0) + + case b.len < 8, (b.len == 8 && bs[0]&0x80 == 0): + // It'll fit in an int64. + i := int64(0) + for _, b := range bs { + i <<= 8 + i ^= int64(b) + } + if b.code == bitcodeNegInt { + i = -i + } + ret = i + + default: + // Need to go big.Int. + i := new(big.Int).SetBytes(bs) + if b.code == bitcodeNegInt { + i = i.Neg(i) + } + ret = i + } + + b.state = b.stateAfterValue() + b.clear() + + return ret, nil +} + +// ReadFloat reads a float value. +func (b *bitstream) ReadFloat() (float64, error) { + if b.code != bitcodeFloat { + panic("not a float") + } + + bs, err := b.readN(b.len) + if err != nil { + return 0, err + } + + var ret float64 + switch len(bs) { + case 0: + ret = 0 + + case 4: + ui := binary.BigEndian.Uint32(bs) + ret = float64(math.Float32frombits(ui)) + + case 8: + ui := binary.BigEndian.Uint64(bs) + ret = math.Float64frombits(ui) + + default: + return 0, &SyntaxError{"invalid float size", b.pos - b.len} + } + + b.state = b.stateAfterValue() + b.clear() + + return ret, nil +} + +// ReadDecimal reads a decimal value. +func (b *bitstream) ReadDecimal() (*Decimal, error) { + if b.code != bitcodeDecimal { + panic("not a decimal") + } + + d, err := b.readDecimal(b.len) + if err != nil { + return nil, err + } + + b.state = b.stateAfterValue() + b.clear() + + return d, nil +} + +// ReadTimestamp reads a timestamp value. +func (b *bitstream) ReadTimestamp() (time.Time, error) { + if b.code != bitcodeTimestamp { + panic("not a timestamp") + } + + len := b.len + + offset, olen, err := b.readVarIntLen(len) + if err != nil { + return time.Time{}, err + } + len -= olen + + ts := []int{1, 1, 1, 0, 0, 0} + for i := 0; len > 0 && i < 6; i++ { + val, vlen, err := b.readVarUintLen(len) + if err != nil { + return time.Time{}, err + } + len -= vlen + ts[i] = int(val) + } + + nsecs, err := b.readNsecs(len) + if err != nil { + return time.Time{}, err + } + + b.state = b.stateAfterValue() + b.clear() + + utc := time.Date(ts[0], time.Month(ts[1]), ts[2], ts[3], ts[4], ts[5], int(nsecs), time.UTC) + return utc.In(time.FixedZone("fixed", int(offset)*60)), nil +} + +// ReadNsecs reads the fraction part of a timestamp and truncates it to nanoseconds. +func (b *bitstream) readNsecs(len uint64) (int, error) { + d, err := b.readDecimal(len) + if err != nil { + return 0, err + } + + nsec, err := d.ShiftL(9).Trunc() + if err != nil || nsec < 0 || nsec > 999999999 { + msg := fmt.Sprintf("invalid timestamp fraction: %v", d) + return 0, &SyntaxError{msg, b.pos} + } + + return int(nsec), nil +} + +// ReadDecimal reads a decimal value of the given length: an exponent encoded as a +// varInt, followed by an integer coefficient taking up the remaining bytes. +func (b *bitstream) readDecimal(len uint64) (*Decimal, error) { + exp := int64(0) + coef := new(big.Int) + + if len > 0 { + val, vlen, err := b.readVarIntLen(len) + if err != nil { + return nil, err + } + + if val > math.MaxInt32 || val < math.MinInt32 { + msg := fmt.Sprintf("decimal exponent out of range: %v", val) + return nil, &SyntaxError{msg, b.pos - vlen} + } + + exp = val + len -= vlen + } + + if len > 0 { + if err := b.readBigInt(len, coef); err != nil { + return nil, err + } + } + + return NewDecimal(coef, int32(exp)), nil +} + +// ReadSymbolID reads a symbol value. +func (b *bitstream) ReadSymbolID() (uint64, error) { + if b.code != bitcodeSymbol { + panic("not a symbol") + } + + if b.len > 8 { + return 0, &SyntaxError{"symbol id too large", b.pos} + } + + bs, err := b.readN(b.len) + if err != nil { + return 0, err + } + + b.state = b.stateAfterValue() + b.clear() + + ret := uint64(0) + for _, b := range bs { + ret <<= 8 + ret ^= uint64(b) + } + return ret, nil +} + +// ReadString reads a string value. +func (b *bitstream) ReadString() (string, error) { + if b.code != bitcodeString { + panic("not a string") + } + + bs, err := b.readN(b.len) + if err != nil { + return "", err + } + + b.state = b.stateAfterValue() + b.clear() + + return string(bs), nil +} + +// ReadBytes reads a blob or clob value. +func (b *bitstream) ReadBytes() ([]byte, error) { + if b.code != bitcodeClob && b.code != bitcodeBlob { + panic("not a lob") + } + + bs, err := b.readN(b.len) + if err != nil { + return nil, err + } + + b.state = b.stateAfterValue() + b.clear() + + return bs, nil +} + +// Clear clears the current code and len. +func (b *bitstream) clear() { + b.code = bitcodeNone + b.null = false + b.len = 0 +} + +// ReadBigInt reads a fixed-length integer of the given length and stores +// the value in the given big.Int. +func (b *bitstream) readBigInt(len uint64, ret *big.Int) error { + bs, err := b.readN(len) + if err != nil { + return err + } + + neg := (bs[0]&0x80 != 0) + bs[0] &= 0x7F + if bs[0] == 0 { + bs = bs[1:] + } + + ret.SetBytes(bs) + if neg { + ret.Neg(ret) + } + + return nil +} + +// ReadVarUint reads a variable-length-encoded uint. +func (b *bitstream) readVarUint() (uint64, error) { + val, _, err := b.readVarUintLen(b.remaining()) + return val, err +} + +// ReadVarUintLen reads a variable-length-encoded uint of at most max bytes, +// returning the value and its actual length in bytes. +func (b *bitstream) readVarUintLen(max uint64) (uint64, uint64, error) { + if max > 10 { + max = 10 + } + + val := uint64(0) + len := uint64(0) + + for { + if len >= max { + return 0, 0, &SyntaxError{"varuint too large", b.pos} + } + + c, err := b.read1() + if err != nil { + return 0, 0, err + } + + val <<= 7 + val ^= uint64(c & 0x7F) + len++ + + if c&0x80 != 0 { + return val, len, nil + } + } +} + +// SkipVarUint skips over a variable-length-encoded uint. +func (b *bitstream) skipVarUint() error { + _, err := b.skipVarUintLen(b.remaining()) + return err +} + +// SkipVarUintLen skips over a variable-length-encoded uint of at most max bytes. +func (b *bitstream) skipVarUintLen(max uint64) (uint64, error) { + if max > 10 { + max = 10 + } + + len := uint64(0) + for { + if len >= max { + return 0, &SyntaxError{"varuint too large", b.pos - len} + } + + c, err := b.read1() + if err != nil { + return 0, err + } + + len++ + + if c&0x80 != 0 { + return len, nil + } + } +} + +// Remaining returns the number of bytes remaining in the current container. +func (b *bitstream) remaining() uint64 { + if b.stack.empty() { + return math.MaxUint64 + } + + end := b.stack.peek().end + if b.pos > end { + panic(fmt.Sprintf("pos (%v) > end (%v)", b.pos, end)) + } + + return end - b.pos +} + +// ReadVarIntLen reads a variable-length-encoded int of at most max bytes, +// returning the value and its actual length in bytes +func (b *bitstream) readVarIntLen(max uint64) (int64, uint64, error) { + if max == 0 { + return 0, 0, &SyntaxError{"varint too large", b.pos} + } + if max > 10 { + max = 10 + } + + // Read the first byte, which contains the sign bit. + c, err := b.read1() + if err != nil { + return 0, 0, err + } + + sign := int64(1) + if c&0x40 != 0 { + sign = -1 + } + + val := int64(c & 0x3F) + len := uint64(1) + + // Check if that was the last (only) byte. + if c&0x80 != 0 { + return val * sign, len, nil + } + + for { + if len >= max { + return 0, 0, &SyntaxError{"varint too large", b.pos - len} + } + + c, err := b.read1() + if err != nil { + return 0, 0, err + } + + val <<= 7 + val ^= int64(c & 0x7F) + len++ + + if c&0x80 != 0 { + return val * sign, len, nil + } + } +} + +// StateAfterValue returns the state this stream is in after reading a value. +func (b *bitstream) stateAfterValue() bss { + if b.stack.peek().code == bitcodeStruct { + return bssBeforeFieldID + } + return bssBeforeValue +} + +var bitcodes = []bitcode{ + bitcodeNull, // 0x00 + bitcodeFalse, // 0x10 + bitcodeInt, // 0x20 + bitcodeNegInt, // 0x30 + bitcodeFloat, // 0x40 + bitcodeDecimal, // 0x50 + bitcodeTimestamp, // 0x60 + bitcodeSymbol, // 0x70 + bitcodeString, // 0x80 + bitcodeClob, // 0x90 + bitcodeBlob, // 0xA0 + bitcodeList, // 0xB0 + bitcodeSexp, // 0xC0 + bitcodeStruct, // 0xD0 + bitcodeAnnotation, // 0xE0 +} + +// ParseTag parses a tag byte into a typecode and a length. +func parseTag(c int) (bitcode, uint64) { + high := (c >> 4) & 0x0F + low := c & 0x0F + + code := bitcodeNone + if high < len(bitcodes) { + code = bitcodes[high] + } + + return code, uint64(low) +} + +// ReadN reads the next n bytes of input from the underlying stream. +func (b *bitstream) readN(n uint64) ([]byte, error) { + if n == 0 { + return nil, nil + } + + bs := make([]byte, n) + actual, err := b.in.Read(bs) + b.pos += uint64(actual) + + if err == io.EOF { + return nil, &UnexpectedEOFError{b.pos} + } + if err != nil { + return nil, &IOError{err} + } + + return bs, nil +} + +// Read1 reads the next byte of input from the underlying stream, returning +// an UnexpectedEOFError if it's an EOF. +func (b *bitstream) read1() (int, error) { + c, err := b.read() + if err != nil { + return 0, err + } + if c == -1 { + return 0, &UnexpectedEOFError{b.pos} + } + return c, nil +} + +// Read reads the next byte of input from the underlying stream. It returns +// -1 instead of io.EOF if we've hit the end of the stream, because I find +// that easier to reason about. +func (b *bitstream) read() (int, error) { + c, err := b.in.ReadByte() + b.pos++ + + if err == io.EOF { + return -1, nil + } + if err != nil { + return 0, &IOError{err} + } + + return int(c), nil +} + +// Skip skips n bytes of input from the underlying stream. +func (b *bitstream) skip(n uint64) error { + actual, err := b.in.Discard(int(n)) + b.pos += uint64(actual) + + if err == io.EOF { + return nil + } + if err != nil { + return &IOError{err} + } + + return nil +} + +// A bitnode represents a container value, including its typecode and +// the offset at which it (supposedly) ends. +type bitnode struct { + code bitcode + end uint64 +} + +// A stack of bitnodes representing container values that we're currently +// stepped in to. +type bitstack struct { + arr []bitnode +} + +// Empty returns true if this bitstack is empty. +func (b *bitstack) empty() bool { + return len(b.arr) == 0 +} + +// Peek peeks at the top bitnode on the stack. +func (b *bitstack) peek() bitnode { + if len(b.arr) == 0 { + return bitnode{} + } + return b.arr[len(b.arr)-1] +} + +// Push pushes a bitnode onto the stack. +func (b *bitstack) push(code bitcode, end uint64) { + b.arr = append(b.arr, bitnode{code, end}) +} + +// Pop pops a bitnode from the stack. +func (b *bitstack) pop() { + if len(b.arr) == 0 { + panic("pop called on empty bitstack") + } + b.arr = b.arr[:len(b.arr)-1] +} diff --git a/ion/bitstream_test.go b/ion/bitstream_test.go new file mode 100644 index 00000000..8cbf57ab --- /dev/null +++ b/ion/bitstream_test.go @@ -0,0 +1,111 @@ +package ion + +import "testing" + +func TestBitstream(t *testing.T) { + ion := []byte{ + 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ + 0x86, 0xBE, 0x8E, // imports:[ + 0xDD, // { + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" + 0x85, 0x21, 0x2A, // version: 42 + 0x88, 0x21, 0x64, // max_id: 100 + // }] + 0x87, 0xB8, // symbols: [ + 0x83, 'f', 'o', 'o', // "foo" + 0x83, 'b', 'a', 'r', // "bar" + // ] + // } + 0xD0, // {} + 0xEA, 0x81, 0xEE, 0xD7, // foo::{ + 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, + 0x88, 0x20, // max_id:0 + // } + } + + b := bitstream{} + b.InitBytes(ion) + + next := func(code bitcode, null bool, len uint64) { + if err := b.Next(); err != nil { + t.Fatal(err) + } + if b.Code() != code { + t.Errorf("expected code=%v, got %v", code, b.Code()) + } + if b.IsNull() != null { + t.Errorf("expected null=%v, got %v", null, b.IsNull()) + } + if b.Len() != len { + t.Errorf("expected len=%v, got %v", len, b.Len()) + } + } + + fieldid := func(eid uint64) { + id, err := b.ReadFieldID() + if err != nil { + t.Fatal(err) + } + if id != eid { + t.Errorf("expected %v, got %v", eid, id) + } + } + + next(bitcodeBVM, false, 3) + maj, min, err := b.ReadBVM() + if err != nil { + t.Fatal(err) + } + if maj != 1 && min != 0 { + t.Errorf("expected $ion_1.0, got $ion_%v.%v", maj, min) + } + + next(bitcodeAnnotation, false, 31) + ids, err := b.ReadAnnotationIDs() + if err != nil { + t.Fatal(err) + } + if len(ids) != 1 || ids[0] != 3 { // $ion_symbol_table + t.Errorf("expected [3], got %v", ids) + } + + next(bitcodeStruct, false, 27) + b.StepIn() + { + next(bitcodeFieldID, false, 0) + fieldid(6) // imports + + next(bitcodeList, false, 14) + b.StepIn() + { + next(bitcodeStruct, false, 13) + } + if err := b.StepOut(); err != nil { + t.Fatal(err) + } + + next(bitcodeFieldID, false, 0) + // fieldid(7) // symbols + + next(bitcodeList, false, 8) + next(bitcodeEOF, false, 0) + } + if err := b.StepOut(); err != nil { + t.Fatal(err) + } + + next(bitcodeStruct, false, 0) + next(bitcodeAnnotation, false, 10) + next(bitcodeEOF, false, 0) + next(bitcodeEOF, false, 0) +} + +func TestBitcodeString(t *testing.T) { + for i := bitcodeNone; i <= bitcodeAnnotation+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected non-empty string for bitcode %v", uint8(i)) + } + } +} diff --git a/ion/buf.go b/ion/buf.go new file mode 100644 index 00000000..b4a30fe1 --- /dev/null +++ b/ion/buf.go @@ -0,0 +1,119 @@ +package ion + +import ( + "io" +) + +// Writing binary ion is a bit tricky: values are preceded by their length, +// which can be hard to predict until we've actually written out the value. +// To make matters worse, we can't predict the length of the /length/ ahead +// of time in order to reserve space for it, because it uses a variable-length +// encoding. To avoid copying bytes around all over the place, we write into +// an in-memory tree structure, which we then blast out to the actual io.Writer +// once all the relevant lengths are known. + +// A bufnode is a node in the partially-serialized tree. +type bufnode interface { + Len() uint64 + EmitTo(w io.Writer) error +} + +// A bufseq is a bufnode that's also an appendable sequence of bufnodes. +type bufseq interface { + bufnode + Append(n bufnode) +} + +var _ bufnode = atom([]byte{}) +var _ bufseq = &datagram{} +var _ bufseq = &container{} + +// An atom is a value that has been fully serialized and can be emitted directly. +type atom []byte + +func (a atom) Len() uint64 { + return uint64(len(a)) +} + +func (a atom) EmitTo(w io.Writer) error { + _, err := w.Write(a) + return err +} + +// A datagram is a sequence of nodes that will be emitted one +// after another. Most notably, used to buffer top-level values +// when we haven't yet finalized the local symbol table. +type datagram struct { + len uint64 + children []bufnode +} + +func (d *datagram) Append(n bufnode) { + d.len += n.Len() + d.children = append(d.children, n) +} + +func (d *datagram) Len() uint64 { + return d.len +} + +func (d *datagram) EmitTo(w io.Writer) error { + for _, child := range d.children { + if err := child.EmitTo(w); err != nil { + return err + } + } + + return nil +} + +// A container is a datagram that's preceeded by a code+length tag. +type container struct { + code byte + datagram +} + +func (c *container) Len() uint64 { + if c.len < 0x0E { + return c.len + 1 + } + return c.len + (varUintLen(c.len) + 1) +} + +func (c *container) EmitTo(w io.Writer) error { + var arr [11]byte + buf := arr[:0] + buf = appendTag(buf, c.code, c.len) + + if _, err := w.Write(buf); err != nil { + return err + } + return c.datagram.EmitTo(w) +} + +// A bufstack is a stack of bufseqs, more or less matching the +// stack of BeginList/Sexp/Struct calls made on a binaryWriter. +// The top of the stack is the sequence we're currently writing +// values into; when it's popped off, it will be appended to the +// bufseq below it. +type bufstack struct { + arr []bufseq +} + +func (s *bufstack) peek() bufseq { + if len(s.arr) == 0 { + return nil + } + return s.arr[len(s.arr)-1] +} + +func (s *bufstack) push(b bufseq) { + s.arr = append(s.arr, b) +} + +func (s *bufstack) pop() { + if len(s.arr) == 0 { + panic("pop called on an empty stack") + } + s.arr = s.arr[:len(s.arr)-1] +} diff --git a/ion/buf_test.go b/ion/buf_test.go new file mode 100644 index 00000000..d4b6a308 --- /dev/null +++ b/ion/buf_test.go @@ -0,0 +1,79 @@ +package ion + +import ( + "bytes" + "testing" +) + +func TestBufnode(t *testing.T) { + root := container{code: 0xE0} + root.Append(atom([]byte{0x81, 0x83})) + { + symtab := &container{code: 0xD0} + { + symtab.Append(atom([]byte{0x86})) // varUint(6) + { + imps := &container{code: 0xB0} + { + imp0 := &container{code: 0xD0} + { + imp0.Append(atom([]byte{0x84})) // varUint(4) + imp0.Append(atom([]byte{0x85, 'b', 'o', 'g', 'u', 's'})) + imp0.Append(atom([]byte{0x85})) // varUint(5) + imp0.Append(atom([]byte{0x21, 0x2A})) + imp0.Append(atom([]byte{0x88})) // varUint(8) + imp0.Append(atom([]byte{0x21, 0x64})) + } + imps.Append(imp0) + } + symtab.Append(imps) + } + + symtab.Append(atom([]byte{0x87})) // varUint(7) + { + syms := &container{code: 0xB0} + { + syms.Append(atom([]byte{0x83, 'f', 'o', 'o'})) + syms.Append(atom([]byte{0x83, 'b', 'a', 'r'})) + } + symtab.Append(syms) + } + } + root.Append(symtab) + } + + buf := bytes.Buffer{} + if err := root.EmitTo(&buf); err != nil { + t.Fatal(err) + } + + val := buf.Bytes() + eval := []byte{ + // $ion_symbol_table::{ + 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, + // imports:[ + 0x86, 0xBE, 0x8E, + // { + 0xDD, + // name: "bogus" + 0x84, 0x85, 'b', 'o', 'g', 'u', 's', + // version: 42 + 0x85, 0x21, 0x2A, + // max_id: 100 + 0x88, 0x21, 0x64, + // } + // ], + // symbols:[ + 0x87, 0xB8, + // "foo", + 0x83, 'f', 'o', 'o', + // "bar" + 0x83, 'b', 'a', 'r', + // ] + // } + } + + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", fmtbytes(eval), fmtbytes(val)) + } +} diff --git a/ion/catalog.go b/ion/catalog.go new file mode 100644 index 00000000..65d4f88f --- /dev/null +++ b/ion/catalog.go @@ -0,0 +1,88 @@ +package ion + +import ( + "bytes" + "fmt" + "io" + "strings" +) + +// A Catalog provides access to shared symbol tables. +type Catalog interface { + FindExact(name string, version int) SharedSymbolTable + FindLatest(name string) SharedSymbolTable +} + +// A basicCatalog wraps an in-memory collection of shared symbol tables. +type basicCatalog struct { + ssts map[string]SharedSymbolTable + latest map[string]SharedSymbolTable +} + +// NewCatalog creates a new basic catalog containing the given symbol tables. +func NewCatalog(ssts ...SharedSymbolTable) Catalog { + cat := &basicCatalog{ + ssts: make(map[string]SharedSymbolTable), + latest: make(map[string]SharedSymbolTable), + } + for _, sst := range ssts { + cat.add(sst) + } + return cat +} + +// Add adds a shared symbol table to the catalog. +func (c *basicCatalog) add(sst SharedSymbolTable) { + key := fmt.Sprintf("%v/%v", sst.Name(), sst.Version()) + c.ssts[key] = sst + + cur, ok := c.latest[sst.Name()] + if !ok || sst.Version() > cur.Version() { + c.latest[sst.Name()] = sst + } +} + +// FindExact attempts to find a shared symbol table with the given name and version. +func (c *basicCatalog) FindExact(name string, version int) SharedSymbolTable { + key := fmt.Sprintf("%v/%v", name, version) + return c.ssts[key] +} + +// FindLatest finds the shared symbol table with the given name and largest version. +func (c *basicCatalog) FindLatest(name string) SharedSymbolTable { + return c.latest[name] +} + +// A System is a reader factory wrapping a catalog. +type System struct { + Catalog Catalog +} + +// NewReader creates a new reader using this system's catalog. +func (s System) NewReader(in io.Reader) Reader { + return NewReaderCat(in, s.Catalog) +} + +// NewReaderStr creates a new reader using this system's catalog. +func (s System) NewReaderStr(in string) Reader { + return NewReaderCat(strings.NewReader(in), s.Catalog) +} + +// NewReaderBytes creates a new reader using this system's catalog. +func (s System) NewReaderBytes(in []byte) Reader { + return NewReaderCat(bytes.NewReader(in), s.Catalog) +} + +// Unmarshal unmarshals Ion data using this system's catalog. +func (s System) Unmarshal(data []byte, v interface{}) error { + r := s.NewReaderBytes(data) + d := NewDecoder(r) + return d.DecodeTo(v) +} + +// UnmarshalStr unmarshals Ion data using this system's catalog. +func (s System) UnmarshalStr(data string, v interface{}) error { + r := s.NewReaderStr(data) + d := NewDecoder(r) + return d.DecodeTo(v) +} diff --git a/ion/catalog_test.go b/ion/catalog_test.go new file mode 100644 index 00000000..2e3bd2e0 --- /dev/null +++ b/ion/catalog_test.go @@ -0,0 +1,62 @@ +package ion + +import ( + "bytes" + "fmt" + "testing" +) + +type Item struct { + ID int `json:"id"` + Name string `json:"name"` + Description string `json:"description"` +} + +func TestCatalog(t *testing.T) { + sst := NewSharedSymbolTable("item", 1, []string{ + "item", + "id", + "name", + "description", + }) + + buf := bytes.Buffer{} + out := NewBinaryWriter(&buf, sst) + + for i := 0; i < 10; i++ { + out.Annotation("item") + MarshalTo(out, &Item{ + ID: i, + Name: fmt.Sprintf("Item %v", i), + Description: fmt.Sprintf("The %vth test item", i), + }) + } + if err := out.Finish(); err != nil { + t.Fatal(err) + } + + bs := buf.Bytes() + + sys := System{Catalog: NewCatalog(sst)} + in := sys.NewReaderBytes(bs) + + i := 0 + for ; ; i++ { + item := Item{} + err := UnmarshalFrom(in, &item) + if err == ErrNoInput { + break + } + if err != nil { + t.Fatal(err) + } + + if item.ID != i { + t.Errorf("expected id=%v, got %v", i, item.ID) + } + } + + if i != 10 { + t.Errorf("expected i=10, got %v", i) + } +} diff --git a/ion/consts.go b/ion/consts.go new file mode 100644 index 00000000..1a49dec4 --- /dev/null +++ b/ion/consts.go @@ -0,0 +1,52 @@ +package ion + +import ( + "reflect" + "time" +) + +var binaryNulls = func() []byte { + ret := make([]byte, StructType+1) + ret[NoType] = 0x0F + ret[NullType] = 0x0F + ret[BoolType] = 0x1F + ret[IntType] = 0x2F + ret[FloatType] = 0x4F + ret[DecimalType] = 0x5F + ret[TimestampType] = 0x6F + ret[SymbolType] = 0x7F + ret[StringType] = 0x8F + ret[ClobType] = 0x9F + ret[BlobType] = 0xAF + ret[ListType] = 0xBF + ret[SexpType] = 0xCF + ret[StructType] = 0xDF + return ret +}() + +var textNulls []string = func() []string { + ret := make([]string, StructType+1) + ret[NoType] = "null" + ret[NullType] = "null.null" + ret[BoolType] = "null.bool" + ret[IntType] = "null.int" + ret[FloatType] = "null.float" + ret[DecimalType] = "null.decimal" + ret[TimestampType] = "null.timestamp" + ret[SymbolType] = "null.symbol" + ret[StringType] = "null.string" + ret[ClobType] = "null.clob" + ret[BlobType] = "null.blob" + ret[ListType] = "null.list" + ret[SexpType] = "null.sexp" + ret[StructType] = "null.struct" + return ret +}() + +var hexChars = []byte{ + '0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', +} + +var timeType = reflect.TypeOf(time.Time{}) +var decimalType = reflect.TypeOf(Decimal{}) diff --git a/ion/ctx.go b/ion/ctx.go new file mode 100644 index 00000000..d3a1b9ed --- /dev/null +++ b/ion/ctx.go @@ -0,0 +1,65 @@ +package ion + +import "fmt" + +// ctx is the current reader or writer context. +type ctx uint8 + +const ( + ctxAtTopLevel ctx = iota + ctxInStruct + ctxInList + ctxInSexp +) + +func ctxToContainerType(c ctx) Type { + switch c { + case ctxInList: + return ListType + case ctxInSexp: + return SexpType + case ctxInStruct: + return StructType + default: + return NoType + } +} + +func containerTypeToCtx(t Type) ctx { + switch t { + case ListType: + return ctxInList + case SexpType: + return ctxInSexp + case StructType: + return ctxInStruct + default: + panic(fmt.Sprintf("type %v is not a container type", t)) + } +} + +// ctxstack is a context stack. +type ctxstack struct { + arr []ctx +} + +// peek returns the current context. +func (c *ctxstack) peek() ctx { + if len(c.arr) == 0 { + return ctxAtTopLevel + } + return c.arr[len(c.arr)-1] +} + +// push pushes a new context onto the stack. +func (c *ctxstack) push(ctx ctx) { + c.arr = append(c.arr, ctx) +} + +// pop pops the top context off the stack. +func (c *ctxstack) pop() { + if len(c.arr) == 0 { + panic("pop called at top level") + } + c.arr = c.arr[:len(c.arr)-1] +} diff --git a/ion/decimal.go b/ion/decimal.go new file mode 100644 index 00000000..1b522286 --- /dev/null +++ b/ion/decimal.go @@ -0,0 +1,342 @@ +package ion + +import ( + "fmt" + "math" + "math/big" + "strconv" + "strings" +) + +// A ParseError is returned if ParseDecimal is called with a parameter that +// cannot be parsed as a Decimal. +type ParseError struct { + Num string + Msg string +} + +func (e *ParseError) Error() string { + return fmt.Sprintf("ion: ParseDecimal(%v): %v", e.Num, e.Msg) +} + +// TODO: Explicitly track precision? + +// Decimal is an arbitrary-precision decimal value. +type Decimal struct { + n *big.Int + scale int32 +} + +// NewDecimal creates a new decimal whose value is equal to n * 10^exp. +func NewDecimal(n *big.Int, exp int32) *Decimal { + return &Decimal{ + n: n, + scale: -exp, + } +} + +// NewDecimalInt creates a new decimal whose value is equal to n. +func NewDecimalInt(n int64) *Decimal { + return NewDecimal(big.NewInt(n), 0) +} + +// MustParseDecimal parses the given string into a decimal object, +// panicing on error. +func MustParseDecimal(in string) *Decimal { + d, err := ParseDecimal(in) + if err != nil { + panic(err) + } + return d +} + +// ParseDecimal parses the given string into a decimal object, +// returning an error on failure. +func ParseDecimal(in string) (*Decimal, error) { + if len(in) == 0 { + return nil, &ParseError{in, "empty string"} + } + + exponent := int32(0) + + d := strings.IndexAny(in, "Dd") + if d != -1 { + // There's an explicit exponent. + exp := in[d+1:] + if len(exp) == 0 { + return nil, &ParseError{in, "unexpected end of input after d"} + } + + tmp, err := strconv.ParseInt(exp, 10, 32) + if err != nil { + return nil, &ParseError{in, err.Error()} + } + + exponent = int32(tmp) + in = in[:d] + } + + d = strings.Index(in, ".") + if d != -1 { + // There's zero or more decimal places. + ipart := in[:d] + fpart := in[d+1:] + + exponent -= int32(len(fpart)) + in = ipart + fpart + } + + n, ok := new(big.Int).SetString(in, 10) + if !ok { + // Unfortunately this is all we get? + return nil, &ParseError{in, "cannot parse coefficient"} + } + + return NewDecimal(n, exponent), nil +} + +// CoEx returns this decimal's coefficient and exponent. +func (d *Decimal) CoEx() (*big.Int, int32) { + return d.n, -d.scale +} + +// Abs returns the absolute value of this Decimal. +func (d *Decimal) Abs() *Decimal { + return &Decimal{ + n: new(big.Int).Abs(d.n), + scale: d.scale, + } +} + +// Add returns the result of adding this Decimal to another Decimal. +func (d *Decimal) Add(o *Decimal) *Decimal { + // a*10^x + b*10^y = (a*10^(x-y) + b) * 10^y + dd, oo := rescale(d, o) + return &Decimal{ + n: new(big.Int).Add(dd.n, oo.n), + scale: dd.scale, + } +} + +// Sub returns the result of substrating another Decimal from this Decimal. +func (d *Decimal) Sub(o *Decimal) *Decimal { + dd, oo := rescale(d, o) + return &Decimal{ + n: new(big.Int).Sub(dd.n, oo.n), + scale: dd.scale, + } +} + +// Neg returns the negative of this Decimal. +func (d *Decimal) Neg() *Decimal { + return &Decimal{ + n: new(big.Int).Neg(d.n), + scale: d.scale, + } +} + +// Mul multiplies two decimals and returns the result. +func (d *Decimal) Mul(o *Decimal) *Decimal { + // a*10^x * b*10^y = (a*b) * 10^(x+y) + scale := int64(d.scale) + int64(o.scale) + if scale > math.MaxInt32 || scale < math.MinInt32 { + panic("exponent out of bounds") + } + + return &Decimal{ + n: new(big.Int).Mul(d.n, o.n), + scale: int32(scale), + } +} + +// ShiftL returns a new decimal shifted the given number of decimal +// places to the left. It's a computationally-cheap way to compute +// d * 10^shift. +func (d *Decimal) ShiftL(shift int) *Decimal { + scale := int64(d.scale) - int64(shift) + if scale > math.MaxInt32 || scale < math.MinInt32 { + panic("exponent out of bounds") + } + + return &Decimal{ + n: d.n, + scale: int32(scale), + } +} + +// ShiftR returns a new decimal shifted the given number of decimal +// places to the right. It's a computationally-cheap way to compute +// d / 10^shift. +func (d *Decimal) ShiftR(shift int) *Decimal { + scale := int64(d.scale) + int64(shift) + if scale > math.MaxInt32 || scale < math.MinInt32 { + panic("exponent out of bounds") + } + + return &Decimal{ + n: d.n, + scale: int32(scale), + } +} + +// TODO: Div, Exp, etc? + +// Sign returns -1 if the value is less than 0, 0 if it is equal to zero, +// and +1 if it is greater than zero. +func (d *Decimal) Sign() int { + return d.n.Sign() +} + +// Cmp compares two decimals, returning -1 if d is smaller, +1 if d is +// larger, and 0 if they are equal (ignoring precision). +func (d *Decimal) Cmp(o *Decimal) int { + dd, oo := rescale(d, o) + return dd.n.Cmp(oo.n) +} + +// Equal determines if two decimals are equal (discounting precision, +// at least for now). +func (d *Decimal) Equal(o *Decimal) bool { + return d.Cmp(o) == 0 +} + +func rescale(a, b *Decimal) (*Decimal, *Decimal) { + if a.scale < b.scale { + return a.upscale(b.scale), b + } else if a.scale > b.scale { + return a, b.upscale(a.scale) + } else { + return a, b + } +} + +var ten = big.NewInt(10) + +// Make 'n' bigger by making 'scale' smaller, since we know we can +// do that. (1d100 -> 10d99). Makes comparisons and math easier, at the +// expense of more storage space. Technically speaking implies adding +// more precision, but we're not tracking that too closely. +func (d *Decimal) upscale(scale int32) *Decimal { + diff := int64(scale) - int64(d.scale) + if diff < 0 { + panic("can't upscale to a smaller scale") + } + + pow := new(big.Int).Exp(ten, big.NewInt(diff), nil) + n := new(big.Int).Mul(d.n, pow) + + return &Decimal{ + n: n, + scale: scale, + } +} + +// Trunc attempts to truncate this decimal to an int64, dropping any fractional bits. +func (d *Decimal) Trunc() (int64, error) { + if d.scale < 0 { + // Don't even bother trying this with numbers that *definitely* too big to represent + // as an int64, because upscale(0) will consume a bunch of memory. + if d.scale < -20 { + return 0, &strconv.NumError{ + Func: "ParseInt", + Num: d.String(), + Err: strconv.ErrRange, + } + } + d = d.upscale(0) + } + + str := d.n.String() + + want := len(str) - int(d.scale) + if want <= 0 { + return 0, nil + } + + return strconv.ParseInt(str[:want], 10, 64) +} + +// Truncate returns a new decimal, truncated to the given number of +// decimal digits of precision. It does not round, so 19.Truncate(1) +// = 1d1. +func (d *Decimal) Truncate(precision int) *Decimal { + if precision <= 0 { + panic("precision must be positive") + } + + // Is there a better way to calculate precision? It really + // seems like there should be... + + str := d.n.String() + if str[0] == '-' { + // Cheating a bit. + precision++ + } + + diff := len(str) - precision + if diff <= 0 { + // Already small enough, nothing to truncate. + return d + } + + // Lazy man's division by a power of 10. + n, ok := new(big.Int).SetString(str[:precision], 10) + if !ok { + // Should never happen, since we started with a valid int. + panic("failed to parse integer") + } + + scale := int64(d.scale) - int64(diff) + if scale < math.MinInt32 { + panic("exponent out of range") + } + + return &Decimal{ + n: n, + scale: int32(scale), + } +} + +// String formats the decimal as a string in Ion text format. +func (d *Decimal) String() string { + switch { + case d.scale == 0: + // Value is an unscaled integer. Just mark it as a decimal. + return d.n.String() + "." + + case d.scale < 0: + // Value is a upscaled integer, nn'd'ss + return d.n.String() + "d" + fmt.Sprintf("%d", -d.scale) + + default: + // Value is a downscaled integer nn.nn('d'-ss)? + str := d.n.String() + idx := len(str) - int(d.scale) + + prefix := 1 + if d.n.Sign() < 0 { + // Account for leading '-'. + prefix++ + } + + if idx >= prefix { + // Put the decimal point in the middle, no exponent. + return str[:idx] + "." + str[idx:] + } + + // Put the decimal point at the beginning and + // add a (negative) exponent. + b := strings.Builder{} + b.WriteString(str[:prefix]) + + if len(str) > prefix { + b.WriteString(".") + b.WriteString(str[prefix:]) + } + + b.WriteString("d") + b.WriteString(fmt.Sprintf("%d", idx-prefix)) + + return b.String() + } +} diff --git a/ion/decimal_test.go b/ion/decimal_test.go new file mode 100644 index 00000000..21b72c9e --- /dev/null +++ b/ion/decimal_test.go @@ -0,0 +1,312 @@ +package ion + +import ( + "fmt" + "math/big" + "testing" +) + +func TestDecimalToString(t *testing.T) { + test := func(n int64, scale int32, expected string) { + t.Run(expected, func(t *testing.T) { + d := Decimal{ + n: big.NewInt(n), + scale: scale, + } + actual := d.String() + if actual != expected { + t.Errorf("expected '%v', got '%v'", expected, actual) + } + }) + } + + test(0, 0, "0.") + test(0, -1, "0d1") + test(0, 1, "0d-1") + + test(1, 0, "1.") + test(1, -1, "1d1") + test(1, 1, "1d-1") + + test(-1, 0, "-1.") + test(-1, -1, "-1d1") + test(-1, 1, "-1d-1") + + test(123, 0, "123.") + test(-456, 0, "-456.") + + test(123, -5, "123d5") + test(-456, -5, "-456d5") + + test(123, 1, "12.3") + test(123, 2, "1.23") + test(123, 3, "1.23d-1") + test(123, 4, "1.23d-2") + + test(-456, 1, "-45.6") + test(-456, 2, "-4.56") + test(-456, 3, "-4.56d-1") + test(-456, 4, "-4.56d-2") +} + +func TestParseDecimal(t *testing.T) { + test := func(in string, n *big.Int, scale int32) { + t.Run(in, func(t *testing.T) { + d, err := ParseDecimal(in) + if err != nil { + t.Fatal(err) + } + + if n.Cmp(d.n) != 0 { + t.Errorf("wrong n; expected %v, got %v", n, d.n) + } + if scale != d.scale { + t.Errorf("wrong scale; expected %v, got %v", scale, d.scale) + } + }) + } + + test("0", big.NewInt(0), 0) + test("-0", big.NewInt(0), 0) + test("0D0", big.NewInt(0), 0) + test("-0d-1", big.NewInt(0), 1) + + test("1.", big.NewInt(1), 0) + test("1.0", big.NewInt(10), 1) + test("0.123", big.NewInt(123), 3) + + test("1d0", big.NewInt(1), 0) + test("1d1", big.NewInt(1), -1) + test("1d+1", big.NewInt(1), -1) + test("1d-1", big.NewInt(1), 1) + + test("-0.12d4", big.NewInt(-12), -2) +} + +func absF(d *Decimal) *Decimal { return d.Abs() } +func negF(d *Decimal) *Decimal { return d.Neg() } + +type unaryop struct { + sym string + fun func(d *Decimal) *Decimal +} + +var abs = &unaryop{"abs", absF} +var neg = &unaryop{"neg", negF} + +func testUnaryOp(t *testing.T, a, e string, op *unaryop) { + t.Run(op.sym+"("+a+")="+e, func(t *testing.T) { + aa, _ := ParseDecimal(a) + ee, _ := ParseDecimal(e) + actual := op.fun(aa) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) + } + }) +} + +func TestAbs(t *testing.T) { + test := func(a, e string) { + testUnaryOp(t, a, e, abs) + } + + test("0", "0") + test("1d100", "1d100") + test("-1d100", "1d100") + test("1.2d-3", "1.2d-3") + test("-1.2d-3", "1.2d-3") +} + +func TestNeg(t *testing.T) { + test := func(a, e string) { + testUnaryOp(t, a, e, neg) + } + + test("0", "0") + test("1d100", "-1d100") + test("-1d100", "1d100") + test("1.2d-3", "-1.2d-3") + test("-1.2d-3", "1.2d-3") +} + +func TestTrunc(t *testing.T) { + test := func(a string, eval int64) { + t.Run(fmt.Sprintf("trunc(%v)=%v", a, eval), func(t *testing.T) { + aa := MustParseDecimal(a) + val, err := aa.Trunc() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("0.", 0) + test("0.01", 0) + test("1.", 1) + test("-1.", -1) + test("1.01", 1) + test("-1.01", -1) + test("101", 101) + test("1d3", 1000) +} + +func addF(a, b *Decimal) *Decimal { return a.Add(b) } +func subF(a, b *Decimal) *Decimal { return a.Sub(b) } +func mulF(a, b *Decimal) *Decimal { return a.Mul(b) } + +type binop struct { + sym string + fun func(a, b *Decimal) *Decimal +} + +func TestShiftL(t *testing.T) { + test := func(a string, b int, e string) { + aa, _ := ParseDecimal(a) + ee, _ := ParseDecimal(e) + actual := aa.ShiftL(b) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) + } + } + + test("0", 10, "0") + test("1", 0, "1") + test("123", 1, "1230") + test("123", 100, "123d100") + test("1.23d-100", 102, "123") +} + +func TestShiftR(t *testing.T) { + test := func(a string, b int, e string) { + aa, _ := ParseDecimal(a) + ee, _ := ParseDecimal(e) + actual := aa.ShiftR(b) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) + } + } + + test("0", 10, "0") + test("1", 0, "1") + test("123", 1, "12.3") + test("123", 100, "1.23d-98") + test("1.23d100", 98, "123") +} + +var add = &binop{"+", addF} +var sub = &binop{"-", subF} +var mul = &binop{"*", mulF} + +func testBinaryOp(t *testing.T, a, b, e string, op *binop) { + t.Run(a+op.sym+b+"="+e, func(t *testing.T) { + aa, _ := ParseDecimal(a) + bb, _ := ParseDecimal(b) + ee, _ := ParseDecimal(e) + + actual := op.fun(aa, bb) + if !actual.Equal(ee) { + t.Errorf("expected %v, got %v", ee, actual) + } + }) +} + +func TestAdd(t *testing.T) { + test := func(a, b, e string) { + testBinaryOp(t, a, b, e, add) + } + + test("1", "0", "1") + test("1", "1", "2") + test("1", "0.1", "1.1") + test("0.3", "0.06", "0.36") + test("1", "100", "101") + test("1d100", "1d98", "101d98") + test("1d-100", "1d-98", "1.01d-98") +} + +func TestSub(t *testing.T) { + test := func(a, b, e string) { + testBinaryOp(t, a, b, e, sub) + } + + test("1", "0", "1") + test("1", "1", "0") + test("1", "0.1", "0.9") + test("0.3", "0.06", "0.24") + test("1", "100", "-99") + test("1d100", "1d98", "99d98") + test("1d-100", "1d-98", "-99d-100") +} + +func TestMul(t *testing.T) { + test := func(a, b, e string) { + testBinaryOp(t, a, b, e, mul) + } + + test("1", "0", "0") + test("1", "1", "1") + test("2", "-1", "-2") + test("7", "6", "42") + test("10", "0.3", "3") + test("3d100", "2d50", "6d150") + test("3d-100", "2d-50", "6d-150") + test("2d100", "4d-98", "8d2") +} + +func TestTruncate(t *testing.T) { + test := func(a string, p int, expected string) { + t.Run(fmt.Sprintf("trunc(%v,%v)", a, p), func(t *testing.T) { + aa := MustParseDecimal(a) + actual := aa.Truncate(p).String() + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("1", 1, "1.") + test("1", 10, "1.") + test("10", 1, "1d1") + test("1999", 1, "1d3") + test("1.2345", 3, "1.23") + test("100d100", 2, "10d101") + test("1.2345d-100", 2, "1.2d-100") +} + +func TestCmp(t *testing.T) { + test := func(a, b string, expected int) { + t.Run("("+a+","+b+")", func(t *testing.T) { + ad, _ := ParseDecimal(a) + bd, _ := ParseDecimal(b) + actual := ad.Cmp(bd) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("0", "0", 0) + test("0", "1", -1) + test("0", "-1", 1) + + test("1d2", "100", 0) + test("100", "1d2", 0) + test("1d2", "10", 1) + test("10", "1d2", -1) + + test("0.01", "1d-2", 0) + test("1d-2", "0.01", 0) + test("0.01", "1d-3", 1) + test("1d-3", "0.01", -1) +} + +func TestUpscale(t *testing.T) { + d, _ := ParseDecimal("1d1") + actual := d.upscale(4).String() + if actual != "10.0000" { + t.Errorf("expected 10.0000, got %v", actual) + } +} diff --git a/ion/err.go b/ion/err.go new file mode 100644 index 00000000..737076a9 --- /dev/null +++ b/ion/err.go @@ -0,0 +1,88 @@ +package ion + +import "fmt" + +// A UsageError is returned when you use a Reader or Writer in an inappropriate way. +type UsageError struct { + API string + Msg string +} + +func (e *UsageError) Error() string { + return fmt.Sprintf("ion: usage error in %v: %v", e.API, e.Msg) +} + +// An IOError is returned when there is an error reading from or writing to an +// underlying io.Reader or io.Writer. +type IOError struct { + Err error +} + +func (e *IOError) Error() string { + return fmt.Sprintf("ion: i/o error: %v", e.Err) +} + +// A SyntaxError is returned when a Reader encounters invalid input for which no more +// specific error type is defined. +type SyntaxError struct { + Msg string + Offset uint64 +} + +func (e *SyntaxError) Error() string { + return fmt.Sprintf("ion: syntax error: %v (offset %v)", e.Msg, e.Offset) +} + +// An UnexpectedEOFError is returned when a Reader unexpectedly encounters an +// io.EOF error. +type UnexpectedEOFError struct { + Offset uint64 +} + +func (e *UnexpectedEOFError) Error() string { + return fmt.Sprintf("ion: unexpected end of input (offset %v)", e.Offset) +} + +// An UnsupportedVersionError is returned when a Reader encounters a binary version +// marker with a version that this library does not understand. +type UnsupportedVersionError struct { + Major int + Minor int + Offset uint64 +} + +func (e *UnsupportedVersionError) Error() string { + return fmt.Sprintf("ion: unsupported version %v.%v (offset %v)", e.Major, e.Minor, e.Offset) +} + +// An InvalidTagByteError is returned when a binary Reader encounters an invalid +// tag byte. +type InvalidTagByteError struct { + Byte byte + Offset uint64 +} + +func (e *InvalidTagByteError) Error() string { + return fmt.Sprintf("ion: invalid tag byte 0x%02X (offset %v)", e.Byte, e.Offset) +} + +// An UnexpectedRuneError is returned when a text Reader encounters an unexpected rune. +type UnexpectedRuneError struct { + Rune rune + Offset uint64 +} + +func (e *UnexpectedRuneError) Error() string { + return fmt.Sprintf("ion: unexpected rune %q (offset %v)", e.Rune, e.Offset) +} + +// An UnexpectedTokenError is returned when a text Reader encounters an unexpected +// token. +type UnexpectedTokenError struct { + Token string + Offset uint64 +} + +func (e *UnexpectedTokenError) Error() string { + return fmt.Sprintf("ion: unexpected token '%v' (offset %v)", e.Token, e.Offset) +} diff --git a/ion/fields.go b/ion/fields.go new file mode 100644 index 00000000..2f4a8f06 --- /dev/null +++ b/ion/fields.go @@ -0,0 +1,122 @@ +package ion + +import ( + "fmt" + "reflect" + "strings" +) + +// A field is a reflectively-accessed field of a struct type. +type field struct { + name string + typ reflect.Type + path []int + omitEmpty bool +} + +// A fielder maps out the fields of a type. +type fielder struct { + fields []field + index map[string]bool +} + +// FieldsFor returns the fields of the given struct type. +// TODO: cache me. +func fieldsFor(t reflect.Type) []field { + fldr := fielder{index: map[string]bool{}} + fldr.inspect(t, nil) + return fldr.fields +} + +// Inspect recursively inspects a type to determine all of its fields. +func (f *fielder) inspect(t reflect.Type, path []int) { + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + if !visible(&sf) { + // Skip non-visible fields. + continue + } + + tag := sf.Tag.Get("json") + if tag == "-" { + // Skip fields that are explicitly hidden by tag. + continue + } + name, opts := parseJSONTag(tag) + + newpath := make([]int, len(path)+1) + copy(newpath, path) + newpath[len(path)] = i + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + + if name == "" && sf.Anonymous && ft.Kind() == reflect.Struct { + // Dig in to the embedded struct. + f.inspect(ft, newpath) + } else { + // Add this named field. + if name == "" { + name = sf.Name + } + + if f.index[name] { + panic(fmt.Sprintf("too many fields named %v", name)) + } + f.index[name] = true + + f.fields = append(f.fields, field{ + name: name, + typ: ft, + path: newpath, + omitEmpty: omitEmpty(opts), + }) + } + } +} + +// Visible returns true if the given StructField should show up in the output. +func visible(sf *reflect.StructField) bool { + exported := sf.PkgPath == "" + if sf.Anonymous { + t := sf.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + // Fields of embedded structs are visible even if the struct type itself is not. + return true + } + } + return exported +} + +// ParseJSONTag parses a `json:"..."` field tag, returning the name and opts. +func parseJSONTag(tag string) (string, string) { + if idx := strings.Index(tag, ","); idx != -1 { + // Ignore additional JSON options, at least for now. + return tag[:idx], tag[idx+1:] + } + return tag, "" +} + +// OmitEmpty returns true if opts includes "omitempty". +func omitEmpty(opts string) bool { + for opts != "" { + var o string + + i := strings.Index(opts, ",") + if i >= 0 { + o, opts = opts[:i], opts[i+1:] + } else { + o, opts = opts, "" + } + + if o == "omitempty" { + return true + } + } + return false +} diff --git a/ion/marshal.go b/ion/marshal.go new file mode 100644 index 00000000..7a5529d4 --- /dev/null +++ b/ion/marshal.go @@ -0,0 +1,338 @@ +package ion + +import ( + "bytes" + "fmt" + "io" + "math/big" + "reflect" + "sort" + "time" +) + +// EncoderOpts holds bit-flag options for an Encoder. +type EncoderOpts uint + +const ( + // EncodeSortMaps instructs the encoder to write map keys in sorted order. + EncodeSortMaps EncoderOpts = 1 +) + +// MarshalText marshals values to text ion. +func MarshalText(v interface{}) ([]byte, error) { + buf := bytes.Buffer{} + w := NewTextWriterOpts(&buf, TextWriterQuietFinish) + e := Encoder{ + w: w, + opts: EncodeSortMaps, + } + + if err := e.Encode(v); err != nil { + return nil, err + } + if err := e.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// MarshalBinary marshals values to binary ion. +func MarshalBinary(v interface{}, ssts ...SharedSymbolTable) ([]byte, error) { + buf := bytes.Buffer{} + w := NewBinaryWriter(&buf, ssts...) + e := Encoder{w: w} + + if err := e.Encode(v); err != nil { + return nil, err + } + if err := e.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// MarshalBinaryLST marshals values to binary ion with a fixed local symbol table. +func MarshalBinaryLST(v interface{}, lst SymbolTable) ([]byte, error) { + buf := bytes.Buffer{} + w := NewBinaryWriterLST(&buf, lst) + e := Encoder{w: w} + + if err := e.Encode(v); err != nil { + return nil, err + } + if err := e.Finish(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// MarshalTo marshals the given value to the given writer. It does +// not call Finish, so is suitable for encoding values inside of +// a partially-constructed Ion value. +func MarshalTo(w Writer, v interface{}) error { + e := Encoder{ + w: w, + } + return e.Encode(v) +} + +// An Encoder writes Ion values to an output stream. +type Encoder struct { + w Writer + opts EncoderOpts +} + +// NewEncoder creates a new encoder. +func NewEncoder(w Writer) *Encoder { + return NewEncoderOpts(w, 0) +} + +// NewEncoderOpts creates a new encoder with the specified options. +func NewEncoderOpts(w Writer, opts EncoderOpts) *Encoder { + return &Encoder{ + w: w, + opts: opts, + } +} + +// NewTextEncoder creates a new text Encoder. +func NewTextEncoder(w io.Writer) *Encoder { + return NewEncoder(NewTextWriter(w)) +} + +// NewBinaryEncoder creates a new binary Encoder. +func NewBinaryEncoder(w io.Writer, ssts ...SharedSymbolTable) *Encoder { + return NewEncoder(NewBinaryWriter(w, ssts...)) +} + +// NewBinaryEncoderLST creates a new binary Encoder with a fixed local symbol table. +func NewBinaryEncoderLST(w io.Writer, lst SymbolTable) *Encoder { + return NewEncoder(NewBinaryWriterLST(w, lst)) +} + +// Encode marshals the given value to Ion, writing it to the underlying writer. +func (m *Encoder) Encode(v interface{}) error { + return m.encodeValue(reflect.ValueOf(v)) +} + +// Finish finishes writing the current Ion datagram. +func (m *Encoder) Finish() error { + return m.w.Finish() +} + +// EncodeValue recursively encodes a value. +func (m *Encoder) encodeValue(v reflect.Value) error { + if !v.IsValid() { + m.w.WriteNull() + return nil + } + + t := v.Type() + switch t.Kind() { + case reflect.Bool: + return m.w.WriteBool(v.Bool()) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return m.w.WriteInt(v.Int()) + + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + return m.w.WriteInt(int64(v.Uint())) + + case reflect.Uint, reflect.Uint64, reflect.Uintptr: + i := big.Int{} + i.SetUint64(v.Uint()) + return m.w.WriteBigInt(&i) + + case reflect.Float32, reflect.Float64: + return m.w.WriteFloat(v.Float()) + + case reflect.String: + return m.w.WriteString(v.String()) + + case reflect.Interface, reflect.Ptr: + return m.encodePtr(v) + + case reflect.Struct: + return m.encodeStruct(v) + + case reflect.Map: + return m.encodeMap(v) + + case reflect.Slice: + return m.encodeSlice(v) + + case reflect.Array: + return m.encodeArray(v) + + default: + return fmt.Errorf("ion: unsupported type: %v", v.Type().String()) + } +} + +// EncodePtr encodes an Ion null if the pointer is nil, and otherwise encodes the value that +// the pointer is pointing to. +func (m *Encoder) encodePtr(v reflect.Value) error { + if v.IsNil() { + return m.w.WriteNull() + } + return m.encodeValue(v.Elem()) +} + +// EncodeMap encodes a map to the output writer as an Ion struct. +func (m *Encoder) encodeMap(v reflect.Value) error { + if v.IsNil() { + return m.w.WriteNull() + } + + m.w.BeginStruct() + + keys := keysFor(v) + if m.opts&EncodeSortMaps != 0 { + sort.Slice(keys, func(i, j int) bool { return keys[i].s < keys[j].s }) + } + + for _, key := range keys { + m.w.FieldName(key.s) + value := v.MapIndex(key.v) + if err := m.encodeValue(value); err != nil { + return err + } + } + + return m.w.EndStruct() +} + +// A mapkey holds the reflective map key value as well as its stringified form. +type mapkey struct { + v reflect.Value + s string +} + +// KeysFor returns the stringified keys for the given map. +func keysFor(v reflect.Value) []mapkey { + keys := v.MapKeys() + res := make([]mapkey, len(keys)) + + for i, key := range keys { + // TODO: Handle other kinds of keys. + if key.Kind() != reflect.String { + panic("unexpected map key type") + } + res[i] = mapkey{ + v: key, + s: key.String(), + } + } + + return res +} + +// EncodeSlice encodes a slice to the output writer as an appropriate Ion type. +func (m *Encoder) encodeSlice(v reflect.Value) error { + if v.Type().Elem().Kind() == reflect.Uint8 { + return m.encodeBlob(v) + } + + if v.IsNil() { + return m.w.WriteNull() + } + + return m.encodeArray(v) +} + +// EncodeBlob encodes a []byte to the output writer as an Ion blob. +func (m *Encoder) encodeBlob(v reflect.Value) error { + if v.IsNil() { + return m.w.WriteNull() + } + return m.w.WriteBlob(v.Bytes()) +} + +// EncodeArray encodes an array to the output writer as an Ion list. +func (m *Encoder) encodeArray(v reflect.Value) error { + m.w.BeginList() + + for i := 0; i < v.Len(); i++ { + if err := m.encodeValue(v.Index(i)); err != nil { + return err + } + } + + return m.w.EndList() +} + +// EncodeStruct encodes a struct to the output writer as an Ion struct. +func (m *Encoder) encodeStruct(v reflect.Value) error { + t := v.Type() + if t == timeType { + return m.encodeTime(v) + } + if t == decimalType { + return m.encodeDecimal(v) + } + + fields := fieldsFor(v.Type()) + + m.w.BeginStruct() + +FieldLoop: + for i := range fields { + f := &fields[i] + + fv := v + for _, i := range f.path { + if fv.Kind() == reflect.Ptr { + if fv.IsNil() { + continue FieldLoop + } + fv = fv.Elem() + } + fv = fv.Field(i) + } + + if f.omitEmpty && emptyValue(fv) { + continue + } + + m.w.FieldName(f.name) + if err := m.encodeValue(fv); err != nil { + return err + } + } + + return m.w.EndStruct() +} + +// EncodeTime encodes a time.Time to the output writer as an Ion timestamp. +func (m *Encoder) encodeTime(v reflect.Value) error { + t := v.Interface().(time.Time) + return m.w.WriteTimestamp(t) +} + +// EncodeDecimal encodes an ion.Decimal to the output writer as an Ion decimal. +func (m *Encoder) encodeDecimal(v reflect.Value) error { + d := v.Addr().Interface().(*Decimal) + return m.w.WriteDecimal(d) +} + +// EmptyValue returns true if the given value is the empty value for its type. +func emptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} diff --git a/ion/marshal_test.go b/ion/marshal_test.go new file mode 100644 index 00000000..72e8dc7a --- /dev/null +++ b/ion/marshal_test.go @@ -0,0 +1,162 @@ +package ion + +import ( + "bytes" + "math" + "testing" + "time" +) + +func TestMarshalText(t *testing.T) { + test := func(v interface{}, eval string) { + t.Run(eval, func(t *testing.T) { + val, err := MarshalText(v) + if err != nil { + t.Fatal(err) + } + if string(val) != eval { + t.Errorf("expected '%v', got '%v'", eval, string(val)) + } + }) + } + + test(nil, "null") + test(true, "true") + test(false, "false") + + test(byte(42), "42") + test(-42, "-42") + test(uint64(math.MaxUint64), "18446744073709551615") + test(math.MinInt64, "-9223372036854775808") + + test(42.0, "4.2e+1") + test(math.Inf(1), "+inf") + test(math.Inf(-1), "-inf") + test(math.NaN(), "nan") + + test(MustParseDecimal("1.20"), "1.20") + test(time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC), "2010-01-01T00:00:00Z") + + test("hello\tworld", "\"hello\\tworld\"") + + test(struct{ A, B int }{42, 0}, "{A:42,B:0}") + test(struct { + A int `json:"val,ignoreme"` + B int `json:"-"` + C int `json:",omitempty"` + d int + }{42, 0, 0, 0}, "{val:42}") + + test(struct{ V interface{} }{}, "{V:null}") + test(struct{ V interface{} }{"42"}, "{V:\"42\"}") + + fourtytwo := 42 + + test(struct{ V *int }{}, "{V:null}") + test(struct{ V *int }{&fourtytwo}, "{V:42}") + + test(map[string]int{"b": 2, "a": 1}, "{a:1,b:2}") + + test(struct{ V []int }{}, "{V:null}") + test(struct{ V []int }{[]int{4, 2}}, "{V:[4,2]}") + + test(struct{ V []byte }{}, "{V:null}") + test(struct{ V []byte }{[]byte{4, 2}}, "{V:{{BAI=}}}") + + test(struct{ V [2]byte }{[2]byte{4, 2}}, "{V:[4,2]}") +} + +func TestMarshalBinary(t *testing.T) { + test := func(v interface{}, name string, eval []byte) { + t.Run(name, func(t *testing.T) { + val, err := MarshalBinary(v) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected '%v', got '%v'", fmtbytes(eval), fmtbytes(val)) + } + }) + } + + test(nil, "null", []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) + test(struct{ A, B int }{42, 0}, "{A:42,B:0}", []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE9, 0x81, 0x83, 0xD6, 0x87, 0xB4, 0x81, 'A', 0x81, 'B', + 0xD5, + 0x8A, 0x21, 0x2A, + 0x8B, 0x20, + }) +} + +func TestMarshalBinaryLST(t *testing.T) { + lsta := NewLocalSymbolTable(nil, nil) + lstb := NewLocalSymbolTable(nil, []string{ + "A", "B", + }) + + test := func(v interface{}, name string, lst SymbolTable, eval []byte) { + t.Run(name, func(t *testing.T) { + val, err := MarshalBinaryLST(v, lst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected '%v', got '%v'", fmtbytes(eval), fmtbytes(val)) + } + }) + } + + test(nil, "null", lsta, []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) + test(struct{ A, B int }{42, 0}, "{A:42,B:0}", lstb, []byte{ + 0xE0, 0x01, 0x00, 0xEA, + 0xE9, 0x81, 0x83, 0xD6, 0x87, 0xB4, 0x81, 'A', 0x81, 'B', + 0xD5, + 0x8A, 0x21, 0x2A, + 0x8B, 0x20, + }) +} + +func TestMarshalNestedStructs(t *testing.T) { + type gp struct { + A int `json:"a"` + } + + type gp2 struct { + B int `json:"b"` + } + + type parent struct { + gp + *gp2 + C int `json:"c"` + } + + type root struct { + parent + D int `json:"d"` + } + + v := root{ + parent: parent{ + gp: gp{ + A: 1, + }, + gp2: &gp2{ + B: 2, + }, + C: 3, + }, + D: 4, + } + + val, err := MarshalText(v) + if err != nil { + t.Fatal(err) + } + + eval := "{a:1,b:2,c:3,d:4}" + if string(val) != eval { + t.Errorf("expected %v, got %v", eval, string(val)) + } +} diff --git a/ion/reader.go b/ion/reader.go new file mode 100644 index 00000000..9f16c5af --- /dev/null +++ b/ion/reader.go @@ -0,0 +1,388 @@ +package ion + +import ( + "bufio" + "bytes" + "io" + "math" + "math/big" + "strings" + "time" +) + +// A Reader reads a stream of Ion values. +// +// The Reader has a logical position within the stream of values, influencing the +// values returnedd from its methods. Initially, the Reader is positioned before the +// first value in the stream. A call to Next advances the Reader to the first value +// in the stream, with subsequent calls advancing to subsequent values. When a call to +// Next moves the Reader to the position after the final value in the stream, it returns +// false, making it easy to loop through the values in a stream. +// +// var r Reader +// for r.Next() { +// // ... +// } +// +// Next also returns false in case of error. This can be distinguished from a legitimate +// end-of-stream by calling Err after exiting the loop. +// +// When positioned on an Ion value, the type of the value can be retrieved by calling +// Type. If it has an associated field name (inside a struct) or annotations, they can +// be read by calling FieldName and Annotations respectively. +// +// For atomic values, an appropriate XxxValue method can be called to read the value. +// For lists, sexps, and structs, you should instead call StepIn to move the Reader in +// to the contained sequence of values. The Reader will initially be positioned before +// the first value in the container. Calling Next without calling StepIn will skip over +// the composite value and return the next value in the outer value stream. +// +// At any point while reading through a composite value, including when Next returns false +// to indicate the end of the contained values, you may call StepOut to move back to the +// outer sequence of values. The Reader will be positioned at the end of the composite value, +// such that a call to Next will move to the immediately-following value (if any). +// +// r := NewTextReaderStr("[foo, bar] [baz]") +// for r.Next() { +// if err := r.StepIn(); err != nil { +// return err +// } +// for r.Next() { +// fmt.Println(r.StringValue()) +// } +// if err := r.StepOut(); err != nil { +// return err +// } +// } +// if err := r.Err(); err != nil { +// return err +// } +// +type Reader interface { + + // SymbolTable returns the current symbol table, or nil if there isn't one. + // Text Readers do not, generally speaking, have an associated symbol table. + // Binary Readers do. + SymbolTable() SymbolTable + + // Next advances the Reader to the next position in the current value stream. + // It returns true if this is the position of an Ion value, and false if it + // is not. On error, it returns false and sets Err. + Next() bool + + // Err returns an error if a previous call call to Next has failed. + Err() error + + // Type returns the type of the Ion value the Reader is currently positioned on. + // It returns NoType if the Reader is positioned before or after a value. + Type() Type + + // IsNull returns true if the current value is an explicit null. This may be true + // even if the Type is not NullType (for example, null.struct has type Struct). Yes, + // that's a bit confusing. + IsNull() bool + + // FieldName returns the field name associated with the current value. It returns + // the empty string if there is no current value or the current value has no field + // name. + FieldName() string + + // Annotations returns the set of annotations associated with the current value. + // It returns nil if there is no current value or the current value has no annotations. + Annotations() []string + + // StepIn steps in to the current value if it is a container. It returns an error if there + // is no current value or if the value is not a container. On success, the Reader is + // positioned before the first value in the container. + StepIn() error + + // StepOut steps out of the current container value being read. It returns an error if + // this Reader is not currently stepped in to a container. On success, the Reader is + // positioned after the end of the container, but before any subsequent values in the + // stream. + StepOut() error + + // BoolValue returns the current value as a boolean (if that makes sense). It returns + // an error if the current value is not an Ion bool. + BoolValue() (bool, error) + + // IntSize returns the size of integer needed to losslessly represent the current value + // (if that makes sense). It returns an error if the current value is not an Ion int. + IntSize() (IntSize, error) + + // IntValue returns the current value as a 32-bit integer (if that makes sense). It + // returns an error if the current value is not an Ion integer or requires more than + // 32 bits to represent losslessly. + IntValue() (int, error) + + // Int64Value returns the current value as a 64-bit integer (if that makes sense). It + // returns an error if the current value is not an Ion integer or requires more than + // 64 bits to represent losslessly. + Int64Value() (int64, error) + + // Uint64Value returns the current value as an unsigned 64-bit integer (if that makes + // sense). It returns an error if the current value is not an Ion integer, is negative, + // or requires more than 64 bits to represent losslessly. + Uint64Value() (uint64, error) + + // BigIntValue returns the current value as a big.Integer (if that makes sense). It + // returns an error if the current value is not an Ion integer. + BigIntValue() (*big.Int, error) + + // FloatValue returns the current value as a 64-bit floating point number (if that + // makes sense). It returns an error if the current value is not an Ion float. + FloatValue() (float64, error) + + // DecimalValue returns the current value as an arbitrary-precision Decimal (if that + // makes sense). It returns an error if the current value is not an Ion decimal. + DecimalValue() (*Decimal, error) + + // TimeValue returns the current value as a timestamp (if that makes sense). It returns + // an error if the current value is not an Ion timestamp. + TimeValue() (time.Time, error) + + // StringValue returns the current value as a string (if that makes sense). It returns + // an error if the current value is not an Ion symbol or an Ion string. + StringValue() (string, error) + + // ByteValue returns the current value as a byte slice (if that makes sense). It returns + // an error if the current value is not an Ion clob or an Ion blob. + ByteValue() ([]byte, error) +} + +// NewReader creates a new Ion reader of the appropriate type by peeking +// at the first several bytes of input for a binary version marker. +func NewReader(in io.Reader) Reader { + return NewReaderCat(in, nil) +} + +// NewReaderStr creates a new reader from a string. +func NewReaderStr(str string) Reader { + return NewReader(strings.NewReader(str)) +} + +// NewReaderBytes creates a new reader for the given bytes. +func NewReaderBytes(in []byte) Reader { + return NewReader(bytes.NewReader(in)) +} + +// NewReaderCat creates a new reader with the given catalog. +func NewReaderCat(in io.Reader, cat Catalog) Reader { + br := bufio.NewReader(in) + + bs, err := br.Peek(4) + if err == nil && bs[0] == 0xE0 && bs[3] == 0xEA { + return newBinaryReaderBuf(br, cat) + } + + return newTextReaderBuf(br) +} + +// A reader holds common implementation stuff to both the text and binary readers. +type reader struct { + ctx ctxstack + eof bool + err error + + fieldName string + annotations []string + valueType Type + value interface{} +} + +// Err returns the current error. +func (r *reader) Err() error { + return r.err +} + +// Type returns the current value's type. +func (r *reader) Type() Type { + return r.valueType +} + +// IsNull returns true if the current value is null. +func (r *reader) IsNull() bool { + return r.valueType != NoType && r.value == nil +} + +// FieldName returns the current value's field name. +func (r *reader) FieldName() string { + return r.fieldName +} + +// Annotations returns the current value's annotations. +func (r *reader) Annotations() []string { + return r.annotations +} + +// BoolValue returns the current value as a bool. +func (r *reader) BoolValue() (bool, error) { + if r.valueType != BoolType { + return false, &UsageError{"Reader.BoolValue", "value is not a bool"} + } + if r.value == nil { + return false, nil + } + return r.value.(bool), nil +} + +// IntSize returns the size of the current int value. +func (r *reader) IntSize() (IntSize, error) { + if r.valueType != IntType { + return NullInt, &UsageError{"Reader.IntSize", "value is not a int"} + } + if r.value == nil { + return NullInt, nil + } + + if i, ok := r.value.(int64); ok { + if i > math.MaxInt32 || i < math.MinInt32 { + return Int64, nil + } + return Int32, nil + } + + i := r.value.(*big.Int) + if i.IsUint64() { + return Uint64, nil + } + + return BigInt, nil +} + +// IntValue returns the current value as an int. +func (r *reader) IntValue() (int, error) { + i, err := r.Int64Value() + if err != nil { + return 0, err + } + if i > math.MaxInt32 || i < math.MinInt32 { + return 0, &UsageError{"Reader.IntValue", "value too large for an int32"} + } + return int(i), nil +} + +// Int64Value returns the current value as an int64. +func (r *reader) Int64Value() (int64, error) { + if r.valueType != IntType { + return 0, &UsageError{"Reader.Int64Value", "value is not an int"} + } + if r.value == nil { + return 0, nil + } + + if i, ok := r.value.(int64); ok { + return i, nil + } + + bi := r.value.(*big.Int) + if bi.IsInt64() { + return bi.Int64(), nil + } + + return 0, &UsageError{"Reader.Int64Value", "value too large for an int64"} +} + +// Uint64Value returns the current value as a uint64. +func (r *reader) Uint64Value() (uint64, error) { + if r.valueType != IntType { + return 0, &UsageError{"Reader.Uint64Value", "value is not an int"} + } + if r.value == nil { + return 0, nil + } + + if i, ok := r.value.(int64); ok { + if i >= 0 { + return uint64(i), nil + } + return 0, &UsageError{"Reader.Uint64Value", "value is negative"} + } + + bi := r.value.(*big.Int) + if bi.Sign() < 0 { + return 0, &UsageError{"Reader.Uint64Value", "value is negative"} + } + if !bi.IsUint64() { + return 0, &UsageError{"Reader.Uint64Value", "value too large for a uint64"} + } + return bi.Uint64(), nil +} + +// BigIntValue returns the current value as a big int. +func (r *reader) BigIntValue() (*big.Int, error) { + if r.valueType != IntType { + return nil, &UsageError{"Reader.BigIntValue", "value is not an int"} + } + if r.value == nil { + return nil, nil + } + + if i, ok := r.value.(int64); ok { + return big.NewInt(i), nil + } + return r.value.(*big.Int), nil +} + +// FloatValue returns the current value as a float. +func (r *reader) FloatValue() (float64, error) { + if r.valueType != FloatType { + return 0, &UsageError{"Reader.FloatValue", "value is not a float"} + } + if r.value == nil { + return 0.0, nil + } + return r.value.(float64), nil +} + +// DecimalValue returns the current value as a Decimal. +func (r *reader) DecimalValue() (*Decimal, error) { + if r.valueType != DecimalType { + return nil, &UsageError{"Reader.DecimalValue", "value is not a decimal"} + } + if r.value == nil { + return nil, nil + } + return r.value.(*Decimal), nil +} + +// TimeValue returns the current value as a time. +func (r *reader) TimeValue() (time.Time, error) { + if r.valueType != TimestampType { + return time.Time{}, &UsageError{"Reader.TimestampValue", "value is not a timestamp"} + } + if r.value == nil { + return time.Time{}, nil + } + return r.value.(time.Time), nil +} + +// StringValue returns the current value as a string. +func (r *reader) StringValue() (string, error) { + if r.valueType != StringType && r.valueType != SymbolType { + return "", &UsageError{"Reader.StringValue", "value is not a string"} + } + if r.value == nil { + return "", nil + } + return r.value.(string), nil +} + +// ByteValue returns the current value as a byte slice. +func (r *reader) ByteValue() ([]byte, error) { + if r.valueType != BlobType && r.valueType != ClobType { + return nil, &UsageError{"Reader.ByteValue", "value is not a lob"} + } + if r.value == nil { + return nil, nil + } + return r.value.([]byte), nil +} + +// Clear clears the current value from the reader. +func (r *reader) clear() { + r.fieldName = "" + r.annotations = nil + r.valueType = NoType + r.value = nil +} diff --git a/ion/reader_test.go b/ion/reader_test.go new file mode 100644 index 00000000..2e204de1 --- /dev/null +++ b/ion/reader_test.go @@ -0,0 +1,127 @@ +package ion + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "testing" +) + +var blacklist = map[string]bool{ + "ion-tests/iontestdata/good/emptyAnnotatedInt.10n": true, + "ion-tests/iontestdata/good/subfieldVarUInt32bit.ion": true, + "ion-tests/iontestdata/good/utf16.ion": true, + "ion-tests/iontestdata/good/utf32.ion": true, + "ion-tests/iontestdata/good/whitespace.ion": true, + "ion-tests/iontestdata/good/item1.10n": true, +} + +type drainfunc func(t *testing.T, r Reader, f string) + +func TestReadFiles(t *testing.T) { + testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { + drain(t, r, 0) + }) +} + +func drain(t *testing.T, r Reader, level int) { + for r.Next() { + // print(level, r.Type()) + + if !r.IsNull() { + switch r.Type() { + case StructType, ListType, SexpType: + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + + drain(t, r, level+1) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } + } + } + } + + if r.Err() != nil { + t.Fatal(r.Err()) + } +} + +func print(level int, obj interface{}) { + fmt.Print(" > ") + for i := 0; i < level; i++ { + fmt.Print(" ") + } + fmt.Println(obj) +} + +func TestDecodeFiles(t *testing.T) { + testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { + // fmt.Println(f) + d := NewDecoder(r) + for { + v, err := d.Decode() + if err == ErrNoInput { + break + } + if err != nil { + t.Fatal(err) + } + // fmt.Println(v) + _ = v + } + }) +} + +var emptyFiles = []string{ + "ion-tests/iontestdata/good/blank.ion", + "ion-tests/iontestdata/good/empty.ion", +} + +func isEmptyFile(f string) bool { + for _, s := range emptyFiles { + if f == s { + return true + } + } + return false +} + +func testReadDir(t *testing.T, path string, d drainfunc) { + files, err := ioutil.ReadDir(path) + if err != nil { + t.Fatal(err) + } + + for _, file := range files { + fp := filepath.Join(path, file.Name()) + if file.IsDir() { + testReadDir(t, fp, d) + } else { + t.Run(fp, func(t *testing.T) { + testReadFile(t, fp, d) + }) + } + } +} + +func testReadFile(t *testing.T, path string, d drainfunc) { + if _, ok := blacklist[path]; ok { + return + } + + // fmt.Println(path) + + file, err := os.Open(path) + if err != nil { + t.Fatal(err) + } + defer file.Close() + + r := NewReader(file) + + d(t, r, path) +} diff --git a/ion/skipper.go b/ion/skipper.go new file mode 100644 index 00000000..cf74ebb7 --- /dev/null +++ b/ion/skipper.go @@ -0,0 +1,863 @@ +package ion + +import ( + "fmt" + "io" +) + +// SkipContainerContents skips over the contents of a container of the given type. +func (t *tokenizer) SkipContainerContents(typ Type) error { + switch typ { + case StructType: + return t.skipStructHelper() + case ListType: + return t.skipListHelper() + case SexpType: + return t.skipSexpHelper() + default: + panic(fmt.Sprintf("invalid container type: %v", typ)) + } +} + +// Skips whitespace and a double-colon token, if there is one. +func (t *tokenizer) SkipDoubleColon() (bool, bool, error) { + ws, err := t.skipWhitespaceHelper() + if err != nil { + return false, false, err + } + + ok, err := t.skipDoubleColon() + if err != nil { + return false, false, err + } + + return ok, ws, nil +} + +// Peeks ahead to see if the next token is a dot, and +// if so skips it. If not, leaves the next token unconsumed. +func (t *tokenizer) SkipDot() (bool, error) { + c, err := t.peek() + if err != nil { + return false, err + } + if c != '.' { + return false, nil + } + + t.read() + return true, nil +} + +// SkipLobWhitespace skips whitespace when we're inside a large +// object ({{ ///= }} or {{ '''///=''' }}) where comments are +// not allowed. +func (t *tokenizer) SkipLobWhitespace() (int, error) { + c, _, err := t.skipLobWhitespace() + return c, err +} + +// SkipValue skips to the end of the current value, if the caller +// didn't bother to consume it before calling Next again. +func (t *tokenizer) skipValue() (int, error) { + var c int + var err error + + switch t.token { + case tokenNumber: + c, err = t.skipNumber() + case tokenBinary: + c, err = t.skipBinary() + case tokenHex: + c, err = t.skipHex() + case tokenTimestamp: + c, err = t.skipTimestamp() + case tokenSymbol: + c, err = t.skipSymbol() + case tokenSymbolQuoted: + c, err = t.skipSymbolQuoted() + case tokenSymbolOperator: + c, err = t.skipSymbolOperator() + case tokenString: + c, err = t.skipString() + case tokenLongString: + c, err = t.skipLongString() + case tokenOpenDoubleBrace: + c, err = t.skipBlob() + case tokenOpenBrace: + c, err = t.skipStruct() + case tokenOpenParen: + c, err = t.skipSexp() + case tokenOpenBracket: + c, err = t.skipList() + default: + panic(fmt.Sprintf("skipValue called with token=%v", t.token)) + } + + if err != nil { + return 0, err + } + + if isWhitespace(c) { + c, _, err = t.skipWhitespace() + if err != nil { + return 0, err + } + } + + t.unfinished = false + return c, nil +} + +// SkipNumber skips a (non-binary, non-hex) number. +func (t *tokenizer) skipNumber() (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + if c == '-' { + c, err = t.read() + if err != nil { + return 0, err + } + } + + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + + if c == '.' { + c, err = t.read() + if err != nil { + return 0, err + } + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + } + + if c == 'd' || c == 'D' || c == 'e' || c == 'E' { + c, err = t.read() + if err != nil { + return 0, err + } + if c == '+' || c == '-' { + c, err = t.read() + if err != nil { + return 0, err + } + } + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + } + + ok, err := t.isStopChar(c) + if err != nil { + return 0, err + } + if !ok { + return 0, t.invalidChar(c) + } + return c, nil +} + +// SkipBinary skips a binary literal value. +func (t *tokenizer) skipBinary() (int, error) { + isB := func(c int) bool { + return c == 'b' || c == 'B' + } + isBinaryDigit := func(c int) bool { + return c == '0' || c == '1' + } + return t.skipRadix(isB, isBinaryDigit) +} + +// SkipHex skips a hex value. +func (t *tokenizer) skipHex() (int, error) { + isX := func(c int) bool { + return c == 'x' || c == 'X' + } + return t.skipRadix(isX, isHexDigit) +} + +func (t *tokenizer) skipRadix(pok, dok matcher) (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + if c == '-' { + c, err = t.read() + if err != nil { + return 0, err + } + } + + if c != '0' { + return 0, t.invalidChar(c) + } + if err = t.expect(pok); err != nil { + return 0, err + } + + for { + c, err = t.read() + if err != nil { + return 0, err + } + if !dok(c) { + break + } + } + + ok, err := t.isStopChar(c) + if err != nil { + return 0, err + } + if !ok { + return 0, t.invalidChar(c) + } + + return c, nil +} + +// SkipTimestamp skips a timestamp value, returning the next character. +func (t *tokenizer) skipTimestamp() (int, error) { + // Read the first four digits, yyyy. + c, err := t.skipTimestampDigits(4) + if err != nil { + return 0, err + } + if c == 'T' { + // yyyyT + return t.read() + } + if c != '-' { + return 0, t.invalidChar(c) + } + + // Read the next two, yyyy-mm. + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c == 'T' { + // yyyy-mmT + return t.read() + } + if c != '-' { + return 0, t.invalidChar(c) + } + + // Read the day. + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != 'T' { + // yyyy-mm-dd. + return t.skipTimestampFinish(c) + } + + c, err = t.read() + if err != nil { + return 0, err + } + if !isDigit(c) { + // yyyy-mm-ddT(+hh:mm)? + c, err = t.skipTimestampOffset(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) + } + + // Already read the first hour digit above. + c, err = t.skipTimestampDigits(1) + if err != nil { + return 0, err + } + if c != ':' { + return 0, t.invalidChar(c) + } + + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != ':' { + // yyyy-mm-ddThh:mmZ + c, err = t.skipTimestampOffsetOrZ(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) + } + + c, err = t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != '.' { + // yyyy-mm-ddThh:mm:ssZ + c, err = t.skipTimestampOffsetOrZ(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) + } + + // yyyy-mm-ddThh:mm:ss.ssssZ + c, err = t.read() + if err != nil { + return 0, err + } + if isDigit(c) { + c, err = t.skipDigits(c) + if err != nil { + return 0, err + } + } + + c, err = t.skipTimestampOffsetOrZ(c) + if err != nil { + return 0, err + } + return t.skipTimestampFinish(c) +} + +// SkipTimestampOffsetOrZ skips a (required) timestamp offset value or +// letter 'Z' (indicating UTC). +func (t *tokenizer) skipTimestampOffsetOrZ(c int) (int, error) { + if c == '-' || c == '+' { + return t.skipTimestampOffset(c) + } + if c == 'z' || c == 'Z' { + return t.read() + } + return 0, t.invalidChar(c) +} + +// SkipTimestampOffset skips an (optional) +-hh:mm timestamp zone offset +// value. +func (t *tokenizer) skipTimestampOffset(c int) (int, error) { + if c != '-' && c != '+' { + return c, nil + } + + c, err := t.skipTimestampDigits(2) + if err != nil { + return 0, err + } + if c != ':' { + return 0, t.invalidChar(c) + } + return t.skipTimestampDigits(2) +} + +// SkipTimestampDigits skips a bounded sequence of digits inside a +// timestamp. +func (t *tokenizer) skipTimestampDigits(n int) (int, error) { + for n > 0 { + if err := t.expect(func(c int) bool { + return isDigit(c) + }); err != nil { + return 0, err + } + n-- + } + + return t.read() +} + +// SkipTimestampFinish makes sure the character after a timestamp +// value is a valid ending point. If so, it returns it. +func (t *tokenizer) skipTimestampFinish(c int) (int, error) { + ok, err := t.isStopChar(c) + if err != nil { + return 0, err + } + if !ok { + return 0, t.invalidChar(c) + } + return c, nil +} + +// SkipSymbol skips a normal symbol and returns the next character. +func (t *tokenizer) skipSymbol() (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + for isIdentifierPart(c) { + c, err = t.read() + if err != nil { + return 0, err + } + } + + return c, nil +} + +// SkipSymbolQuoted skips a quoted symbol and returns the next char. +func (t *tokenizer) skipSymbolQuoted() (int, error) { + if err := t.skipSymbolQuotedHelper(); err != nil { + return 0, err + } + return t.read() +} + +// SkipSymbolQuotedHelper skips a quoted symbol. +func (t *tokenizer) skipSymbolQuotedHelper() error { + for { + c, err := t.read() + if err != nil { + return err + } + + switch c { + case -1, '\n': + return t.invalidChar(c) + + case '\'': + return nil + + case '\\': + if _, err := t.read(); err != nil { + return err + } + } + } +} + +// SkipSymbolOperator skips an operator-style symbol inside an sexp. +func (t *tokenizer) skipSymbolOperator() (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + for isOperatorChar(c) { + c, err = t.read() + if err != nil { + return 0, err + } + } + + return c, nil +} + +// SkipString skips over a "-enclosed string, returning the next char. +func (t *tokenizer) skipString() (int, error) { + if err := t.skipStringHelper(); err != nil { + return 0, err + } + return t.read() +} + +// SkipStringHelper skips over a "-enclosed string. +func (t *tokenizer) skipStringHelper() error { + for { + c, err := t.read() + if err != nil { + return err + } + + switch c { + case -1, '\n': + return t.invalidChar(c) + + case '"': + return nil + + case '\\': + if _, err := t.read(); err != nil { + return err + } + } + } +} + +// SkipLongString skips over a '''-enclosed string, returning the next +// character after the closing '''. +func (t *tokenizer) skipLongString() (int, error) { + if err := t.skipLongStringHelper(t.skipCommentsHandler); err != nil { + return 0, err + } + return t.read() +} + +// SkipLongStringHelper skips over a '''-enclosed string. +func (t *tokenizer) skipLongStringHelper(handler commentHandler) error { + for { + c, err := t.read() + if err != nil { + return err + } + + switch c { + case -1: + return t.invalidChar(c) + + case '\'': + ok, err := t.skipEndOfLongString(handler) + if err != nil { + return err + } + if ok { + return nil + } + + case '\\': + if _, err = t.read(); err != nil { + return err + } + } + } +} + +// SkipEndOfLongString is called after reading a ' to determine if we've +// hit the end of the long string.. +func (t *tokenizer) skipEndOfLongString(handler commentHandler) (bool, error) { + // We just read a ', check for two more ''s. + cs, err := t.peekN(2) + if err != nil && err != io.EOF { + return false, err + } + + // If it's not a triple-quote, keep going. + if len(cs) < 2 || cs[0] != '\'' || cs[1] != '\'' { + return false, nil + } + + // Consume the triple-quote. + if err := t.skipN(2); err != nil { + return false, err + } + + // Consume any additional whitespace/comments. + c, _, err := t.skipWhitespaceWith(handler) + if err != nil { + return false, err + } + + // Check if it's another triple-quote; if so, keep going. + if c == '\'' { + ok, err := t.IsTripleQuote() + if err != nil { + return false, err + } + if ok { + return false, nil + } + } + + t.unread(c) + return true, nil +} + +// SkipBlob skips over a blob value, returning the next character. +func (t *tokenizer) skipBlob() (int, error) { + if err := t.skipBlobHelper(); err != nil { + return 0, err + } + return t.read() +} + +// SkipBlobHelper skips over a blob value, stopping after reading the +// final '}'. +func (t *tokenizer) skipBlobHelper() error { + c, _, err := t.skipLobWhitespace() + if err != nil { + return err + } + + // TODO: If this is a clob, could we potentially have an embedded + // '}' here? + for c != '}' { + c, _, err = t.skipLobWhitespace() + if err != nil { + return err + } + if c == -1 { + return t.invalidChar(c) + } + } + + return t.expect(func(c int) bool { + return c == '}' + }) +} + +func (t *tokenizer) skipStruct() (int, error) { + return t.skipContainer('}') +} + +func (t *tokenizer) skipStructHelper() error { + return t.skipContainerHelper('}') +} + +func (t *tokenizer) skipSexp() (int, error) { + return t.skipContainer(')') +} + +func (t *tokenizer) skipSexpHelper() error { + return t.skipContainerHelper(')') +} + +// SkipList skips forward past a list that the caller doesn't care to +// step in to. +func (t *tokenizer) skipList() (int, error) { + return t.skipContainer(']') +} + +func (t *tokenizer) skipListHelper() error { + return t.skipContainerHelper(']') +} + +// SkipContainer skips a container terminated by the given char and +// returns the next character. +func (t *tokenizer) skipContainer(term int) (int, error) { + if err := t.skipContainerHelper(term); err != nil { + return 0, err + } + return t.read() +} + +// SkipContainerHelper skips over a container terminated by the given +// char. +func (t *tokenizer) skipContainerHelper(term int) error { + if term != ']' && term != ')' && term != '}' { + panic("wat") + } + + for { + c, _, err := t.skipWhitespace() + if err != nil { + return err + } + + switch c { + case -1: + return t.invalidChar(c) + + case term: + return nil + + case '"': + if err := t.skipStringHelper(); err != nil { + return err + } + + case '\'': + ok, err := t.IsTripleQuote() + if err != nil { + return err + } + if ok { + if err = t.skipLongStringHelper(t.skipCommentsHandler); err != nil { + return err + } + } else { + if err = t.skipSymbolQuotedHelper(); err != nil { + return err + } + } + + case '(': + if err := t.skipContainerHelper(')'); err != nil { + return err + } + + case '[': + if err := t.skipContainerHelper(']'); err != nil { + return err + } + + case '{': + c, err := t.peek() + if err != nil { + return err + } + + if c == '{' { + if _, err := t.read(); err != nil { + return err + } + if err := t.skipBlobHelper(); err != nil { + return err + } + } else if c == '}' { + if _, err := t.read(); err != nil { + return err + } + } else { + if err := t.skipContainerHelper('}'); err != nil { + return err + } + } + } + } +} + +// SkipDigits skips a sequence of digits starting with the +// given character. +func (t *tokenizer) skipDigits(c int) (int, error) { + var err error + for err == nil && isDigit(c) { + c, err = t.read() + } + return c, err +} + +// SkipWhitespace skips whitespace (and comments) when we're out +// in normal parsing territory. +func (t *tokenizer) skipWhitespace() (int, bool, error) { + return t.skipWhitespaceWith(t.skipCommentsHandler) +} + +// SkipWhitespaceHelper is a 'helper' form of SkipWhitespace that +// unreads the first non-whitespace char instead of returning it. +func (t *tokenizer) skipWhitespaceHelper() (bool, error) { + c, ok, err := t.skipWhitespace() + if err != nil { + return false, err + } + t.unread(c) + return ok, err +} + +// SkipLobWhitespace skips whitespace when we're inside a large +// object ({{ ///= }} or {{ '''///=''' }}) where comments are +// not allowed. +func (t *tokenizer) skipLobWhitespace() (int, bool, error) { + // Comments are not allowed inside a lob value; if we see a '/', + // it's the start of a base64-encoded value. + return t.skipWhitespaceWith(stopForCommentsHandler) +} + +// CommentHandler is a strategy for handling comments. Returns true +// if it found and handled a comment, false if it didn't find a +// comment, and returns an error if it choked on the comment. +type commentHandler func() (bool, error) + +// SkipWhitespaceWith skips whitespace using the given strategy for +// handling comments--generally speaking, either skipping over them +// using skipCommentsHandler, or stopping with a stopForCommentsHandler. +// Returns the first non-whitespace character it reads, and whether it +// actually skipped anything to find it. +func (t *tokenizer) skipWhitespaceWith(handler commentHandler) (int, bool, error) { + skipped := false + for { + c, err := t.read() + if err != nil { + return 0, skipped, err + } + + switch c { + case ' ', '\t', '\n', '\r': + // Skipped. + + case '/': + comment, err := handler() + if err != nil { + return 0, skipped, err + } + if !comment { + return '/', skipped, nil + } + + default: + return c, skipped, nil + } + skipped = true + } +} + +// StopForCommentsHandler is a commentHandler that stops skipping +// whitespace when it finds a (potential) comment. Use it when you +// expect a '/' to be an actual '/', not a comment. +func stopForCommentsHandler() (bool, error) { + return false, nil +} + +// SkipCommentsHandler is a commentHandler that skips over any +// comments it finds. +func (t *tokenizer) skipCommentsHandler() (bool, error) { + // We've just read a '/', which might be the start of a comment. + // Peek ahead to see if it is, and if so skip over it. + c, err := t.peek() + if err != nil { + return false, err + } + + switch c { + case '/': + return true, t.skipSingleLineComment() + case '*': + return true, t.skipBlockComment() + default: + return false, nil + } +} + +// SkipSingleLineComment skips over the body of a single-line comment, +// terminated by the end of the line (or file). +func (t *tokenizer) skipSingleLineComment() error { + for { + c, err := t.read() + if err != nil { + return err + } + + if c == -1 || c == '\n' { + return nil + } + } +} + +// SkipBlockComment skips over the body of a block comment, terminated +// by a '*/' sequence. +func (t *tokenizer) skipBlockComment() error { + star := false + for { + c, err := t.read() + if err != nil { + return err + } + if c == -1 { + return t.invalidChar(c) + } + + if star && c == '/' { + return nil + } + + star = (c == '*') + } +} + +// Peeks ahead to see if the next token is a double colon, and +// if so skips it. If not, leaves the next token unconsumed. +func (t *tokenizer) skipDoubleColon() (bool, error) { + cs, err := t.peekN(2) + if err == io.EOF { + return false, nil + } + if err != nil { + return false, err + } + + if cs[0] == ':' && cs[1] == ':' { + t.skipN(2) + return true, nil + } + + return false, nil +} diff --git a/ion/skipper_test.go b/ion/skipper_test.go new file mode 100644 index 00000000..afb48874 --- /dev/null +++ b/ion/skipper_test.go @@ -0,0 +1,178 @@ +package ion + +import ( + "testing" +) + +func TestSkipNumber(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipNumber) + + test("", -1) + test("0", -1) + test("-1234567890,", ',') + test("1.2 ", ' ') + test("1d45\n", '\n') + test("1.4e-12//", '/') + + testErr("1.2d3d", "ion: unexpected rune 'd' (offset 5)") +} + +func TestSkipBinary(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipBinary) + + test("0b0", -1) + test("-0b10 ", ' ') + test("0b010101,", ',') + + testErr("0b2", "ion: unexpected rune '2' (offset 2)") +} + +func TestSkipHex(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipHex) + + test("0x0", -1) + test("-0x0F ", ' ') + test("0x1234567890abcdefABCDEF,", ',') + + testErr("0x0G", "ion: unexpected rune 'G' (offset 3)") +} + +func TestSkipTimestamp(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipTimestamp) + + test("2001T", -1) + test("2001-01T,", ',') + test("2001-01-02}", '}') + test("2001-01-02T ", ' ') + test("2001-01-02T+00:00\t", '\t') + test("2001-01-02T-00:00\n", '\n') + test("2001-01-02T03:04+00:00 ", ' ') + test("2001-01-02T03:04-00:00 ", ' ') + test("2001-01-02T03:04Z ", ' ') + test("2001-01-02T03:04z ", ' ') + test("2001-01-02T03:04:05Z ", ' ') + test("2001-01-02T03:04:05+00:00 ", ' ') + test("2001-01-02T03:04:05.666Z ", ' ') + test("2001-01-02T03:04:05.666666z ", ' ') + + testErr("", "ion: unexpected end of input (offset 0)") + testErr("2001", "ion: unexpected end of input (offset 4)") + testErr("2001z", "ion: unexpected rune 'z' (offset 4)") + testErr("20011", "ion: unexpected rune '1' (offset 4)") + testErr("2001-0", "ion: unexpected end of input (offset 6)") + testErr("2001-01", "ion: unexpected end of input (offset 7)") + testErr("2001-01-02Tz", "ion: unexpected rune 'z' (offset 11)") + testErr("2001-01-02T03", "ion: unexpected end of input (offset 13)") + testErr("2001-01-02T03z", "ion: unexpected rune 'z' (offset 13)") + testErr("2001-01-02T03:04x ", "ion: unexpected rune 'x' (offset 16)") + testErr("2001-01-02T03:04:05x ", "ion: unexpected rune 'x' (offset 19)") +} + +func TestSkipSymbol(t *testing.T) { + test, _ := testSkip(t, (*tokenizer).skipSymbol) + + test("f", -1) + test("foo:", ':') + test("foo,", ',') + test("foo ", ' ') + test("foo\n", '\n') + test("foo]", ']') + test("foo}", '}') + test("foo)", ')') + test("foo\\n", '\\') +} + +func TestSkipSymbolQuoted(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipSymbolQuoted) + + test("'", -1) + test("foo',", ',') + test("foo\\'bar':", ':') + test("foo\\\nbar',", ',') + + testErr("foo", "ion: unexpected end of input (offset 3)") + testErr("foo\n", "ion: unexpected rune '\\n' (offset 3)") +} + +func TestSkipSymbolOperator(t *testing.T) { + test, _ := testSkip(t, (*tokenizer).skipSymbolOperator) + + test("+", -1) + test("++", -1) + test("+= ", ' ') + test("%b", 'b') +} + +func TestSkipString(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipString) + + test("\"", -1) + test("\",", ',') + test("foo\\\"bar\"], \"\"", ']') + test("foo\\\nbar\" \t\t\t", ' ') + + testErr("foobar", "ion: unexpected end of input (offset 6)") + testErr("foobar\n", "ion: unexpected rune '\\n' (offset 6)") +} + +func TestSkipLongString(t *testing.T) { + test, _ := testSkip(t, (*tokenizer).skipLongString) + + test("'''", -1) + test("''',", ',') + test("abc''',", ',') + test("abc''' }", '}') + test("abc''' /*more*/ '''def'''\t//more\r\n]", ']') +} + +func TestSkipBlob(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipBlob) + + test("}}", -1) + test("oogboog}},{{}}", ',') + test("'''not encoded'''}}\n", '\n') + + testErr("", "ion: unexpected end of input (offset 1)") + testErr("oogboog", "ion: unexpected end of input (offset 7)") + testErr("oogboog}", "ion: unexpected end of input (offset 8)") + testErr("oog}{boog", "ion: unexpected rune '{' (offset 4)") +} + +func TestSkipList(t *testing.T) { + test, testErr := testSkip(t, (*tokenizer).skipList) + + test("]", -1) + test("[]],", ',') + test("[123, \"]\", ']']] ", ' ') + + testErr("abc, def, ", "ion: unexpected end of input (offset 10)") +} + +type skipFunc func(*tokenizer) (int, error) +type skipTestFunc func(string, int) +type skipTestErrFunc func(string, string) + +func testSkip(t *testing.T, f skipFunc) (skipTestFunc, skipTestErrFunc) { + test := func(str string, ec int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, err := f(tok) + if err != nil { + t.Fatal(err) + } + if c != ec { + t.Errorf("expected '%c', got '%c'", ec, c) + } + }) + } + testErr := func(str string, e string) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + _, err := f(tok) + if err == nil || err.Error() != e { + t.Errorf("expected err=%v, got err=%v", e, err) + } + }) + } + return test, testErr +} diff --git a/ion/symboltable.go b/ion/symboltable.go new file mode 100644 index 00000000..922a1dc3 --- /dev/null +++ b/ion/symboltable.go @@ -0,0 +1,475 @@ +package ion + +import ( + "strings" +) + +// A SymbolTable maps binary-representation symbol IDs to +// text-representation strings and vice versa. +type SymbolTable interface { + // Imports returns the symbol tables this table imports. + Imports() []SharedSymbolTable + // Symbols returns the symbols this symbol table defines. + Symbols() []string + // MaxID returns the maximum ID this symbol table defines. + MaxID() uint64 + + // FindByName finds the ID of a symbol by its name. + FindByName(symbol string) (uint64, bool) + // FindByID finds the name of a symbol given its ID. + FindByID(id uint64) (string, bool) + // WriteTo serializes the symbol table to an ion.Writer. + WriteTo(w Writer) error + // String returns an ion text representation of the symbol table. + String() string +} + +// A SharedSymbolTable is distributed out-of-band and referenced from +// a local SymbolTable to save space. +type SharedSymbolTable interface { + SymbolTable + + // Name returns the name of this shared symbol table. + Name() string + // Version returns the version of this shared symbol table. + Version() int + // Adjust returns a new shared symbol table limited or extended to the given max ID. + Adjust(maxID uint64) SharedSymbolTable +} + +type sst struct { + name string + version int + symbols []string + index map[string]uint64 + maxID uint64 +} + +// NewSharedSymbolTable creates a new shared symbol table. +func NewSharedSymbolTable(name string, version int, symbols []string) SharedSymbolTable { + syms := make([]string, len(symbols)) + copy(syms, symbols) + + index := buildIndex(syms, 1) + + return &sst{ + name: name, + version: version, + symbols: syms, + index: index, + maxID: uint64(len(syms)), + } +} + +func (s *sst) Name() string { + return s.name +} + +func (s *sst) Version() int { + return s.version +} + +func (s *sst) Imports() []SharedSymbolTable { + return nil +} + +func (s *sst) Symbols() []string { + syms := make([]string, s.maxID) + copy(syms, s.symbols) + return syms +} + +func (s *sst) MaxID() uint64 { + return uint64(s.maxID) +} + +func (s *sst) Adjust(maxID uint64) SharedSymbolTable { + if maxID == s.maxID { + // Nothing needs to change. + return s + } + + if maxID > uint64(len(s.symbols)) { + // Old index will work fine, just adjust the maxID. + return &sst{ + name: s.name, + version: s.version, + symbols: s.symbols, + index: s.index, + maxID: maxID, + } + } + + // Slice the symbols down to size and reindex. + symbols := s.symbols[:maxID] + index := buildIndex(symbols, 1) + + return &sst{ + name: s.name, + version: s.version, + symbols: symbols, + index: index, + maxID: maxID, + } +} + +func (s *sst) FindByName(sym string) (uint64, bool) { + id, ok := s.index[sym] + return uint64(id), ok +} + +func (s *sst) FindByID(id uint64) (string, bool) { + if id <= 0 || id > uint64(len(s.symbols)) { + return "", false + } + return s.symbols[id-1], true +} + +func (s *sst) WriteTo(w Writer) error { + w.Annotation("$ion_shared_symbol_table") + w.BeginStruct() + { + w.FieldName("name") + w.WriteString(s.name) + + w.FieldName("version") + w.WriteInt(int64(s.version)) + + w.FieldName("symbols") + w.BeginList() + { + for _, sym := range s.symbols { + w.WriteString(sym) + } + } + w.EndList() + } + return w.EndStruct() +} + +func (s *sst) String() string { + buf := strings.Builder{} + + w := NewTextWriter(&buf) + s.WriteTo(w) + + return buf.String() +} + +// V1SystemSymbolTable is the (implied) system symbol table for Ion v1.0. +var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ + "$ion", + "$ion_1_0", + "$ion_symbol_table", + "name", + "version", + "imports", + "symbols", + "max_id", + "$ion_shared_symbol_table", +}) + +// A BogusSST represents an SST imported by an LST that cannot be found in the +// local catalog. It exists to reserve some part of the symbol ID space so other +// symbol tables get mapped to the right IDs. +type bogusSST struct { + name string + version int + maxID uint64 +} + +var _ SharedSymbolTable = &bogusSST{} + +func (s *bogusSST) Name() string { + return s.name +} + +func (s *bogusSST) Version() int { + return s.version +} + +func (s *bogusSST) Imports() []SharedSymbolTable { + return nil +} + +func (s *bogusSST) Symbols() []string { + return nil +} + +func (s *bogusSST) MaxID() uint64 { + return s.maxID +} + +func (s *bogusSST) Adjust(maxID uint64) SharedSymbolTable { + return &bogusSST{ + name: s.name, + version: s.version, + maxID: maxID, + } +} + +func (s *bogusSST) FindByName(sym string) (uint64, bool) { + return 0, false +} + +func (s *bogusSST) FindByID(id uint64) (string, bool) { + return "", false +} + +func (s *bogusSST) WriteTo(w Writer) error { + return &UsageError{"SharedSymbolTable.WriteTo", "bogus symbol table should never be written"} +} + +func (s *bogusSST) String() string { + buf := strings.Builder{} + w := NewTextWriter(&buf) + w.Annotations("$ion_shared_symbol_table", "bogus") + w.BeginStruct() + + w.FieldName("name") + w.WriteString(s.name) + + w.FieldName("version") + w.WriteInt(int64(s.version)) + + w.FieldName("max_id") + w.WriteUint(s.maxID) + + w.EndStruct() + return buf.String() +} + +// A LocalSymbolTable is transmitted in-band along with the binary data +// it describes. It may include SharedSymbolTables by reference. +type lst struct { + imports []SharedSymbolTable + offsets []uint64 + maxImportID uint64 + + symbols []string + index map[string]uint64 +} + +// NewLocalSymbolTable creates a new local symbol table. +func NewLocalSymbolTable(imports []SharedSymbolTable, symbols []string) SymbolTable { + imps, offsets, maxID := processImports(imports) + syms := make([]string, len(symbols)) + copy(syms, symbols) + + index := buildIndex(syms, maxID+1) + + return &lst{ + imports: imps, + offsets: offsets, + maxImportID: maxID, + symbols: syms, + index: index, + } +} + +func (t *lst) Imports() []SharedSymbolTable { + imps := make([]SharedSymbolTable, len(t.imports)) + copy(imps, t.imports) + return imps +} + +func (t *lst) Symbols() []string { + syms := make([]string, len(t.symbols)) + copy(syms, t.symbols) + return syms +} + +func (t *lst) MaxID() uint64 { + return t.maxImportID + uint64(len(t.symbols)) +} + +func (t *lst) FindByName(s string) (uint64, bool) { + for i, imp := range t.imports { + if id, ok := imp.FindByName(s); ok { + return t.offsets[i] + id, true + } + } + + if id, ok := t.index[s]; ok { + return id, true + } + + return 0, false +} + +func (t *lst) FindByID(id uint64) (string, bool) { + if id <= 0 { + return "", false + } + if id <= t.maxImportID { + return t.findByIDInImports(id) + } + + // Local to this symbol table. + idx := id - t.maxImportID - 1 + if idx < uint64(len(t.symbols)) { + return t.symbols[idx], true + } + + return "", false +} + +func (t *lst) findByIDInImports(id uint64) (string, bool) { + i := 1 + off := uint64(0) + + for ; i < len(t.imports); i++ { + if id <= t.offsets[i] { + break + } + off = t.offsets[i] + } + + return t.imports[i-1].FindByID(id - off) +} + +func (t *lst) WriteTo(w Writer) error { + if len(t.imports) == 1 && len(t.symbols) == 0 { + return nil + } + + w.Annotation("$ion_symbol_table") + w.BeginStruct() + + if len(t.imports) > 1 { + w.FieldName("imports") + w.BeginList() + for i := 1; i < len(t.imports); i++ { + imp := t.imports[i] + w.BeginStruct() + + w.FieldName("name") + w.WriteString(imp.Name()) + + w.FieldName("version") + w.WriteInt(int64(imp.Version())) + + w.FieldName("max_id") + w.WriteUint(imp.MaxID()) + + w.EndStruct() + } + w.EndList() + } + + if len(t.symbols) > 0 { + w.FieldName("symbols") + + w.BeginList() + for _, sym := range t.symbols { + w.WriteString(sym) + } + w.EndList() + } + + return w.EndStruct() +} + +func (t *lst) String() string { + buf := strings.Builder{} + + w := NewTextWriter(&buf) + t.WriteTo(w) + + return buf.String() +} + +// A SymbolTableBuilder helps you iteratively build a local symbol table. +type SymbolTableBuilder interface { + SymbolTable + + // Add adds a symbol to this symbol table. + Add(symbol string) (uint64, bool) + // Build creates an immutable local symbol table. + Build() SymbolTable +} + +type symbolTableBuilder struct { + lst +} + +// NewSymbolTableBuilder creates a new symbol table builder with the given imports. +func NewSymbolTableBuilder(imports ...SharedSymbolTable) SymbolTableBuilder { + imps, offsets, maxID := processImports(imports) + return &symbolTableBuilder{ + lst{ + imports: imps, + offsets: offsets, + maxImportID: maxID, + index: make(map[string]uint64), + }, + } +} + +func (b *symbolTableBuilder) Add(symbol string) (uint64, bool) { + if id, ok := b.FindByName(symbol); ok { + return id, false + } + + b.symbols = append(b.symbols, symbol) + id := b.maxImportID + uint64(len(b.symbols)) + b.index[symbol] = id + + return id, true +} + +func (b *symbolTableBuilder) Build() SymbolTable { + symbols := append([]string{}, b.symbols...) + index := make(map[string]uint64) + for s, i := range b.index { + index[s] = uint64(i) + } + + return &lst{ + imports: b.imports, + offsets: b.offsets, + maxImportID: b.maxImportID, + symbols: symbols, + index: index, + } +} + +// ProcessImports processes a slice of imports, returning an (augmented) copy, a set of +// offsets for each import, and the overall max ID. +func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []uint64, uint64) { + // Add in V1SystemSymbolTable at the head of the list if it's not already included. + var imps []SharedSymbolTable + if len(imports) > 0 && imports[0].Name() == "$ion" { + imps = make([]SharedSymbolTable, len(imports)) + copy(imps, imports) + } else { + imps = make([]SharedSymbolTable, len(imports)+1) + imps[0] = V1SystemSymbolTable + copy(imps[1:], imports) + } + + // Calculate offsets. + maxID := uint64(0) + offsets := make([]uint64, len(imps)) + for i, imp := range imps { + offsets[i] = maxID + maxID += imp.MaxID() + } + + return imps, offsets, maxID +} + +// BuildIndex builds an index from symbol name to symbol ID. +func buildIndex(symbols []string, offset uint64) map[string]uint64 { + index := make(map[string]uint64) + + for i, sym := range symbols { + if sym != "" { + if _, ok := index[sym]; !ok { + index[sym] = offset + uint64(i) + } + } + } + + return index +} diff --git a/ion/symboltable_test.go b/ion/symboltable_test.go new file mode 100644 index 00000000..4fe1aaed --- /dev/null +++ b/ion/symboltable_test.go @@ -0,0 +1,181 @@ +package ion + +import ( + "fmt" + "testing" +) + +func TestSharedSymbolTable(t *testing.T) { + st := NewSharedSymbolTable("test", 2, []string{ + "abc", + "def", + "foo'bar", + "null", + "def", + "ghi", + }) + + if st.Name() != "test" { + t.Errorf("wrong name: %v", st.Name()) + } + if st.Version() != 2 { + t.Errorf("wrong version: %v", st.Version()) + } + if st.MaxID() != 6 { + t.Errorf("wrong maxid: %v", st.MaxID()) + } + + testFindByName(t, st, "def", 2) + testFindByName(t, st, "null", 4) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 0, "") + testFindByID(t, st, 2, "def") + testFindByID(t, st, 4, "null") + testFindByID(t, st, 7, "") + + testString(t, st, `$ion_shared_symbol_table::{name:"test",version:2,symbols:["abc","def","foo'bar","null","def","ghi"]}`) +} + +func TestLocalSymbolTable(t *testing.T) { + st := NewLocalSymbolTable(nil, []string{"foo", "bar"}) + + if st.MaxID() != 11 { + t.Errorf("wrong maxid: %v", st.MaxID()) + } + + testFindByName(t, st, "$ion", 1) + testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bar", 11) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 0, "") + testFindByID(t, st, 1, "$ion") + testFindByID(t, st, 10, "foo") + testFindByID(t, st, 11, "bar") + testFindByID(t, st, 12, "") + + testString(t, st, `$ion_symbol_table::{symbols:["foo","bar"]}`) +} + +func TestLocalSymbolTableWithImports(t *testing.T) { + shared := NewSharedSymbolTable("shared", 1, []string{ + "foo", + "bar", + }) + imports := []SharedSymbolTable{shared} + + st := NewLocalSymbolTable(imports, []string{ + "foo2", + "bar2", + }) + + if st.MaxID() != 13 { // 9 from $ion.1, 2 from test.1, 2 local. + t.Errorf("wrong maxid: %v", st.MaxID()) + } + + testFindByName(t, st, "$ion", 1) + testFindByName(t, st, "$ion_shared_symbol_table", 9) + testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bar", 11) + testFindByName(t, st, "foo2", 12) + testFindByName(t, st, "bar2", 13) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 0, "") + testFindByID(t, st, 1, "$ion") + testFindByID(t, st, 9, "$ion_shared_symbol_table") + testFindByID(t, st, 10, "foo") + testFindByID(t, st, 11, "bar") + testFindByID(t, st, 12, "foo2") + testFindByID(t, st, 13, "bar2") + testFindByID(t, st, 14, "") + + testString(t, st, `$ion_symbol_table::{imports:[{name:"shared",version:1,max_id:2}],symbols:["foo2","bar2"]}`) +} + +func TestSymbolTableBuilder(t *testing.T) { + b := NewSymbolTableBuilder() + + id, ok := b.Add("name") + if ok { + t.Error("Add(name) returned true") + } + if id != 4 { + t.Errorf("Add(name) returned %v", id) + } + + id, ok = b.Add("foo") + if !ok { + t.Error("Add(foo) returned false") + } + if id != 10 { + t.Errorf("Add(foo) returned %v", id) + } + + id, ok = b.Add("foo") + if ok { + t.Error("Second Add(foo) returned true") + } + if id != 10 { + t.Errorf("Second Add(foo) returned %v", id) + } + + st := b.Build() + if st.MaxID() != 10 { + t.Errorf("maxid returned %v", st.MaxID()) + } + + testFindByName(t, st, "$ion", 1) + testFindByName(t, st, "foo", 10) + testFindByName(t, st, "bogus", 0) + + testFindByID(t, st, 1, "$ion") + testFindByID(t, st, 10, "foo") + testFindByID(t, st, 11, "") +} + +func testFindByName(t *testing.T, st SymbolTable, sym string, expected uint64) { + t.Run("FindByName("+sym+")", func(t *testing.T) { + actual, ok := st.FindByName(sym) + if expected == 0 { + if ok { + t.Fatalf("unexpectedly found: %v", actual) + } + } else { + if !ok { + t.Fatal("unexpectedly not found") + } + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + } + }) +} + +func testFindByID(t *testing.T, st SymbolTable, id uint64, expected string) { + t.Run(fmt.Sprintf("FindByID(%v)", id), func(t *testing.T) { + actual, ok := st.FindByID(id) + if expected == "" { + if ok { + t.Fatalf("unexpectedly found: %v", actual) + } + } else { + if !ok { + t.Fatal("unexpectedly not found") + } + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + } + }) +} + +func testString(t *testing.T, st SymbolTable, expected string) { + t.Run("String()", func(t *testing.T) { + actual := st.String() + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) +} diff --git a/ion/textreader.go b/ion/textreader.go new file mode 100644 index 00000000..d9cf4077 --- /dev/null +++ b/ion/textreader.go @@ -0,0 +1,658 @@ +package ion + +import ( + "bufio" + "encoding/base64" + "fmt" + "math" + "strconv" +) + +// trs is the state of the text reader. +type trs uint8 + +const ( + trsDone trs = iota + trsBeforeFieldName + trsBeforeTypeAnnotations + trsBeforeContainer + trsAfterValue +) + +func (s trs) String() string { + switch s { + case trsDone: + return "" + case trsBeforeFieldName: + return "" + case trsBeforeTypeAnnotations: + return "" + case trsBeforeContainer: + return "" + case trsAfterValue: + return "" + default: + return strconv.Itoa(int(s)) + } +} + +// A textReader is a Reader that reads text Ion. +type textReader struct { + reader + + tok tokenizer + state trs +} + +func newTextReaderBuf(in *bufio.Reader) Reader { + return &textReader{ + tok: tokenizer{ + in: in, + }, + state: trsBeforeTypeAnnotations, + } +} + +// SymbolTable returns the current symbol table. +func (t *textReader) SymbolTable() SymbolTable { + // TODO: Include me if present in the input stream? + return nil +} + +// Next moves the reader to the next value. +func (t *textReader) Next() bool { + if t.state == trsDone || t.eof { + return false + } + + // If we haven't fully read the current value, skip over it. + err := t.finishValue() + if err != nil { + t.explode(err) + return false + } + + t.clear() + + // Loop until we've consumed enough tokens to know what the next value is. + for { + if err := t.tok.Next(); err != nil { + t.explode(err) + return false + } + + var done bool + var err error + + switch t.state { + case trsAfterValue: + done, err = t.nextAfterValue() + case trsBeforeFieldName: + done, err = t.nextBeforeFieldName() + case trsBeforeTypeAnnotations: + done, err = t.nextBeforeTypeAnnotations() + default: + panic(fmt.Sprintf("unexpected state: %v", t.state)) + } + if err != nil { + t.explode(err) + return false + } + + if done { + // We're done reading tokens. If we hit the end of the current sequence, + // return false. Otherwise, we've got a value for the caller. + return !t.eof + } + } +} + +// NextAfterValue moves to the next value when we're in the +// AfterValue state. +func (t *textReader) nextAfterValue() (bool, error) { + tok := t.tok.Token() + switch tok { + case tokenComma: + // There's another value coming; eat the comma and move to the + // appropriate next state. + switch t.ctx.peek() { + case ctxInStruct: + t.state = trsBeforeFieldName + case ctxInList: + t.state = trsBeforeTypeAnnotations + default: + panic(fmt.Sprintf("unexpected context: %v", t.ctx.peek())) + } + return false, nil + + case tokenCloseBrace: + // No more values in this struct. + if t.ctx.peek() == ctxInStruct { + t.eof = true + return true, nil + } + return false, &UnexpectedTokenError{"}", t.tok.Pos() - 1} + + case tokenCloseBracket: + // No more values in this list. + if t.ctx.peek() == ctxInList { + t.eof = true + return true, nil + } + return false, &UnexpectedTokenError{"]", t.tok.Pos() - 1} + + default: + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} + } +} + +// NextBeforeFieldName moves to the next value when we're in the +// BeforeFieldName state. +func (t *textReader) nextBeforeFieldName() (bool, error) { + tok := t.tok.Token() + switch tok { + case tokenCloseBrace: + // No more values in this struct. + t.eof = true + return true, nil + + case tokenSymbol, tokenSymbolQuoted, tokenString, tokenLongString: + // Read the field name. + val, err := t.tok.ReadValue(tok) + if err != nil { + return false, err + } + if tok == tokenSymbol { + if err := t.verifyUnquotedSymbol(val, "field name"); err != nil { + return false, err + } + } + + // Skip over the following colon. + if err = t.tok.Next(); err != nil { + return false, err + } + if tok = t.tok.Token(); tok != tokenColon { + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} + } + + t.fieldName = val + t.state = trsBeforeTypeAnnotations + + return false, nil + + default: + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} + } +} + +// NextBeforeTypeAnnotations moves to the next value when we're in the +// BeforeTypeAnnotations state. +func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { + tok := t.tok.Token() + switch tok { + case tokenEOF: + if t.ctx.peek() == ctxAtTopLevel { + t.eof = true + return true, nil + } + return false, &UnexpectedEOFError{t.tok.Pos() - 1} + + case tokenSymbolOperator, tokenDot: + if t.ctx.peek() != ctxInSexp { + // Operators can only appear inside an sexp. + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} + } + fallthrough + + case tokenSymbol, tokenSymbolQuoted: + val, err := t.tok.ReadValue(tok) + if err != nil { + return false, err + } + + ok, ws, err := t.tok.SkipDoubleColon() + if err != nil { + return false, err + } + + if ok { + // val was an annotation; remember it and keep going. + if tok == tokenSymbol { + if err := t.verifyUnquotedSymbol(val, "annotation"); err != nil { + return false, err + } + } + t.annotations = append(t.annotations, val) + return false, nil + } + + // val was a legit symbol value. + if err := t.onSymbol(val, tok, ws); err != nil { + return false, err + } + return true, nil + + case tokenString, tokenLongString: + val, err := t.tok.ReadValue(tok) + if err != nil { + return false, err + } + + t.state = t.stateAfterValue() + t.valueType = StringType + t.value = val + return true, nil + + case tokenBinary, tokenHex, tokenNumber, tokenFloatInf, tokenFloatMinusInf: + if err := t.onNumber(tok); err != nil { + return false, err + } + return true, nil + + case tokenTimestamp: + if err := t.onTimestamp(); err != nil { + return false, err + } + return true, nil + + case tokenOpenDoubleBrace: + if err := t.onLob(); err != nil { + return false, err + } + return true, nil + + case tokenOpenBrace: + t.state = trsBeforeContainer + t.valueType = StructType + t.value = StructType + return true, nil + + case tokenOpenBracket: + t.state = trsBeforeContainer + t.valueType = ListType + t.value = ListType + return true, nil + + case tokenOpenParen: + t.state = trsBeforeContainer + t.valueType = SexpType + t.value = SexpType + return true, nil + + case tokenCloseBracket: + // No more values in this list. + if t.ctx.peek() == ctxInList { + t.eof = true + return true, nil + } + return false, &UnexpectedTokenError{"]", t.tok.Pos() - 1} + + case tokenCloseParen: + // No more values in this sexp. + if t.ctx.peek() == ctxInSexp { + t.eof = true + return true, nil + } + return false, &UnexpectedTokenError{")", t.tok.Pos() - 1} + + default: + return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} + } +} + +// StepIn steps in to a container. +func (t *textReader) StepIn() error { + if t.err != nil { + return t.err + } + if t.state != trsBeforeContainer { + return &UsageError{"Reader.StepIn", fmt.Sprintf("cannot step in to a %v", t.valueType)} + } + + ctx := containerTypeToCtx(t.valueType) + t.ctx.push(ctx) + + if ctx == ctxInStruct { + t.state = trsBeforeFieldName + } else { + t.state = trsBeforeTypeAnnotations + } + + t.clear() + + t.tok.SetFinished() + return nil +} + +// StepOut steps out of a container. +func (t *textReader) StepOut() error { + if t.err != nil { + return t.err + } + + ctx := t.ctx.peek() + if ctx == ctxAtTopLevel { + return &UsageError{"Reader.StepOut", "cannot step out of top-level datagram"} + } + ctype := ctxToContainerType(ctx) + + // Finish off whatever value *inside* the container that we're currently reading. + _, err := t.tok.FinishValue() + if err != nil { + t.explode(err) + return err + } + + // If we haven't seen the end of the container yet, skip values until we find it. + if !t.eof { + if err := t.tok.SkipContainerContents(ctype); err != nil { + t.explode(err) + return err + } + } + + t.ctx.pop() + t.state = t.stateAfterValue() + t.clear() + t.eof = false + + return nil +} + +// VerifyUnquotedSymbol checks for certain 'special' values that are returned from +// the tokenizer as symbols but cannot be used as field names or annotations. +func (t *textReader) verifyUnquotedSymbol(val string, ctx string) error { + switch val { + case "null", "true", "false", "nan": + return &SyntaxError{fmt.Sprintf("unquoted keyword '%v' as %v", val, ctx), t.tok.Pos() - 1} + } + return nil +} + +// OnSymbol handles finding a symbol-token value. +func (t *textReader) onSymbol(val string, tok token, ws bool) error { + valueType := SymbolType + var value interface{} = val + + if tok == tokenSymbol { + switch val { + case "null": + vt, err := t.onNull(ws) + if err != nil { + return err + } + valueType = vt + value = nil + + case "true": + valueType = BoolType + value = true + + case "false": + valueType = BoolType + value = false + + case "nan": + valueType = FloatType + value = math.NaN() + } + } + + t.state = t.stateAfterValue() + t.valueType = valueType + t.value = value + + return nil +} + +// OnNull handles finding a null token. +func (t *textReader) onNull(ws bool) (Type, error) { + if !ws { + ok, err := t.tok.SkipDot() + if err != nil { + return NoType, err + } + if ok { + return t.readNullType() + } + } + return NullType, nil +} + +// readNullType reads the null.{this} type symbol. +func (t *textReader) readNullType() (Type, error) { + if err := t.tok.Next(); err != nil { + return NoType, err + } + if t.tok.Token() != tokenSymbol { + msg := fmt.Sprintf("invalid symbol null.%v", t.tok.Token()) + return NoType, &SyntaxError{msg, t.tok.Pos() - 1} + } + + val, err := t.tok.ReadValue(tokenSymbol) + if err != nil { + return NoType, err + } + + switch val { + case "null": + return NullType, nil + case "bool": + return BoolType, nil + case "int": + return IntType, nil + case "float": + return FloatType, nil + case "decimal": + return DecimalType, nil + case "timestamp": + return TimestampType, nil + case "symbol": + return SymbolType, nil + case "string": + return StringType, nil + case "blob": + return BlobType, nil + case "clob": + return ClobType, nil + case "list": + return ListType, nil + case "struct": + return StructType, nil + case "sexp": + return SexpType, nil + default: + msg := fmt.Sprintf("invalid symbol null.%v", t.tok.Token()) + return NoType, &SyntaxError{msg, t.tok.Pos() - 1} + } +} + +// OnNumber handles finding a number token. +func (t *textReader) onNumber(tok token) error { + var valueType Type + var value interface{} + + switch tok { + case tokenBinary: + val, err := t.tok.ReadValue(tok) + if err != nil { + return err + } + + valueType = IntType + value, err = parseInt(val, 2) + if err != nil { + return err + } + + case tokenHex: + val, err := t.tok.ReadValue(tok) + if err != nil { + return err + } + + valueType = IntType + value, err = parseInt(val, 16) + if err != nil { + return err + } + + case tokenNumber: + val, tt, err := t.tok.ReadNumber() + if err != nil { + return err + } + + valueType = tt + + switch tt { + case IntType: + value, err = parseInt(val, 10) + case FloatType: + value, err = parseFloat(val) + case DecimalType: + value, err = parseDecimal(val) + default: + panic(fmt.Sprintf("unexpected type %v", tt)) + } + + if err != nil { + return err + } + + case tokenFloatInf: + valueType = FloatType + value = math.Inf(1) + + case tokenFloatMinusInf: + valueType = FloatType + value = math.Inf(-1) + + default: + panic(fmt.Sprintf("unexpected token type %v", tok)) + } + + t.state = t.stateAfterValue() + t.valueType = valueType + t.value = value + + return nil +} + +// OnTimestamp handles finding a timestamp token. +func (t *textReader) onTimestamp() error { + val, err := t.tok.ReadValue(tokenTimestamp) + if err != nil { + return err + } + + value, err := parseTimestamp(val) + if err != nil { + return err + } + + t.state = t.stateAfterValue() + t.valueType = TimestampType + t.value = value + + return nil +} + +// OnLob handles finding a [bc]lob token. +func (t *textReader) onLob() error { + c, err := t.tok.SkipLobWhitespace() + if err != nil { + return err + } + + var ( + valType Type + val []byte + ) + + if c == '"' { + // Short clob. + valType = ClobType + + str, err := t.tok.ReadShortClob() + if err != nil { + return err + } + + val = []byte(str) + + } else if c == '\'' { + // Long clob. + ok, err := t.tok.IsTripleQuote() + if err != nil { + return err + } + if !ok { + return t.tok.invalidChar(c) + } + + valType = ClobType + + str, err := t.tok.ReadLongClob() + if err != nil { + return err + } + + val = []byte(str) + + } else { + // Normal blob. + valType = BlobType + t.tok.unread(c) + + b64, err := t.tok.ReadBlob() + if err != nil { + return err + } + + val, err = base64.StdEncoding.DecodeString(b64) + if err != nil { + return err + } + } + + t.state = t.stateAfterValue() + t.valueType = valType + t.value = val + + return nil +} + +// FinishValue finishes reading the current value, if there is one. +func (t *textReader) finishValue() error { + ok, err := t.tok.FinishValue() + if err != nil { + return err + } + + if ok { + t.state = t.stateAfterValue() + } + + return nil +} + +func (t *textReader) stateAfterValue() trs { + ctx := t.ctx.peek() + switch ctx { + case ctxInList, ctxInStruct: + return trsAfterValue + case ctxInSexp, ctxAtTopLevel: + return trsBeforeTypeAnnotations + default: + panic(fmt.Sprintf("invalid ctx %v", ctx)) + } +} + +// Explode explodes the reader state when something unexpected +// happens and further calls to Next are a bad idea. +func (t *textReader) explode(err error) { + t.state = trsDone + t.err = err +} diff --git a/ion/textreader_test.go b/ion/textreader_test.go new file mode 100644 index 00000000..68a84ba8 --- /dev/null +++ b/ion/textreader_test.go @@ -0,0 +1,788 @@ +package ion + +import ( + "bytes" + "math" + "math/big" + "testing" + "time" +) + +func TestIgnoreValues(t *testing.T) { + r := NewReaderStr("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") + + _next(t, r, SexpType) + _next(t, r, StructType) + _next(t, r, ListType) + + _symbol(t, r, "foo") + _eof(t, r) +} + +func TestReadSexps(t *testing.T) { + test := func(str string, f containerhandler) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _sexp(t, r, f) + _eof(t, r) + }) + } + + test("(\t)", func(t *testing.T, r Reader) { + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } + }) + + test("(foo)", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + }) + + test("(foo bar baz :: boop)", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + _symbol(t, r, "bar") + _symbolAF(t, r, "", []string{"baz"}, "boop") + }) +} + +func TestStructs(t *testing.T) { + test := func(str string, f containerhandler) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _struct(t, r, f) + _eof(t, r) + }) + } + + test("{\r\n}", func(t *testing.T, r Reader) { + _eof(t, r) + }) + + test("{foo : bar :: baz}", func(t *testing.T, r Reader) { + _symbolAF(t, r, "foo", []string{"bar"}, "baz") + }) + + test("{foo: a, bar: b, baz: c}", func(t *testing.T, r Reader) { + _symbolAF(t, r, "foo", nil, "a") + _symbolAF(t, r, "bar", nil, "b") + _symbolAF(t, r, "baz", nil, "c") + }) +} + +func TestMultipleStructs(t *testing.T) { + r := NewReaderStr("{} {} {}") + + for i := 0; i < 3; i++ { + _struct(t, r, func(t *testing.T, r Reader) { + _eof(t, r) + }) + } + + _eof(t, r) +} + +func TestNullStructs(t *testing.T) { + r := NewReaderStr("null.struct 'null'::{foo:bar}") + + _null(t, r, StructType) + _nextAF(t, r, StructType, "", []string{"null"}) + _eof(t, r) +} + +func TestLists(t *testing.T) { + test := func(str string, f containerhandler) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _list(t, r, f) + _eof(t, r) + }) + } + + test("[ ]", func(t *testing.T, r Reader) { + _eof(t, r) + }) + + test("[foo]", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + _eof(t, r) + }) + + test("[foo, bar, baz::boop]", func(t *testing.T, r Reader) { + _symbol(t, r, "foo") + _symbol(t, r, "bar") + _symbolAF(t, r, "", []string{"baz"}, "boop") + _eof(t, r) + }) +} + +func TestReadNestedLists(t *testing.T) { + empty := func(t *testing.T, r Reader) { + _eof(t, r) + } + + r := NewReaderStr("[[], [[]]]") + + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, empty) + + _list(t, r, func(t *testing.T, r Reader) { + _list(t, r, empty) + }) + + _eof(t, r) + }) + + _eof(t, r) +} + +func TestClobs(t *testing.T) { + test := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _next(t, r, ClobType) + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + + _eof(t, r) + }) + } + + test("{{\"\"}}", []byte{}) + test("{{ \"hello world\" }}", []byte("hello world")) + test("{{'''hello world'''}}", []byte("hello world")) + test("{{'''hello'''\n'''world'''}}", []byte("helloworld")) +} + +func TestBlobs(t *testing.T) { + test := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _next(t, r, BlobType) + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + + _eof(t, r) + }) + } + + test("{{}}", []byte{}) + test("{{AA==}}", []byte{0}) + test("{{ SGVsbG8g\r\nV29ybGQ= }}", []byte("Hello World")) +} + +func TestTimestamps(t *testing.T) { + testA := func(str string, etas []string, eval time.Time) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _nextAF(t, r, TimestampType, "", etas) + + val, err := r.TimeValue() + if err != nil { + t.Fatal(err) + } + if !val.Equal(eval) { + t.Errorf("expected %v, got %v", eval, val) + } + + _eof(t, r) + }) + } + + test := func(str string, eval time.Time) { + testA(str, nil, eval) + } + + et := time.Date(2001, time.January, 1, 0, 0, 0, 0, time.UTC) + test("2001T", et) + test("2001-01T", et) + test("2001-01-01", et) + test("2001-01-01T", et) + test("2001-01-01T00:00Z", et) + test("2001-01-01T00:00:00Z", et) + test("2001-01-01T00:00:00.000Z", et) + test("2001-01-01T00:00:00.000+00:00", et) + test("2001-01-01T00:00:00.000000Z", et) + test("2001-01-01T00:00:00.000000000Z", et) + test("2001-01-01T00:00:00.000000000999Z", et) // We truncate, at least for now. + + testA("foo::'bar'::2001-01-01T00:00:00.000Z", []string{"foo", "bar"}, et) +} + +func TestDecimals(t *testing.T) { + testA := func(str string, etas []string, eval string) { + t.Run(str, func(t *testing.T) { + ee := MustParseDecimal(eval) + + r := NewReaderStr(str) + _nextAF(t, r, DecimalType, "", etas) + + val, err := r.DecimalValue() + if err != nil { + t.Fatal(err) + } + if !ee.Equal(val) { + t.Errorf("expected %v, got %v", ee, val) + } + + _eof(t, r) + }) + } + + test := func(str string, eval string) { + testA(str, nil, eval) + } + + test("123.", "123") + test("123.0", "123") + test("123.456", "123.456") + test("123d2", "12300") + test("123d+2", "12300") + test("123d-2", "1.23") + + testA(" foo :: 'bar' :: 123. ", []string{"foo", "bar"}, "123") +} + +func TestFloats(t *testing.T) { + testA := func(str string, etas []string, eval float64) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _floatAF(t, r, "", etas, eval) + _eof(t, r) + }) + } + + test := func(str string, eval float64) { + testA(str, nil, eval) + } + + test("1e100\n", 1e100) + test("1.2e+0", 1.2) + test("-123.456e-78", -123.456e-78) + test("+inf", math.Inf(1)) + test("-inf", math.Inf(-1)) + + testA("foo::'bar'::1e100", []string{"foo", "bar"}, 1e100) +} + +func TestInts(t *testing.T) { + test := func(str string, f func(*testing.T, Reader)) { + t.Run(str, func(t *testing.T) { + r := NewReaderStr(str) + _next(t, r, IntType) + + f(t, r) + + _eof(t, r) + }) + } + + test("null.int", func(t *testing.T, r Reader) { + if !r.IsNull() { + t.Fatal("expected isnull=true, got false") + } + }) + + testInt := func(str string, eval int) { + test(str, func(t *testing.T, r Reader) { + val, err := r.IntValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + testInt("0", 0) + testInt("12_345", 12345) + testInt("-1_2_3_4_5", -12345) + testInt("0b00_0101", 5) + testInt("-0b00_0101", -5) + testInt("0x01_02_0e_0F", 0x01020e0f) + testInt("-0x0102_0e0F", -0x01020e0f) + + testInt64 := func(str string, eval int64) { + test(str, func(t *testing.T, r Reader) { + val, err := r.Int64Value() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + testInt64("0x123_FFFF_FFFF", 0x123FFFFFFFF) + testInt64("-0x123_FFFF_FFFF", -0x123FFFFFFFF) + + testBigInt := func(str string, estr string) { + test(str, func(t *testing.T, r Reader) { + val, err := r.BigIntValue() + if err != nil { + t.Fatal(err) + } + + eval, _ := (&big.Int{}).SetString(estr, 0) + if eval.Cmp(val) != 0 { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + testBigInt("0xEFFF_FFFF_FFFF_FFFF", "0xEFFFFFFFFFFFFFFF") + testBigInt("0xFFFF_FFFF_FFFF_FFFF", "0xFFFFFFFFFFFFFFFF") + testBigInt("-0x1_FFFF_FFFF_FFFF_FFFF", "-0x1FFFFFFFFFFFFFFFF") +} + +func TestStrings(t *testing.T) { + r := NewReaderStr(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) + + _stringAF(t, r, "", []string{"foo"}, "bar") + _string(t, r, "baz") + _stringAF(t, r, "", []string{"a", "b"}, "beepboop") + _null(t, r, StringType) + + _eof(t, r) +} + +func TestSymbols(t *testing.T) { + r := NewReaderStr("'null'::foo bar a::b::'baz' null.symbol") + + _symbolAF(t, r, "", []string{"null"}, "foo") + _symbol(t, r, "bar") + _symbolAF(t, r, "", []string{"a", "b"}, "baz") + _null(t, r, SymbolType) + + _eof(t, r) +} + +func TestSpecialSymbols(t *testing.T) { + r := NewReaderStr("null\nnull.struct\ntrue\nfalse\nnan") + + _null(t, r, NullType) + _null(t, r, StructType) + + _bool(t, r, true) + _bool(t, r, false) + _float(t, r, math.NaN()) + _eof(t, r) +} + +func TestOperators(t *testing.T) { + r := NewReaderStr("(a*(b+c))") + + _sexp(t, r, func(t *testing.T, r Reader) { + _symbol(t, r, "a") + _symbol(t, r, "*") + _sexp(t, r, func(t *testing.T, r Reader) { + _symbol(t, r, "b") + _symbol(t, r, "+") + _symbol(t, r, "c") + _eof(t, r) + }) + _eof(t, r) + }) +} + +func TestTopLevelOperators(t *testing.T) { + r := NewReaderStr("a + b") + + _symbol(t, r, "a") + + if r.Next() { + t.Errorf("next returned true") + } + if r.Err() == nil { + t.Error("no error") + } +} + +func TestTrsToString(t *testing.T) { + for i := trsDone; i <= trsAfterValue+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected a non-empty string for trs %v", uint8(i)) + } + } +} + +type containerhandler func(t *testing.T, r Reader) + +func _sexp(t *testing.T, r Reader, f containerhandler) { + _sexpAF(t, r, "", nil, f) +} + +func _sexpAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { + _containerAF(t, r, SexpType, efn, etas, f) +} + +func _struct(t *testing.T, r Reader, f containerhandler) { + _structAF(t, r, "", nil, f) +} + +func _structAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { + _containerAF(t, r, StructType, efn, etas, f) +} + +func _list(t *testing.T, r Reader, f containerhandler) { + _listAF(t, r, "", nil, f) +} + +func _listAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { + _containerAF(t, r, ListType, efn, etas, f) +} + +func _containerAF(t *testing.T, r Reader, et Type, efn string, etas []string, f containerhandler) { + _nextAF(t, r, et, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.%v", et, et) + } + + if err := r.StepIn(); err != nil { + t.Fatal(err) + } + + f(t, r) + + if err := r.StepOut(); err != nil { + t.Fatal(err) + } +} + +func _int(t *testing.T, r Reader, eval int) { + _intAF(t, r, "", nil, eval) +} + +func _intAF(t *testing.T, r Reader, efn string, etas []string, eval int) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != Int32 { + t.Errorf("expected size=Int32, got %v", size) + } + + val, err := r.IntValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _int64(t *testing.T, r Reader, eval int64) { + _int64AF(t, r, "", nil, eval) +} + +func _int64AF(t *testing.T, r Reader, efn string, etas []string, eval int64) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != Int64 { + t.Errorf("expected size=Int64, got %v", size) + } + + val, err := r.Int64Value() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _uint(t *testing.T, r Reader, eval uint64) { + _uintAF(t, r, "", nil, eval) +} + +func _uintAF(t *testing.T, r Reader, efn string, etas []string, eval uint64) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != Uint64 { + t.Errorf("expected size=Uint, got %v", size) + } + + val, err := r.Uint64Value() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _bigInt(t *testing.T, r Reader, eval *big.Int) { + _bigIntAF(t, r, "", nil, eval) +} + +func _bigIntAF(t *testing.T, r Reader, efn string, etas []string, eval *big.Int) { + _nextAF(t, r, IntType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.int", eval) + } + + size, err := r.IntSize() + if err != nil { + t.Fatal(err) + } + if size != BigInt { + t.Errorf("expected size=BigInt, got %v", size) + } + + val, err := r.BigIntValue() + if err != nil { + t.Fatal(err) + } + if val.Cmp(eval) != 0 { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _float(t *testing.T, r Reader, eval float64) { + _floatAF(t, r, "", nil, eval) +} + +func _floatAF(t *testing.T, r Reader, efn string, etas []string, eval float64) { + _nextAF(t, r, FloatType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.float", eval) + } + + val, err := r.FloatValue() + if err != nil { + t.Fatal(err) + } + + if math.IsNaN(eval) { + if !math.IsNaN(val) { + t.Errorf("expected %v, got %v", eval, val) + } + } else if eval != val { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _decimal(t *testing.T, r Reader, eval *Decimal) { + _decimalAF(t, r, "", nil, eval) +} + +func _decimalAF(t *testing.T, r Reader, efn string, etas []string, eval *Decimal) { + _nextAF(t, r, DecimalType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.decimal", eval) + } + + val, err := r.DecimalValue() + if err != nil { + t.Fatal(err) + } + + if !eval.Equal(val) { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _timestamp(t *testing.T, r Reader, eval time.Time) { + _timestampAF(t, r, "", nil, eval) +} + +func _timestampAF(t *testing.T, r Reader, efn string, etas []string, eval time.Time) { + _nextAF(t, r, TimestampType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.timestamp", eval) + } + + val, err := r.TimeValue() + if err != nil { + t.Fatal(err) + } + + if !val.Equal(eval) { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _string(t *testing.T, r Reader, eval string) { + _stringAF(t, r, "", nil, eval) +} + +func _stringAF(t *testing.T, r Reader, efn string, etas []string, eval string) { + _nextAF(t, r, StringType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.string", eval) + } + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _symbol(t *testing.T, r Reader, eval string) { + _symbolAF(t, r, "", nil, eval) +} + +func _symbolAF(t *testing.T, r Reader, efn string, etas []string, eval string) { + _nextAF(t, r, SymbolType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.symbol", eval) + } + + val, err := r.StringValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _bool(t *testing.T, r Reader, eval bool) { + _boolAF(t, r, "", nil, eval) +} + +func _boolAF(t *testing.T, r Reader, efn string, etas []string, eval bool) { + _nextAF(t, r, BoolType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.bool", eval) + } + + val, err := r.BoolValue() + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _clob(t *testing.T, r Reader, eval []byte) { + _clobAF(t, r, "", nil, eval) +} + +func _clobAF(t *testing.T, r Reader, efn string, etas []string, eval []byte) { + _nextAF(t, r, ClobType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.clob", eval) + } + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _blob(t *testing.T, r Reader, eval []byte) { + _blobAF(t, r, "", nil, eval) +} + +func _blobAF(t *testing.T, r Reader, efn string, etas []string, eval []byte) { + _nextAF(t, r, BlobType, efn, etas) + if r.IsNull() { + t.Fatalf("expected %v, got null.blob", eval) + } + + val, err := r.ByteValue() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } +} + +func _null(t *testing.T, r Reader, et Type) { + _nullAF(t, r, et, "", nil) +} + +func _nullAF(t *testing.T, r Reader, et Type, efn string, etas []string) { + _nextAF(t, r, et, efn, etas) + if !r.IsNull() { + t.Error("isnull returned false") + } +} + +func _next(t *testing.T, r Reader, et Type) { + _nextAF(t, r, et, "", nil) +} + +func _nextAF(t *testing.T, r Reader, et Type, efn string, etas []string) { + if !r.Next() { + t.Fatal(r.Err()) + } + if r.Type() != et { + t.Fatalf("expected %v, got %v", et, r.Type()) + } + + if efn != r.FieldName() { + t.Errorf("expected fieldname=%v, got %v", efn, r.FieldName()) + } + if !_strequals(etas, r.Annotations()) { + t.Errorf("expected type annotations=%v, got %v", etas, r.Annotations()) + } +} + +func _strequals(a, b []string) bool { + if len(a) != len(b) { + return false + } + + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + + return true +} + +func _eof(t *testing.T, r Reader) { + if r.Next() { + t.Fatal("next returned true") + } + if r.Err() != nil { + t.Fatal(r.Err()) + } +} diff --git a/ion/textutils.go b/ion/textutils.go new file mode 100644 index 00000000..e4a03de4 --- /dev/null +++ b/ion/textutils.go @@ -0,0 +1,382 @@ +package ion + +import ( + "fmt" + "io" + "math/big" + "strconv" + "strings" + "time" +) + +// Does this symbol need to be quoted in text form? +func symbolNeedsQuoting(sym string) bool { + switch sym { + case "", "null", "true", "false", "nan": + return true + } + + if !isIdentifierStart(int(sym[0])) { + return true + } + + for i := 1; i < len(sym); i++ { + if !isIdentifierPart(int(sym[i])) { + return true + } + } + + return false +} + +// Is this the text form of a symbol reference ($)? +func isSymbolRef(sym string) bool { + if len(sym) == 0 || sym[0] != '$' { + return false + } + + if len(sym) == 1 { + return false + } + + for i := 1; i < len(sym); i++ { + if !isDigit(int(sym[i])) { + return false + } + } + + return true +} + +// Is this a valid first character for an identifier? +func isIdentifierStart(c int) bool { + if c >= 'a' && c <= 'z' { + return true + } + if c >= 'A' && c <= 'Z' { + return true + } + if c == '_' || c == '$' { + return true + } + return false +} + +// Is this a valid character for later in an identifier? +func isIdentifierPart(c int) bool { + return isIdentifierStart(c) || isDigit(c) +} + +// Is this a valid hex digit? +func isHexDigit(c int) bool { + if isDigit(c) { + return true + } + if c >= 'a' && c <= 'f' { + return true + } + if c >= 'A' && c <= 'F' { + return true + } + return false +} + +// Is this a digit? +func isDigit(c int) bool { + return c >= '0' && c <= '9' +} + +// Is this a valid part of an operator symbol? +func isOperatorChar(c int) bool { + switch c { + case '!', '#', '%', '&', '*', '+', '-', '.', '/', ';', '<', '=', + '>', '?', '@', '^', '`', '|', '~': + return true + default: + return false + } +} + +// Does this character mark the end of a normal (unquoted) value? Does +// *not* check for the start of a comment, because that requires two +// characters. Use tokenizer.isStopChar(c) or check for it yourself. +func isStopChar(c int) bool { + switch c { + case -1, '{', '}', '[', ']', '(', ')', ',', '"', '\'', + ' ', '\t', '\n', '\r': + return true + default: + return false + } +} + +// Is this character whitespace? +func isWhitespace(c int) bool { + switch c { + case ' ', '\t', '\n', '\r': + return true + } + return false +} + +// Formats a float64 in Ion text style. +func formatFloat(val float64) string { + str := strconv.FormatFloat(val, 'e', -1, 64) + + // Ion uses lower case for special values. + switch str { + case "NaN": + return "nan" + case "+Inf": + return "+inf" + case "-Inf": + return "-inf" + } + + idx := strings.Index(str, "e") + if idx < 0 { + // We need to add an 'e' or it will get interpreted as an Ion decimal. + str += "e0" + } else if idx+2 < len(str) && str[idx+2] == '0' { + // FormatFloat returns exponents with a leading ±0 in some cases; strip it. + str = str[:idx+2] + str[idx+3:] + } + + return str +} + +// Write the given symbol out, quoting and encoding if necessary. +func writeSymbol(sym string, out io.Writer) error { + if symbolNeedsQuoting(sym) { + if err := writeRawChar('\'', out); err != nil { + return err + } + if err := writeEscapedSymbol(sym, out); err != nil { + return err + } + return writeRawChar('\'', out) + } + return writeRawString(sym, out) +} + +// Write the given symbol out, escaping any characters that need escaping. +func writeEscapedSymbol(sym string, out io.Writer) error { + for i := 0; i < len(sym); i++ { + c := sym[i] + if c < 32 || c == '\\' || c == '\'' { + if err := writeEscapedChar(c, out); err != nil { + return err + } + } else { + if err := writeRawChar(c, out); err != nil { + return err + } + } + } + return nil +} + +// Write the given string out, escaping any characters that need escaping. +func writeEscapedString(str string, out io.Writer) error { + for i := 0; i < len(str); i++ { + c := str[i] + if c < 32 || c == '\\' || c == '"' { + if err := writeEscapedChar(c, out); err != nil { + return err + } + } else { + if err := writeRawChar(c, out); err != nil { + return err + } + } + } + return nil +} + +// Write out the given character in escaped form. +func writeEscapedChar(c byte, out io.Writer) error { + switch c { + case 0: + return writeRawString("\\0", out) + case '\a': + return writeRawString("\\a", out) + case '\b': + return writeRawString("\\b", out) + case '\t': + return writeRawString("\\t", out) + case '\n': + return writeRawString("\\n", out) + case '\f': + return writeRawString("\\f", out) + case '\r': + return writeRawString("\\r", out) + case '\v': + return writeRawString("\\v", out) + case '\'': + return writeRawString("\\'", out) + case '"': + return writeRawString("\\\"", out) + case '\\': + return writeRawString("\\\\", out) + default: + buf := []byte{'\\', 'x', hexChars[(c>>4)&0xF], hexChars[c&0xF]} + return writeRawChars(buf, out) + } +} + +// Write out the given raw string. +func writeRawString(s string, out io.Writer) error { + _, err := out.Write([]byte(s)) + return err +} + +// Write out the given raw character sequence. +func writeRawChars(cs []byte, out io.Writer) error { + _, err := out.Write(cs) + return err +} + +// Write out the given raw character. +func writeRawChar(c byte, out io.Writer) error { + _, err := out.Write([]byte{c}) + return err +} + +func parseFloat(str string) (float64, error) { + val, err := strconv.ParseFloat(str, 64) + if err != nil { + if ne, ok := err.(*strconv.NumError); ok { + if ne.Err == strconv.ErrRange { + // Ignore me, val will be +-inf which is fine. + return val, nil + } + } + } + return val, err +} + +func parseDecimal(str string) (*Decimal, error) { + return ParseDecimal(str) +} + +func parseInt(str string, radix int) (interface{}, error) { + digits := str + + switch radix { + case 10: + // All set. + + case 2, 16: + neg := false + if digits[0] == '-' { + neg = true + digits = digits[1:] + } + + // Skip over the '0x' prefix. + digits = digits[2:] + if neg { + digits = "-" + digits + } + + default: + panic("unsupported radix") + } + + i, err := strconv.ParseInt(digits, radix, 64) + if err == nil { + return i, nil + } + if err.(*strconv.NumError).Err != strconv.ErrRange { + return nil, err + } + + bi, ok := (&big.Int{}).SetString(digits, radix) + if !ok { + return nil, &strconv.NumError{ + Func: "ParseInt", + Num: str, + Err: strconv.ErrSyntax, + } + } + + return bi, nil +} + +func parseTimestamp(val string) (time.Time, error) { + if len(val) < 5 { + return invalidTimestamp(val) + } + + year, err := strconv.ParseInt(val[:4], 10, 32) + if err != nil { + return invalidTimestamp(val) + } + if len(val) == 5 && (val[4] == 't' || val[4] == 'T') { + // yyyyT + return time.Date(int(year), 1, 1, 0, 0, 0, 0, time.UTC), nil + } + if val[4] != '-' { + return invalidTimestamp(val) + } + + if len(val) < 8 { + return invalidTimestamp(val) + } + + month, err := strconv.ParseInt(val[5:7], 10, 32) + if err != nil { + return invalidTimestamp(val) + } + + if len(val) == 8 && (val[7] == 't' || val[7] == 'T') { + // yyyy-mmT + return time.Date(int(year), time.Month(month), 1, 0, 0, 0, 0, time.UTC), nil + } + if val[7] != '-' { + return invalidTimestamp(val) + } + + if len(val) < 10 { + return invalidTimestamp(val) + } + + day, err := strconv.ParseInt(val[8:10], 10, 32) + if err != nil { + return invalidTimestamp(val) + } + + if len(val) == 10 || (len(val) == 11 && (val[10] == 't' || val[10] == 'T')) { + // yyyy-mm-dd or yyyy-mm-ddT + return time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC), nil + } + if val[10] != 't' && val[10] != 'T' { + return invalidTimestamp(val) + } + + if len(val) < 17 { + return invalidTimestamp(val) + } + if val[16] != ':' { + return time.Parse("2006-01-02T15:04Z07:00", val) + } + + if len(val) > 19 && val[19] == '.' { + i := 20 + for i < len(val) && isDigit(int(val[i])) { + i++ + } + + if i >= 29 { + // Too much precision for a go Time. + // TODO: We should probably round instead of truncating? Ah well. + return time.Parse(time.RFC3339Nano, val[:29]+val[i:]) + } + } + + return time.Parse(time.RFC3339Nano, val) +} + +func invalidTimestamp(val string) (time.Time, error) { + return time.Time{}, fmt.Errorf("ion: invalid timestamp: %v", val) +} diff --git a/ion/textutils_test.go b/ion/textutils_test.go new file mode 100644 index 00000000..67144a82 --- /dev/null +++ b/ion/textutils_test.go @@ -0,0 +1,164 @@ +package ion + +import ( + "strings" + "testing" + "time" +) + +func TestParseTimestamp(t *testing.T) { + test := func(str string, eval string) { + t.Run(str, func(t *testing.T) { + val, err := parseTimestamp(str) + if err != nil { + t.Fatal(err) + } + + et, err := time.Parse(time.RFC3339Nano, eval) + if err != nil { + t.Fatal(err) + } + + if !val.Equal(et) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("1234T", "1234-01-01T00:00:00Z") + test("1234-05T", "1234-05-01T00:00:00Z") + test("1234-05-06", "1234-05-06T00:00:00Z") + test("1234-05-06T", "1234-05-06T00:00:00Z") + test("1234-05-06T07:08Z", "1234-05-06T07:08:00Z") + test("1234-05-06T07:08:09Z", "1234-05-06T07:08:09Z") + test("1234-05-06T07:08:09.100Z", "1234-05-06T07:08:09.100Z") + test("1234-05-06T07:08:09.100100Z", "1234-05-06T07:08:09.100100Z") + + test("1234-05-06T07:08+09:10", "1234-05-06T07:08:00+09:10") + test("1234-05-06T07:08:09-10:11", "1234-05-06T07:08:09-10:11") +} + +func TestWriteSymbol(t *testing.T) { + test := func(sym, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeSymbol(sym, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("expected \"%v\", got \"%v\"", expected, actual) + } + }) + } + + test("", "''") + test("null", "'null'") + test("null.null", "'null.null'") + + test("basic", "basic") + test("_basic_", "_basic_") + test("$basic$", "$basic$") + test("$123", "$123") + + test("123", "'123'") + test("abc'def", "'abc\\'def'") + test("abc\"def", "'abc\"def'") +} + +func TestSymbolNeedsQuoting(t *testing.T) { + test := func(sym string, expected bool) { + t.Run(sym, func(t *testing.T) { + actual := symbolNeedsQuoting(sym) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("", true) + test("null", true) + test("true", true) + test("false", true) + test("nan", true) + + test("basic", false) + test("_basic_", false) + test("basic$123", false) + test("$", false) + test("$basic", false) + test("$123", false) + + test("123", true) + test("abc.def", true) + test("abc,def", true) + test("abc:def", true) + test("abc{def", true) + test("abc}def", true) + test("abc[def", true) + test("abc]def", true) + test("abc'def", true) + test("abc\"def", true) +} + +func TestIsSymbolRef(t *testing.T) { + test := func(sym string, expected bool) { + t.Run(sym, func(t *testing.T) { + actual := isSymbolRef(sym) + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } + }) + } + + test("", false) + test("1", false) + test("a", false) + test("$", false) + test("$1", true) + test("$1234567890", true) + test("$a", false) + test("$1234a567890", false) +} + +func TestWriteEscapedSymbol(t *testing.T) { + test := func(sym, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeEscapedSymbol(sym, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("bad encoding of \"%v\": \"%v\"", + expected, actual) + } + }) + } + + test("basic", "basic") + test("\"basic\"", "\"basic\"") + test("o'clock", "o\\'clock") + test("c:\\", "c:\\\\") +} + +func TestWriteEscapedChar(t *testing.T) { + test := func(c byte, expected string) { + t.Run(expected, func(t *testing.T) { + buf := strings.Builder{} + if err := writeEscapedChar(c, &buf); err != nil { + t.Fatal(err) + } + actual := buf.String() + if actual != expected { + t.Errorf("bad encoding of '%v': \"%v\"", + expected, actual) + } + }) + } + + test(0, "\\0") + test('\n', "\\n") + test(1, "\\x01") + test('\xFF', "\\xFF") +} diff --git a/ion/textwriter.go b/ion/textwriter.go new file mode 100644 index 00000000..2ae82c1f --- /dev/null +++ b/ion/textwriter.go @@ -0,0 +1,359 @@ +package ion + +import ( + "encoding/base64" + "fmt" + "io" + "math/big" + "time" +) + +// TextWriterOpts defines a set of bit flag options for text writers. +type TextWriterOpts uint8 + +const ( + // TextWriterQuietFinish disables emiting a newline in Finish(). Convenient if you + // know you're only emiting one datagram; dangerous if there's a chance you're going + // to emit another datagram using the same Writer. + TextWriterQuietFinish TextWriterOpts = 1 +) + +// textWriter is a writer that writes human-readable text +type textWriter struct { + writer + needsSeparator bool + opts TextWriterOpts +} + +// NewTextWriter returns a new text writer. +func NewTextWriter(out io.Writer) Writer { + return NewTextWriterOpts(out, 0) +} + +// NewTextWriterOpts returns a new text writer with the given options. +func NewTextWriterOpts(out io.Writer, opts TextWriterOpts) Writer { + return &textWriter{ + writer: writer{ + out: out, + }, + opts: opts, + } +} + +// WriteNull writes an untyped null. +func (w *textWriter) WriteNull() error { + return w.writeValue("Writer.WriteNull", textNulls[NoType]) +} + +// WriteNullType writes a typed null. +func (w *textWriter) WriteNullType(t Type) error { + return w.writeValue("Writer.WriteNullType", textNulls[t]) +} + +// WriteBool writes a boolean value. +func (w *textWriter) WriteBool(val bool) error { + str := "false" + if val { + str = "true" + } + return w.writeValue("Writer.WriteBool", str) +} + +// WriteInt writes an integer value. +func (w *textWriter) WriteInt(val int64) error { + return w.writeValue("Writer.WriteInt", fmt.Sprintf("%d", val)) +} + +// WriteUint writes an unsigned integer value. +func (w *textWriter) WriteUint(val uint64) error { + return w.writeValue("Writer.WriteUint", fmt.Sprintf("%d", val)) +} + +// WriteBigInt writes a (big) integer value. +func (w *textWriter) WriteBigInt(val *big.Int) error { + return w.writeValue("Writer.WriteBigInt", val.String()) +} + +// WriteFloat writes a floating-point value. +func (w *textWriter) WriteFloat(val float64) error { + return w.writeValue("Writer.WriteFloat", formatFloat(val)) +} + +// WriteDecimal writes an arbitrary-precision decimal value. +func (w *textWriter) WriteDecimal(val *Decimal) error { + return w.writeValue("Writer.WriteDecimal", val.String()) +} + +// WriteTimestamp writes a timestamp. +func (w *textWriter) WriteTimestamp(val time.Time) error { + return w.writeValue("Writer.WriteTimestamp", val.Format(time.RFC3339Nano)) +} + +// WriteSymbol writes a symbol. +func (w *textWriter) WriteSymbol(val string) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteSymbol"); w.err != nil { + return w.err + } + + if w.err = writeSymbol(val, w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil +} + +// WriteString writes a string. +func (w *textWriter) WriteString(val string) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteString"); w.err != nil { + return w.err + } + + if w.err = writeRawChar('"', w.out); w.err != nil { + return w.err + } + if w.err = writeEscapedString(val, w.out); w.err != nil { + return w.err + } + if w.err = writeRawChar('"', w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil +} + +// WriteClob writes a clob. +func (w *textWriter) WriteClob(val []byte) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { + return w.err + } + + if w.err = writeRawString("{{\"", w.out); w.err != nil { + return w.err + } + for _, c := range val { + if c < 32 || c == '\\' || c == '"' || c > 0x7F { + if err := writeEscapedChar(c, w.out); err != nil { + return err + } + } else { + if err := writeRawChar(c, w.out); err != nil { + return err + } + } + } + if w.err = writeRawString("\"}}", w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil +} + +// WriteBlob writes a blob. +func (w *textWriter) WriteBlob(val []byte) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { + return w.err + } + + if w.err = writeRawString("{{", w.out); w.err != nil { + return w.err + } + + enc := base64.NewEncoder(base64.StdEncoding, w.out) + enc.Write(val) + if w.err = enc.Close(); w.err != nil { + return w.err + } + + if w.err = writeRawString("}}", w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil +} + +// BeginList begins writing a list. +func (w *textWriter) BeginList() error { + if w.err == nil { + w.err = w.begin("Writer.BeginList", ctxInList, '[') + } + return w.err +} + +// EndList finishes writing a list. +func (w *textWriter) EndList() error { + if w.err == nil { + w.err = w.end("Writer.EndList", ctxInList, ']') + } + return w.err +} + +// BeginSexp begins writing an s-expression. +func (w *textWriter) BeginSexp() error { + if w.err == nil { + w.err = w.begin("Writer.BeginSexp", ctxInSexp, '(') + } + return w.err +} + +// EndSexp finishes writing an s-expression. +func (w *textWriter) EndSexp() error { + if w.err == nil { + w.err = w.end("Writer.EndSexp", ctxInSexp, ')') + } + return w.err +} + +// BeginStruct begins writing a struct. +func (w *textWriter) BeginStruct() error { + if w.err == nil { + w.err = w.begin("Writer.BeginStruct", ctxInStruct, '{') + } + return w.err +} + +// EndStruct finishes writing a struct. +func (w *textWriter) EndStruct() error { + if w.err == nil { + w.err = w.end("Writer.EndStruct", ctxInStruct, '}') + } + return w.err +} + +// Finish finishes writing the current datagram. +func (w *textWriter) Finish() error { + if w.err != nil { + return w.err + } + if w.ctx.peek() != ctxAtTopLevel { + return &UsageError{"Writer.Finish", "not at top level"} + } + + if w.opts&TextWriterQuietFinish == 0 { + if w.err = writeRawChar('\n', w.out); w.err != nil { + return w.err + } + w.needsSeparator = false + } + + w.clear() + return nil +} + +// writeValue writes a stringified value to the output stream. +func (w *textWriter) writeValue(api string, val string) error { + if w.err != nil { + return w.err + } + if w.err = w.beginValue(api); w.err != nil { + return w.err + } + + if w.err = writeRawString(val, w.out); w.err != nil { + return w.err + } + + w.endValue() + return nil +} + +// beginValue begins the process of writing a value, by writing out +// a separator (if needed), field name (if in a struct), and type +// annotations (if any). +func (w *textWriter) beginValue(api string) error { + if w.needsSeparator { + var sep byte + switch w.ctx.peek() { + case ctxInStruct, ctxInList: + sep = ',' + case ctxInSexp: + sep = ' ' + default: + sep = '\n' + } + + if err := writeRawChar(sep, w.out); err != nil { + return err + } + } + + if w.inStruct() { + if w.fieldName == "" { + return &UsageError{api, "field name not set"} + } + name := w.fieldName + w.fieldName = "" + + if err := writeSymbol(name, w.out); err != nil { + return err + } + if err := writeRawChar(':', w.out); err != nil { + return err + } + } + + if len(w.annotations) > 0 { + as := w.annotations + w.annotations = nil + + for _, a := range as { + if err := writeSymbol(a, w.out); err != nil { + return err + } + if err := writeRawString("::", w.out); err != nil { + return err + } + } + } + + return nil +} + +// endValue finishes the process of writing a value. +func (w *textWriter) endValue() { + w.needsSeparator = true +} + +// begin starts writing a container of the given type. +func (w *textWriter) begin(api string, t ctx, c byte) error { + if err := w.beginValue(api); err != nil { + return err + } + + w.ctx.push(t) + w.needsSeparator = false + + return writeRawChar(c, w.out) +} + +// end finishes writing a container of the given type +func (w *textWriter) end(api string, t ctx, c byte) error { + if w.ctx.peek() != t { + return &UsageError{api, "not in that kind of container"} + } + + if err := writeRawChar(c, w.out); err != nil { + return err + } + + w.clear() + w.ctx.pop() + w.endValue() + + return nil +} diff --git a/ion/textwriter_test.go b/ion/textwriter_test.go new file mode 100644 index 00000000..b3124c32 --- /dev/null +++ b/ion/textwriter_test.go @@ -0,0 +1,351 @@ +package ion + +import ( + "math" + "math/big" + "strings" + "testing" + "time" +) + +func TestWriteTextTopLevelFieldName(t *testing.T) { + writeText(func(w Writer) { + if err := w.FieldName("foo"); err == nil { + t.Error("expected an error") + } + }) +} + +func TestWriteTextEmptyStruct(t *testing.T) { + testTextWriter(t, "{}", func(w Writer) { + if err := w.BeginStruct(); err != nil { + t.Fatal(err) + } + + if err := w.EndStruct(); err != nil { + t.Fatal(err) + } + + if err := w.EndStruct(); err == nil { + t.Fatal("no error from ending struct too many times") + } + }) +} + +func TestWriteTextAnnotatedStruct(t *testing.T) { + testTextWriter(t, "foo::$bar::'.baz'::{}", func(w Writer) { + w.Annotation("foo") + w.Annotation("$bar") + w.Annotation(".baz") + w.BeginStruct() + err := w.EndStruct() + + if err != nil { + t.Fatal(err) + } + }) +} + +func TestWriteTextNestedStruct(t *testing.T) { + testTextWriter(t, "{foo:'true'::{},'null':{}}", func(w Writer) { + w.BeginStruct() + + w.FieldName("foo") + w.Annotation("true") + w.BeginStruct() + w.EndStruct() + + w.FieldName("null") + w.BeginStruct() + w.EndStruct() + + w.EndStruct() + }) +} + +func TestWriteTextEmptyList(t *testing.T) { + testTextWriter(t, "[]", func(w Writer) { + if err := w.BeginList(); err != nil { + t.Fatal(err) + } + + if err := w.EndList(); err != nil { + t.Fatal(err) + } + + if err := w.EndList(); err == nil { + t.Error("no error calling endlist at top level") + } + }) +} + +func TestWriteTextNestedLists(t *testing.T) { + testTextWriter(t, "[{},foo::{},'null'::[]]", func(w Writer) { + w.BeginList() + + w.BeginStruct() + w.EndStruct() + + w.Annotation("foo") + w.BeginStruct() + w.EndStruct() + + w.Annotation("null") + w.BeginList() + w.EndList() + + w.EndList() + }) +} + +func TestWriteTextSexps(t *testing.T) { + testTextWriter(t, "()\n(())\n(() ())", func(w Writer) { + w.BeginSexp() + w.EndSexp() + + w.BeginSexp() + w.BeginSexp() + w.EndSexp() + w.EndSexp() + + w.BeginSexp() + w.BeginSexp() + w.EndSexp() + w.BeginSexp() + w.EndSexp() + w.EndSexp() + }) +} + +func TestWriteTextNulls(t *testing.T) { + expected := "[null,foo::null.null,null.bool,null.int,null.float,null.decimal," + + "null.timestamp,null.symbol,null.string,null.clob,null.blob," + + "null.list,'null'::null.sexp,null.struct]" + + testTextWriter(t, expected, func(w Writer) { + w.BeginList() + + w.WriteNull() + w.Annotation("foo") + w.WriteNullType(NullType) + w.WriteNullType(BoolType) + w.WriteNullType(IntType) + w.WriteNullType(FloatType) + w.WriteNullType(DecimalType) + w.WriteNullType(TimestampType) + w.WriteNullType(SymbolType) + w.WriteNullType(StringType) + w.WriteNullType(ClobType) + w.WriteNullType(BlobType) + w.WriteNullType(ListType) + w.Annotation("null") + w.WriteNullType(SexpType) + w.WriteNullType(StructType) + + w.EndList() + }) +} + +func TestWriteTextBool(t *testing.T) { + expected := "true\n(false '123'::true)\n'false'::false" + testTextWriter(t, expected, func(w Writer) { + w.WriteBool(true) + + w.BeginSexp() + + w.WriteBool(false) + w.Annotation("123") + w.WriteBool(true) + + w.EndSexp() + + w.Annotation("false") + w.WriteBool(false) + }) +} + +func TestWriteTextInt(t *testing.T) { + expected := "(zero::0 1 -1 (9223372036854775807 -9223372036854775808))" + testTextWriter(t, expected, func(w Writer) { + w.BeginSexp() + + w.Annotation("zero") + w.WriteInt(0) + w.WriteInt(1) + w.WriteInt(-1) + + w.BeginSexp() + w.WriteInt(math.MaxInt64) + w.WriteInt(math.MinInt64) + w.EndSexp() + + w.EndSexp() + }) +} + +func TestWriteTextBigInt(t *testing.T) { + expected := "[0,big::18446744073709551616]" + testTextWriter(t, expected, func(w Writer) { + w.BeginList() + + w.WriteBigInt(big.NewInt(0)) + + var val, max, one big.Int + max.SetUint64(math.MaxUint64) + one.SetInt64(1) + val.Add(&max, &one) + + w.Annotation("big") + w.WriteBigInt(&val) + + w.EndList() + }) +} + +func TestWriteTextFloat(t *testing.T) { + expected := "{z:0e+0,nz:-0e+0,s:1.234e+1,l:1.234e-55,n:nan,i:+inf,ni:-inf}" + testTextWriter(t, expected, func(w Writer) { + w.BeginStruct() + + w.FieldName("z") + w.WriteFloat(0.0) + w.FieldName("nz") + w.WriteFloat(-1.0 / math.Inf(1)) + + w.FieldName("s") + w.WriteFloat(12.34) + w.FieldName("l") + w.WriteFloat(12.34e-56) + + w.FieldName("n") + w.WriteFloat(math.NaN()) + w.FieldName("i") + w.WriteFloat(math.Inf(1)) + w.FieldName("ni") + w.WriteFloat(math.Inf(-1)) + + w.EndStruct() + }) +} + +func TestWriteTextDecimal(t *testing.T) { + expected := "0.\n-1.23d-98" + testTextWriter(t, expected, func(w Writer) { + w.WriteDecimal(MustParseDecimal("0")) + w.WriteDecimal(MustParseDecimal("-123d-100")) + }) +} + +func TestWriteTextTimestamp(t *testing.T) { + expected := "1970-01-01T00:00:00.001Z\n1970-01-01T01:23:00+01:23" + testTextWriter(t, expected, func(w Writer) { + w.WriteTimestamp(time.Unix(0, 1000000).In(time.UTC)) + w.WriteTimestamp(time.Unix(0, 0).In(time.FixedZone("wtf", 4980))) + }) +} + +func TestWriteTextSymbol(t *testing.T) { + expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸',$123:$456}" + testTextWriter(t, expected, func(w Writer) { + w.BeginStruct() + + w.FieldName("foo") + w.WriteSymbol("bar") + w.FieldName("empty") + w.WriteSymbol("") + w.FieldName("null") + w.WriteSymbol("null") + + w.FieldName("f") + w.Annotation("a") + w.Annotation("b") + w.Annotation("u") + w.WriteSymbol("lo🇺🇸") + + w.FieldName("$123") + w.WriteSymbol("$456") + + w.EndStruct() + }) +} + +func TestWriteTextString(t *testing.T) { + expected := `("hello" "" ("\\\"\n\"\\" zany::"🤪"))` + testTextWriter(t, expected, func(w Writer) { + w.BeginSexp() + w.WriteString("hello") + w.WriteString("") + + w.BeginSexp() + w.WriteString("\\\"\n\"\\") + w.Annotation("zany") + w.WriteString("🤪") + w.EndSexp() + + w.EndSexp() + }) +} + +func TestWriteTextBlob(t *testing.T) { + expected := "{{AAEC/f7/}}\n{{SGVsbG8gV29ybGQ=}}\nempty::{{}}" + testTextWriter(t, expected, func(w Writer) { + w.WriteBlob([]byte{0, 1, 2, 0xFD, 0xFE, 0xFF}) + w.WriteBlob([]byte("Hello World")) + w.Annotation("empty") + w.WriteBlob(nil) + }) +} + +func TestWriteTextClob(t *testing.T) { + expected := "{hello:{{\"world\"}},bits:{{\"\\0\\x01\\xFE\\xFF\"}}}" + testTextWriter(t, expected, func(w Writer) { + w.BeginStruct() + w.FieldName("hello") + w.WriteClob([]byte("world")) + w.FieldName("bits") + w.WriteClob([]byte{0, 1, 0xFE, 0xFF}) + w.EndStruct() + }) +} + +func TestWriteTextFinish(t *testing.T) { + expected := "1\nfoo\n\"bar\"\n{}\n" + testTextWriter(t, expected, func(w Writer) { + w.WriteInt(1) + w.WriteSymbol("foo") + w.WriteString("bar") + w.BeginStruct() + w.EndStruct() + if err := w.Finish(); err != nil { + t.Fatal(err) + } + }) +} + +func TestWriteTextBadFinish(t *testing.T) { + buf := strings.Builder{} + w := NewTextWriter(&buf) + + w.BeginStruct() + err := w.Finish() + + if err == nil { + t.Error("should not be able to finish in the middle of a struct") + } +} + +func testTextWriter(t *testing.T, expected string, f func(Writer)) { + actual := writeText(f) + if actual != expected { + t.Errorf("expected: %v, actual: %v", expected, actual) + } +} + +func writeText(f func(Writer)) string { + buf := strings.Builder{} + w := NewTextWriter(&buf) + + f(w) + + return buf.String() +} diff --git a/ion/tokenizer.go b/ion/tokenizer.go new file mode 100644 index 00000000..3413de7f --- /dev/null +++ b/ion/tokenizer.go @@ -0,0 +1,1265 @@ +package ion + +import ( + "bufio" + "bytes" + "fmt" + "io" + "strings" +) + +type token int + +const ( + tokenError token = iota + + tokenEOF // End of input + + tokenNumber // Haven't seen enough to know which, yet + tokenBinary // 0b[01]+ + tokenHex // 0x[0-9a-fA-F]+ + tokenFloatInf // +inf + tokenFloatMinusInf // -inf + tokenTimestamp // 2001-01-01T00:00:00.000Z + + tokenSymbol // [a-zA-Z_]+ + tokenSymbolQuoted // '[^']+' + tokenSymbolOperator // +-/* + + tokenString // "[^"]+" + tokenLongString // '''[^']+''' + + tokenDot // . + tokenComma // , + tokenColon // : + tokenDoubleColon // :: + + tokenOpenParen // ( + tokenCloseParen // ) + tokenOpenBrace // { + tokenCloseBrace // } + tokenOpenBracket // [ + tokenCloseBracket // ] + tokenOpenDoubleBrace // {{ + tokenCloseDoubleBrace // }} +) + +func (t token) String() string { + switch t { + case tokenError: + return "" + case tokenEOF: + return "" + case tokenNumber: + return "" + case tokenBinary: + return "" + case tokenHex: + return "" + case tokenFloatInf: + return "+inf" + case tokenFloatMinusInf: + return "-inf" + case tokenTimestamp: + return "" + case tokenSymbol: + return "" + case tokenSymbolQuoted: + return "" + case tokenSymbolOperator: + return "" + + case tokenString: + return "" + case tokenLongString: + return "" + + case tokenDot: + return "." + case tokenComma: + return "," + case tokenColon: + return ":" + case tokenDoubleColon: + return "::" + + case tokenOpenParen: + return "(" + case tokenCloseParen: + return ")" + + case tokenOpenBrace: + return "{" + case tokenCloseBrace: + return "}" + + case tokenOpenBracket: + return "[" + case tokenCloseBracket: + return "]" + + case tokenOpenDoubleBrace: + return "{{" + case tokenCloseDoubleBrace: + return "}}" + + default: + return "" + } +} + +type tokenizer struct { + in *bufio.Reader + buffer []int + + token token + unfinished bool + pos uint64 +} + +func tokenizeString(in string) *tokenizer { + return tokenizeBytes([]byte(in)) +} + +func tokenizeBytes(in []byte) *tokenizer { + return tokenize(bytes.NewReader(in)) +} + +func tokenize(in io.Reader) *tokenizer { + return &tokenizer{ + in: bufio.NewReader(in), + } +} + +// Token returns the type of the current token. +func (t *tokenizer) Token() token { + return t.token +} + +func (t *tokenizer) Pos() uint64 { + return t.pos +} + +// Next advances to the next token in the input stream. +func (t *tokenizer) Next() error { + var c int + var err error + + if t.unfinished { + c, err = t.skipValue() + } else { + c, _, err = t.skipWhitespace() + } + + if err != nil { + return err + } + + switch { + case c == -1: + return t.ok(tokenEOF, true) + + case c == ':': + c2, err := t.peek() + if err != nil { + return err + } + if c2 == ':' { + t.read() + return t.ok(tokenDoubleColon, false) + } + return t.ok(tokenColon, false) + + case c == '{': + c2, err := t.peek() + if err != nil { + return err + } + if c2 == '{' { + t.read() + return t.ok(tokenOpenDoubleBrace, true) + } + return t.ok(tokenOpenBrace, true) + + case c == '}': + return t.ok(tokenCloseBrace, false) + + case c == '[': + return t.ok(tokenOpenBracket, true) + + case c == ']': + return t.ok(tokenCloseBracket, false) + + case c == '(': + return t.ok(tokenOpenParen, true) + + case c == ')': + return t.ok(tokenCloseParen, false) + + case c == ',': + return t.ok(tokenComma, false) + + case c == '.': + c2, err := t.peek() + if err != nil { + return err + } + if isOperatorChar(c2) { + t.unread(c) + return t.ok(tokenSymbolOperator, true) + } + return t.ok(tokenDot, false) + + case c == '\'': + ok, err := t.IsTripleQuote() + if err != nil { + return err + } + if ok { + return t.ok(tokenLongString, true) + } + return t.ok(tokenSymbolQuoted, true) + + case c == '+': + ok, err := t.isInf(c) + if err != nil { + return err + } + if ok { + return t.ok(tokenFloatInf, false) + } + t.unread(c) + return t.ok(tokenSymbolOperator, true) + + case c == '-': + c2, err := t.peek() + if err != nil { + return err + } + + if isDigit(c2) { + t.read() + tt, err := t.scanForNumericType(c2) + if err != nil { + return err + } + if tt == tokenTimestamp { + // can't have negative timestamps. + return t.invalidChar(c2) + } + t.unread(c2) + t.unread(c) + return t.ok(tt, true) + } + + ok, err := t.isInf(c) + if err != nil { + return err + } + if ok { + return t.ok(tokenFloatMinusInf, false) + } + + t.unread(c) + return t.ok(tokenSymbolOperator, true) + + case isOperatorChar(c): + t.unread(c) + return t.ok(tokenSymbolOperator, true) + + case c == '"': + return t.ok(tokenString, true) + + case isIdentifierStart(c): + t.unread(c) + return t.ok(tokenSymbol, true) + + case isDigit(c): + tt, err := t.scanForNumericType(c) + if err != nil { + return err + } + + t.unread(c) + return t.ok(tt, true) + + default: + return t.invalidChar(c) + } +} + +func (t *tokenizer) ok(tok token, more bool) error { + t.token = tok + t.unfinished = more + return nil +} + +// SetFinished marks the current token finished (indicating that the caller has +// chosen to step in to a list, sexp, or struct and Next should not skip over its +// contents in search of the next token). +func (t *tokenizer) SetFinished() { + t.unfinished = false +} + +// FinishValue skips to the end of the current value if (and only if) +// we're currently in the middle of reading it. +func (t *tokenizer) FinishValue() (bool, error) { + if !t.unfinished { + return false, nil + } + + c, err := t.skipValue() + if err != nil { + return true, err + } + + t.unread(c) + t.unfinished = false + return true, nil +} + +// ReadValue reads the value of a token of the given type. +func (t *tokenizer) ReadValue(tok token) (string, error) { + var str string + var err error + + switch tok { + case tokenSymbol: + str, err = t.readSymbol() + case tokenSymbolQuoted: + str, err = t.readQuotedSymbol() + case tokenSymbolOperator, tokenDot: + str, err = t.readOperator() + case tokenString: + str, err = t.readString() + case tokenLongString: + str, err = t.readLongString() + case tokenBinary: + str, err = t.readBinary() + case tokenHex: + str, err = t.readHex() + case tokenTimestamp: + str, err = t.readTimestamp() + default: + panic(fmt.Sprintf("unsupported token type %v", tok)) + } + + if err != nil { + return "", err + } + + t.unfinished = false + return str, nil +} + +// ReadNumber reads a number and determines the type. +func (t *tokenizer) ReadNumber() (string, Type, error) { + w := strings.Builder{} + + c, err := t.read() + if err != nil { + return "", NoType, err + } + + if c == '-' { + w.WriteByte('-') + c, err = t.read() + if err != nil { + return "", NoType, err + } + } + + first := c + oldlen := w.Len() + + c, err = t.readDigits(c, &w) + if err != nil { + return "", NoType, err + } + + if first == '0' { + if w.Len()-oldlen > 1 { + return "", NoType, &SyntaxError{"invalid leading zeroes", t.pos - 1} + } + } + + tt := IntType + + if c == '.' { + w.WriteByte('.') + tt = DecimalType + + if c, err = t.read(); err != nil { + return "", NoType, err + } + if c, err = t.readDigits(c, &w); err != nil { + return "", NoType, err + } + } + + switch c { + case 'e', 'E': + tt = FloatType + + w.WriteByte(byte(c)) + if c, err = t.readExponent(&w); err != nil { + return "", NoType, err + } + + case 'd', 'D': + tt = DecimalType + + w.WriteByte(byte(c)) + if c, err = t.readExponent(&w); err != nil { + return "", NoType, err + } + } + + ok, err := t.isStopChar(c) + if err != nil { + return "", NoType, err + } + if !ok { + return "", NoType, t.invalidChar(c) + } + t.unread(c) + + return w.String(), tt, nil +} + +func (t *tokenizer) readExponent(w io.ByteWriter) (int, error) { + c, err := t.read() + if err != nil { + return 0, err + } + + if c == '+' || c == '-' { + w.WriteByte(byte(c)) + if c, err = t.read(); err != nil { + return 0, err + } + } + + return t.readDigits(c, w) +} + +func (t *tokenizer) readDigits(c int, w io.ByteWriter) (int, error) { + if !isDigit(c) { + return c, nil + } + w.WriteByte(byte(c)) + + return t.readRadixDigits(isDigit, w) +} + +// ReadSymbol reads an unquoted symbol value. +func (t *tokenizer) readSymbol() (string, error) { + ret := strings.Builder{} + + c, err := t.peek() + if err != nil { + return "", err + } + + for isIdentifierPart(c) { + ret.WriteByte(byte(c)) + t.read() + c, err = t.peek() + if err != nil { + return "", err + } + } + + return ret.String(), nil +} + +// ReadQuotedSymbol reads a quoted symbol. +func (t *tokenizer) readQuotedSymbol() (string, error) { + ret := strings.Builder{} + + for { + c, err := t.read() + if err != nil { + return "", err + } + + switch c { + case -1, '\n': + return "", t.invalidChar(c) + + case '\'': + return ret.String(), nil + + case '\\': + c, err = t.peek() + if err != nil { + return "", err + } + + if c == '\n' { + t.read() + continue + } + + r, err := t.readEscapedChar(false) + if err != nil { + return "", err + } + ret.WriteRune(r) + + default: + ret.WriteByte(byte(c)) + } + } +} + +func (t *tokenizer) readOperator() (string, error) { + ret := strings.Builder{} + + c, err := t.peek() + if err != nil { + return "", err + } + + for isOperatorChar(c) { + ret.WriteByte(byte(c)) + t.read() + c, err = t.peek() + if err != nil { + return "", err + } + } + + return ret.String(), nil +} + +// ReadString reads a quoted string. +func (t *tokenizer) readString() (string, error) { + ret := strings.Builder{} + + for { + c, err := t.read() + if err != nil { + return "", err + } + + switch c { + case -1, '\n': + return "", t.invalidChar(c) + + case '"': + return ret.String(), nil + + case '\\': + c, err = t.peek() + if err != nil { + return "", err + } + + if c == '\n' { + t.read() + continue + } + + r, err := t.readEscapedChar(false) + if err != nil { + return "", err + } + ret.WriteRune(r) + + default: + ret.WriteByte(byte(c)) + } + } +} + +// ReadLongString reads a triple-quoted string. +func (t *tokenizer) readLongString() (string, error) { + ret := strings.Builder{} + + for { + c, err := t.read() + if err != nil { + return "", err + } + + switch c { + case -1: + return "", t.invalidChar(c) + + case '\'': + ok, err := t.skipEndOfLongString(t.skipCommentsHandler) + if err != nil { + return "", err + } + if ok { + return ret.String(), nil + } + + case '\\': + c, err = t.peek() + if err != nil { + return "", err + } + + if c == '\n' { + t.read() + continue + } + + r, err := t.readEscapedChar(false) + if err != nil { + return "", err + } + ret.WriteRune(r) + + default: + ret.WriteByte(byte(c)) + } + } +} + +// ReadEscapedChar reads an escaped character. +func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { + // We just read the '\', grab the next char. + c, err := t.read() + if err != nil { + return 0, err + } + + switch c { + case '0': + return '\x00', nil + case 'a': + return '\a', nil + case 'b': + return '\b', nil + case 't': + return '\t', nil + case 'n': + return '\n', nil + case 'f': + return '\f', nil + case 'r': + return '\r', nil + case 'v': + return '\v', nil + case '?': + return '?', nil + case '/': + return '/', nil + case '\'': + return '\'', nil + case '"': + return '"', nil + case '\\': + return '\\', nil + case 'U': + if clob { + return 0, t.invalidChar('U') + } + return t.readHexEscapeSeq(8) + case 'u': + return t.readHexEscapeSeq(4) + case 'x': + return t.readHexEscapeSeq(2) + } + + return 0, &SyntaxError{fmt.Sprintf("bad escape sequence '\\%c'", c), t.pos - 2} +} + +func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { + val := rune(0) + + for len > 0 { + c, err := t.read() + if err != nil { + return 0, err + } + + d, err := t.fromHex(c) + if err != nil { + return 0, err + } + + val = (val << 4) | rune(d) + len-- + } + + return val, nil +} + +func (t *tokenizer) fromHex(c int) (int, error) { + if c >= '0' && c <= '9' { + return c - '0', nil + } + if c >= 'a' && c <= 'f' { + return 10 + (c - 'a'), nil + } + if c >= 'A' && c <= 'F' { + return 10 + (c - 'A'), nil + } + return 0, t.invalidChar(c) +} + +func (t *tokenizer) readBinary() (string, error) { + isB := func(c int) bool { + return c == 'b' || c == 'B' + } + isDigit := func(c int) bool { + return c == '0' || c == '1' + } + return t.readRadix(isB, isDigit) +} + +func (t *tokenizer) readHex() (string, error) { + isX := func(c int) bool { + return c == 'x' || c == 'X' + } + return t.readRadix(isX, isHexDigit) +} + +func (t *tokenizer) readRadix(pok, dok matcher) (string, error) { + w := strings.Builder{} + + c, err := t.read() + if err != nil { + return "", err + } + + if c == '-' { + w.WriteByte('-') + c, err = t.read() + if err != nil { + return "", err + } + } + + if c != '0' { + return "", t.invalidChar(c) + } + w.WriteByte('0') + + c, err = t.read() + if err != nil { + return "", err + } + if !pok(c) { + return "", t.invalidChar(c) + } + w.WriteByte(byte(c)) + + c, err = t.readRadixDigits(dok, &w) + if err != nil { + return "", err + } + + ok, err := t.isStopChar(c) + if err != nil { + return "", err + } + if !ok { + return "", t.invalidChar(c) + } + t.unread(c) + + return w.String(), nil +} + +func (t *tokenizer) readRadixDigits(dok matcher, w io.ByteWriter) (int, error) { + var c int + var err error + + for { + c, err = t.read() + if err != nil { + return 0, err + } + if c == '_' { + continue + } + if !dok(c) { + return c, nil + } + w.WriteByte(byte(c)) + } +} + +func (t *tokenizer) readTimestamp() (string, error) { + w := strings.Builder{} + + c, err := t.readTimestampDigits(4, &w) + if err != nil { + return "", err + } + if c == 'T' { + // yyyyT + w.WriteByte('T') + return w.String(), nil + } + if c != '-' { + return "", t.invalidChar(c) + } + w.WriteByte('-') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c == 'T' { + // yyyy-mmT + w.WriteByte('T') + return w.String(), nil + } + if c != '-' { + return "", t.invalidChar(c) + } + w.WriteByte('-') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c != 'T' { + // yyyy-mm-dd + return t.readTimestampFinish(c, &w) + } + w.WriteByte('T') + + if c, err = t.read(); err != nil { + return "", err + } + if !isDigit(c) { + // yyyy-mm-ddT(+hh:mm)? + if c, err = t.readTimestampOffset(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) + } + w.WriteByte(byte(c)) + + if c, err = t.readTimestampDigits(1, &w); err != nil { + return "", err + } + if c != ':' { + return "", t.invalidChar(c) + } + w.WriteByte(':') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c != ':' { + // yyyy-mm-ddThh:mmZ + if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) + } + w.WriteByte(':') + + if c, err = t.readTimestampDigits(2, &w); err != nil { + return "", err + } + if c != '.' { + // yyyy-mm-ddThh:mm:ssZ + if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) + } + w.WriteByte('.') + + // yyyy-mm-ddThh:mm:ss.ssssZ + if c, err = t.read(); err != nil { + return "", err + } + if isDigit(c) { + if c, err = t.readDigits(c, &w); err != nil { + return "", err + } + } + + if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { + return "", err + } + return t.readTimestampFinish(c, &w) +} + +func (t *tokenizer) readTimestampOffsetOrZ(c int, w io.ByteWriter) (int, error) { + if c == '-' || c == '+' { + return t.readTimestampOffset(c, w) + } + if c == 'z' || c == 'Z' { + w.WriteByte(byte(c)) + return t.read() + } + return 0, t.invalidChar(c) +} + +func (t *tokenizer) readTimestampOffset(c int, w io.ByteWriter) (int, error) { + if c != '-' && c != '+' { + return c, nil + } + w.WriteByte(byte(c)) + + c, err := t.readTimestampDigits(2, w) + if err != nil { + return 0, err + } + if c != ':' { + return 0, t.invalidChar(c) + } + w.WriteByte(':') + return t.readTimestampDigits(2, w) +} + +func (t *tokenizer) readTimestampDigits(n int, w io.ByteWriter) (int, error) { + for n > 0 { + c, err := t.read() + if err != nil { + return 0, err + } + if !isDigit(c) { + return 0, t.invalidChar(c) + } + w.WriteByte(byte(c)) + n-- + } + return t.read() +} + +func (t *tokenizer) readTimestampFinish(c int, w fmt.Stringer) (string, error) { + ok, err := t.isStopChar(c) + if err != nil { + return "", err + } + if !ok { + return "", t.invalidChar(c) + } + t.unread(c) + return w.String(), nil +} + +func (t *tokenizer) ReadBlob() (string, error) { + w := strings.Builder{} + + var ( + c int + err error + ) + + for { + if c, _, err = t.skipLobWhitespace(); err != nil { + return "", err + } + if c == -1 { + return "", t.invalidChar(c) + } + if c == '}' { + break + } + w.WriteByte(byte(c)) + } + + if c, err = t.read(); err != nil { + return "", err + } + if c != '}' { + return "", t.invalidChar(c) + } + + t.unfinished = false + return w.String(), nil +} + +func (t *tokenizer) ReadShortClob() (string, error) { + str, err := t.readString() + if err != nil { + return "", err + } + + c, _, err := t.skipLobWhitespace() + if err != nil { + return "", err + } + if c != '}' { + return "", t.invalidChar(c) + } + + if c, err = t.read(); err != nil { + return "", err + } + if c != '}' { + return "", t.invalidChar(c) + } + + t.unfinished = false + return str, nil +} + +func (t *tokenizer) ReadLongClob() (string, error) { + str, err := t.readLongString() + if err != nil { + return "", err + } + + c, _, err := t.skipLobWhitespace() + if err != nil { + return "", err + } + if c != '}' { + return "", t.invalidChar(c) + } + + if c, err = t.read(); err != nil { + return "", err + } + if c != '}' { + return "", t.invalidChar(c) + } + + t.unfinished = false + return str, nil +} + +// IsTripleQuote returns true if this is a triple-quote sequence ('''). +func (t *tokenizer) IsTripleQuote() (bool, error) { + // We've just read a '\'', check if the next two are too. + cs, err := t.peekN(2) + if err == io.EOF { + return false, nil + } + if err != nil { + return false, err + } + + if cs[0] == '\'' && cs[1] == '\'' { + t.skipN(2) + return true, nil + } + + return false, nil +} + +// IsInf returns true if the given character begins a '+inf' or +// '-inf' keyword. +func (t *tokenizer) isInf(c int) (bool, error) { + if c != '+' && c != '-' { + return false, nil + } + + cs, err := t.peekN(5) + if err != nil && err != io.EOF { + return false, err + } + + if len(cs) < 3 || cs[0] != 'i' || cs[1] != 'n' || cs[2] != 'f' { + // Definitely not +-inf. + return false, nil + } + + if len(cs) == 3 || isStopChar(cs[3]) { + // Cleanly-terminated +-inf. + t.skipN(3) + return true, nil + } + + if cs[3] == '/' && len(cs) > 4 && (cs[4] == '/' || cs[4] == '*') { + t.skipN(3) + // +-inf followed immediately by a comment works too. + return true, nil + } + + return false, nil +} + +// ScanForNumericType attempts to determine what type of number we +// have by peeking at a fininte number of characters. We can rule +// out binary (0b...), hex (0x...), and timestamps (....-) via this +// method. There are a couple other cases where we *could* distinguish, +// but it's unclear that it's worth it. +func (t *tokenizer) scanForNumericType(c int) (token, error) { + if !isDigit(c) { + panic("scanForNumericType with non-digit") + } + + cs, err := t.peekN(4) + if err != nil && err != io.EOF { + return tokenError, err + } + + if c == '0' && len(cs) > 0 { + switch { + case cs[0] == 'b' || cs[0] == 'B': + return tokenBinary, nil + + case cs[0] == 'x' || cs[0] == 'X': + return tokenHex, nil + } + } + + if len(cs) >= 4 { + if isDigit(cs[0]) && isDigit(cs[1]) && isDigit(cs[2]) { + if cs[3] == '-' || cs[3] == 'T' { + return tokenTimestamp, nil + } + } + } + + // Can't tell yet; wait until actually reading it to find out. + return tokenNumber, nil +} + +// Is this character a valid way to end a 'normal' (unquoted) value? +// Peeks in case of '/', so don't call it with a character you've +// peeked. +func (t *tokenizer) isStopChar(c int) (bool, error) { + if isStopChar(c) { + return true, nil + } + + if c == '/' { + c2, err := t.peek() + if err != nil { + return false, err + } + if c2 == '/' || c2 == '*' { + // Comment, also all done. + return true, nil + } + } + + return false, nil +} + +type matcher func(int) bool + +// Expect reads a byte of input and asserts that it matches some +// condition, returning an error if it does not. +func (t *tokenizer) expect(f matcher) error { + c, err := t.read() + if err != nil { + return err + } + if !f(c) { + return t.invalidChar(c) + } + return nil +} + +// InvalidChar returns an error complaining that the given character was +// unexpected. +func (t *tokenizer) invalidChar(c int) error { + if c == -1 { + return &UnexpectedEOFError{t.pos - 1} + } + return &UnexpectedRuneError{rune(c), t.pos - 1} +} + +// SkipN skips over the next n bytes of input. Presumably you've +// already peeked at them, and decided they're not worth keeping. +func (t *tokenizer) skipN(n int) error { + for i := 0; i < n; i++ { + c, err := t.read() + if err != nil { + return err + } + if c == -1 { + break + } + } + return nil +} + +// PeekN peeks at the next n bytes of input. Unlike read/peek, does +// NOT return -1 to indicate EOF. If it cannot peek N bytes ahead +// because of an EOF (or other error), it returns the bytes it was +// able to peek at along with the error. +func (t *tokenizer) peekN(n int) ([]int, error) { + var ret []int + var err error + + // Read ahead. + for i := 0; i < n; i++ { + var c int + c, err = t.read() + if err != nil { + break + } + if c == -1 { + err = io.EOF + break + } + ret = append(ret, c) + } + + // Put back the ones we got. + if err == io.EOF { + t.unread(-1) + } + for i := len(ret) - 1; i >= 0; i-- { + t.unread(ret[i]) + } + + return ret, err +} + +// Peek at the next byte of input without removing it. Other conditions +// from Read all apply. +func (t *tokenizer) peek() (int, error) { + if len(t.buffer) > 0 { + // Short-circuit and peek from the buffer. + return t.buffer[len(t.buffer)-1], nil + } + + c, err := t.read() + if err != nil { + return 0, err + } + + t.unread(c) + return c, nil +} + +// Read reads a byte of input from the underlying reader. EOF is +// returned as (-1, nil) rather than (0, io.EOF), because I find it +// easier to reason about that way. Newlines are normalized to '\n'. +func (t *tokenizer) read() (int, error) { + t.pos++ + if len(t.buffer) > 0 { + // We've already peeked ahead; read from our buffer. + c := t.buffer[len(t.buffer)-1] + t.buffer = t.buffer[:len(t.buffer)-1] + return c, nil + } + + c, err := t.in.ReadByte() + if err == io.EOF { + return -1, nil + } + if err != nil { + return 0, &IOError{err} + } + + // Normalize \r and \r\n to just \n. + if c == '\r' { + cs, err := t.in.Peek(1) + if err != nil && err != io.EOF { + // Not EOF, because we haven't dealt with the '\r' yet. + return 0, &IOError{err} + } + if len(cs) > 0 && cs[0] == '\n' { + // Skip over the '\n' as well. + t.in.ReadByte() + } + return '\n', nil + } + + return int(c), nil +} + +// Unread pushes a character (or -1) back into the input stream to +// be read again later. +func (t *tokenizer) unread(c int) { + t.pos-- + t.buffer = append(t.buffer, c) +} diff --git a/ion/tokenizer_test.go b/ion/tokenizer_test.go new file mode 100644 index 00000000..f4af1584 --- /dev/null +++ b/ion/tokenizer_test.go @@ -0,0 +1,571 @@ +package ion + +import ( + "io" + "testing" +) + +func TestNext(t *testing.T) { + tok := tokenizeString("foo::'foo':[] 123, {})") + + next := func(tt token) { + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != tt { + t.Fatalf("expected %v, got %v", tt, tok.Token()) + } + } + + next(tokenSymbol) + next(tokenDoubleColon) + next(tokenSymbolQuoted) + next(tokenColon) + next(tokenOpenBracket) + next(tokenNumber) + next(tokenComma) + next(tokenOpenBrace) +} + +func TestReadSymbol(t *testing.T) { + test := func(str string, expected string, next token) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + if err := tok.Next(); err != nil { + t.Fatal(err) + } + + if tok.Token() != tokenSymbol { + t.Fatal("not a symbol") + } + + actual, err := tok.readSymbol() + if err != nil { + t.Fatal(err) + } + + if actual != expected { + t.Errorf("expected '%v', got '%v'", expected, actual) + } + + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != next { + t.Errorf("expected next=%v, got next=%v", next, tok.Token()) + } + }) + } + + test("a", "a", tokenEOF) + test("abc", "abc", tokenEOF) + test("null +inf", "null", tokenFloatInf) + test("false,", "false", tokenComma) + test("nan]", "nan", tokenCloseBracket) +} + +func TestReadSymbols(t *testing.T) { + tok := tokenizeString("foo bar baz beep boop null") + expected := []string{"foo", "bar", "baz", "beep", "boop", "null"} + + for i := 0; i < len(expected); i++ { + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != tokenSymbol { + t.Fatalf("expected %v, got %v", tokenSymbol, tok.Token()) + } + + val, err := tok.readSymbol() + if err != nil { + t.Fatal(err) + } + + if val != expected[i] { + t.Errorf("expected %v, got %v", expected[i], val) + } + } +} + +func TestReadQuotedSymbol(t *testing.T) { + test := func(str string, expected string, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + if err := tok.Next(); err != nil { + t.Fatal(err) + } + + if tok.Token() != tokenSymbolQuoted { + t.Fatal("not a quoted symbol") + } + + actual, err := tok.readQuotedSymbol() + if err != nil { + t.Fatal(err) + } + + if actual != expected { + t.Errorf("expected '%v', got '%v'", expected, actual) + } + + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + if c != next { + t.Errorf("expected next=%q, got next=%q", next, c) + } + }) + } + + test("'a'", "a", -1) + test("'a b c'", "a b c", -1) + test("'null' ", "null", ' ') + test("'false',", "false", ',') + test("'nan']", "nan", ']') + + test("'a\\'b'", "a'b", -1) + test("'a\\\nb'", "ab", -1) + test("'a\\\\b'", "a\\b", -1) + test("'a\x20b'", "a b", -1) + test("'a\\u2248b'", "a≈b", -1) + test("'a\\U0001F44Db'", "a👍b", -1) +} + +func TestReadTimestamp(t *testing.T) { + test := func(str string, eval string, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + if err := tok.Next(); err != nil { + t.Fatal(err) + } + if tok.Token() != tokenTimestamp { + t.Fatalf("unexpected token %v", tok.Token()) + } + + val, err := tok.ReadValue(tokenTimestamp) + if err != nil { + t.Fatal(err) + } + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + if c != next { + t.Errorf("expected %q, got %q", next, c) + } + }) + } + + test("2001T", "2001T", -1) + test("2001-01T,", "2001-01T", ',') + test("2001-01-02}", "2001-01-02", '}') + test("2001-01-02T ", "2001-01-02T", ' ') + test("2001-01-02T+00:00\t", "2001-01-02T+00:00", '\t') + test("2001-01-02T-00:00\n", "2001-01-02T-00:00", '\n') + test("2001-01-02T03:04+00:00 ", "2001-01-02T03:04+00:00", ' ') + test("2001-01-02T03:04-00:00 ", "2001-01-02T03:04-00:00", ' ') + test("2001-01-02T03:04Z ", "2001-01-02T03:04Z", ' ') + test("2001-01-02T03:04z ", "2001-01-02T03:04z", ' ') + test("2001-01-02T03:04:05Z ", "2001-01-02T03:04:05Z", ' ') + test("2001-01-02T03:04:05+00:00 ", "2001-01-02T03:04:05+00:00", ' ') + test("2001-01-02T03:04:05.666Z ", "2001-01-02T03:04:05.666Z", ' ') + test("2001-01-02T03:04:05.666666z ", "2001-01-02T03:04:05.666666z", ' ') +} + +func TestIsTripleQuote(t *testing.T) { + test := func(str string, eok bool, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + + ok, err := tok.IsTripleQuote() + if err != nil { + t.Fatal(err) + } + if ok != eok { + t.Errorf("expected ok=%v, got ok=%v", eok, ok) + } + + read(t, tok, next) + }) + } + + test("''string'''", true, 's') + test("'string'''", false, '\'') + test("'", false, '\'') + test("", false, -1) +} + +func TestIsInf(t *testing.T) { + test := func(str string, eok bool, next int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + + ok, err := tok.isInf(c) + if err != nil { + t.Fatal(err) + } + + if ok != eok { + t.Errorf("expected %v, got %v", eok, ok) + } + + c, err = tok.read() + if err != nil { + t.Fatal(err) + } + if c != next { + t.Errorf("expected '%c', got '%c'", next, c) + } + }) + } + + test("+inf", true, -1) + test("-inf", true, -1) + test("+inf ", true, ' ') + test("-inf\t", true, '\t') + test("-inf\n", true, '\n') + test("+inf,", true, ',') + test("-inf}", true, '}') + test("+inf)", true, ')') + test("-inf]", true, ']') + test("+inf//", true, '/') + test("+inf/*", true, '/') + + test("+inf/", false, 'i') + test("-inf/0", false, 'i') + test("+int", false, 'i') + test("-iot", false, 'i') + test("+unf", false, 'u') + test("_inf", false, 'i') + + test("-in", false, 'i') + test("+i", false, 'i') + test("+", false, -1) + test("-", false, -1) +} + +func TestScanForNumericType(t *testing.T) { + test := func(str string, ett token) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + + tt, err := tok.scanForNumericType(c) + if err != nil { + t.Fatal(err) + } + if tt != ett { + t.Errorf("expected %v, got %v", ett, tt) + } + }) + } + + test("0b0101", tokenBinary) + test("0B", tokenBinary) + test("0xABCD", tokenHex) + test("0X", tokenHex) + test("0000-00-00", tokenTimestamp) + test("0000T", tokenTimestamp) + + test("0", tokenNumber) + test("1b0101", tokenNumber) + test("1B", tokenNumber) + test("1x0101", tokenNumber) + test("1X", tokenNumber) + test("1234", tokenNumber) + test("12345", tokenNumber) + test("1,23T", tokenNumber) + test("12,3T", tokenNumber) + test("123,T", tokenNumber) +} + +func TestSkipWhitespace(t *testing.T) { + test := func(str string, eok bool, ec int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, ok, err := tok.skipWhitespace() + if err != nil { + t.Fatal(err) + } + + if ok != eok { + t.Errorf("expected ok=%v, got ok=%v", eok, ok) + } + if c != ec { + t.Errorf("expected c='%c', got c='%c'", ec, c) + } + }) + } + + test("/ 0)", false, '/') + test("xyz_", false, 'x') + test(" / 0)", true, '/') + test(" xyz_", true, 'x') + test(" \t\r\n / 0)", true, '/') + test("\t\t // comment\t\r\n\t\t x", true, 'x') + test(" \r\n /* comment *//* \r\n comment */x", true, 'x') +} + +func TestSkipLobWhitespace(t *testing.T) { + test := func(str string, eok bool, ec int) { + t.Run(str, func(t *testing.T) { + tok := tokenizeString(str) + c, ok, err := tok.skipLobWhitespace() + if err != nil { + t.Fatal(err) + } + + if ok != eok { + t.Errorf("expected ok=%v, got ok=%v", eok, ok) + } + if c != ec { + t.Errorf("expected c='%c', got c='%c'", ec, c) + } + }) + } + + test("///=", false, '/') + test("xyz_", false, 'x') + test(" ///=", true, '/') + test(" xyz_", true, 'x') + test("\r\n\t///=", true, '/') + test("\r\n\txyz_", true, 'x') +} + +func TestSkipCommentsHandler(t *testing.T) { + t.Run("SingleLine", func(t *testing.T) { + tok := tokenizeString("/comment\nok") + ok, err := tok.skipCommentsHandler() + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected ok=true, got ok=false") + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) + }) + + t.Run("Block", func(t *testing.T) { + tok := tokenizeString("*comm\nent*/ok") + ok, err := tok.skipCommentsHandler() + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("expected ok=true, got ok=false") + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) + }) + + t.Run("FalseAlarm", func(t *testing.T) { + tok := tokenizeString(" 0)") + ok, err := tok.skipCommentsHandler() + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("expected ok=false, got ok=true") + } + + read(t, tok, ' ') + read(t, tok, '0') + read(t, tok, ')') + read(t, tok, -1) + }) +} + +func TestSkipSingleLineComment(t *testing.T) { + tok := tokenizeString("single-line comment\r\nok") + err := tok.skipSingleLineComment() + if err != nil { + t.Fatal(err) + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) +} + +func TestSkipSingleLineCommentOnLastLine(t *testing.T) { + tok := tokenizeString("single-line comment") + err := tok.skipSingleLineComment() + if err != nil { + t.Fatal(err) + } + + read(t, tok, -1) +} + +func TestSkipBlockComment(t *testing.T) { + tok := tokenizeString("this is/ a\nmulti-line /** comment.**/ok") + err := tok.skipBlockComment() + if err != nil { + t.Fatal(err) + } + + read(t, tok, 'o') + read(t, tok, 'k') + read(t, tok, -1) +} + +func TestSkipInvalidBlockComment(t *testing.T) { + tok := tokenizeString("this is a comment that never ends") + err := tok.skipBlockComment() + if err == nil { + t.Error("did not fail on bad block comment") + } +} + +func TestPeekN(t *testing.T) { + tok := tokenizeString("abc\r\ndef") + + peekN(t, tok, 1, nil, 'a') + peekN(t, tok, 2, nil, 'a', 'b') + peekN(t, tok, 3, nil, 'a', 'b', 'c') + + read(t, tok, 'a') + read(t, tok, 'b') + + peekN(t, tok, 3, nil, 'c', '\n', 'd') + peekN(t, tok, 2, nil, 'c', '\n') + peekN(t, tok, 3, nil, 'c', '\n', 'd') + + read(t, tok, 'c') + read(t, tok, '\n') + read(t, tok, 'd') + + peekN(t, tok, 3, io.EOF, 'e', 'f') + peekN(t, tok, 3, io.EOF, 'e', 'f') + peekN(t, tok, 2, nil, 'e', 'f') + + read(t, tok, 'e') + read(t, tok, 'f') + read(t, tok, -1) + + peekN(t, tok, 10, io.EOF) +} + +func peekN(t *testing.T, tok *tokenizer, n int, ee error, ecs ...int) { + cs, err := tok.peekN(n) + if err != ee { + t.Fatalf("expected err=%v, got err=%v", ee, err) + } + if !equal(ecs, cs) { + t.Errorf("expected %v, got %v", ecs, cs) + } +} + +func equal(a, b []int) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if a[i] != b[i] { + return false + } + } + + return true +} + +func TestPeek(t *testing.T) { + tok := tokenizeString("abc") + + peek(t, tok, 'a') + peek(t, tok, 'a') + read(t, tok, 'a') + + peek(t, tok, 'b') + tok.unread('a') + + peek(t, tok, 'a') + read(t, tok, 'a') + read(t, tok, 'b') + peek(t, tok, 'c') + peek(t, tok, 'c') + + read(t, tok, 'c') + peek(t, tok, -1) + peek(t, tok, -1) + read(t, tok, -1) +} + +func peek(t *testing.T, tok *tokenizer, expected int) { + c, err := tok.peek() + if err != nil { + t.Fatal(err) + } + if c != expected { + t.Errorf("expected %v, got %v", expected, c) + } +} + +func TestReadUnread(t *testing.T) { + tok := tokenizeString("abc\rd\ne\r\n") + + read(t, tok, 'a') + tok.unread('a') + + read(t, tok, 'a') + read(t, tok, 'b') + read(t, tok, 'c') + tok.unread('c') + tok.unread('b') + + read(t, tok, 'b') + read(t, tok, 'c') + read(t, tok, '\n') + tok.unread('\n') + + read(t, tok, '\n') + read(t, tok, 'd') + read(t, tok, '\n') + read(t, tok, 'e') + read(t, tok, '\n') + read(t, tok, -1) + + tok.unread(-1) + tok.unread('\n') + + read(t, tok, '\n') + read(t, tok, -1) + read(t, tok, -1) +} + +func TestTokenToString(t *testing.T) { + for i := tokenError; i <= tokenCloseDoubleBrace+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected non-empty string for token %v", int(i)) + } + } +} + +func read(t *testing.T, tok *tokenizer, expected int) { + c, err := tok.read() + if err != nil { + t.Fatal(err) + } + if c != expected { + t.Errorf("expected %v, got %v", expected, c) + } +} diff --git a/ion/type.go b/ion/type.go new file mode 100644 index 00000000..f0165909 --- /dev/null +++ b/ion/type.go @@ -0,0 +1,125 @@ +package ion + +import "fmt" + +// A Type represents the type of an Ion Value. +type Type uint8 + +const ( + // NoType is returned by a Reader that is not currently pointing at a value. + NoType Type = iota + + // NullType is the type of the (unqualified) Ion null value. + NullType + + // BoolType is the type of an Ion boolean, true or false. + BoolType + + // IntType is the type of a signed Ion integer of arbitrary size. + IntType + + // FloatType is the type of a fixed-precision Ion floating-point value. + FloatType + + // DecimalType is the type of an arbitrary-precision Ion decimal value. + DecimalType + + // TimestampType is the type of an arbitrary-precision Ion timestamp. + TimestampType + + // SymbolType is the type of an Ion symbol, mapped to an integer ID by a SymbolTable + // to (potentially) save space. + SymbolType + + // StringType is the type of a non-symbol Unicode string, represented directly. + StringType + + // ClobType is the type of a character large object. Like a BlobType, it stores an + // arbitrary sequence of bytes, but it represents them in text form as an escaped-ASCII + // string rather than a base64-encoded string. + ClobType + + // BlobType is the type of a binary large object; a sequence of arbitrary bytes. + BlobType + + // ListType is the type of a list, recursively containing zero or more Ion values. + ListType + + // SexpType is the type of an s-expression. Like a ListType, it contains a sequence + // of zero or more Ion values, but with a lisp-like syntax when encoded as text. + SexpType + + // StructType is the type of a structure, recursively containing a sequence of named + // (by an Ion symbol) Ion values. + StructType +) + +// String implements fmt.Stringer for Type. +func (t Type) String() string { + switch t { + case NoType: + return "" + case NullType: + return "null" + case BoolType: + return "bool" + case IntType: + return "int" + case FloatType: + return "float" + case DecimalType: + return "decimal" + case TimestampType: + return "timestamp" + case StringType: + return "string" + case SymbolType: + return "symbol" + case BlobType: + return "blob" + case ClobType: + return "clob" + case StructType: + return "struct" + case ListType: + return "list" + case SexpType: + return "sexp" + default: + return fmt.Sprintf("", uint8(t)) + } +} + +// IntSize represents the size of an integer. +type IntSize uint8 + +const ( + // NullInt is the size of null.int and other things that aren't actually ints. + NullInt IntSize = iota + // Int32 is the size of an Ion integer that can be losslessly stored in an int32. + Int32 + // Int64 is the size of an Ion integer that can be losslessly stored in an int64. + Int64 + // Uint64 is the size of an Ion integer that can be losslessly stored in a uint64. + Uint64 + // BigInt is the size of an Ion integer that can only be losslessly stored in a big.Int. + BigInt +) + +// String implements fmt.Stringer for IntSize. +func (i IntSize) String() string { + switch i { + case NullInt: + return "null.int" + case Int32: + return "int32" + case Int64: + return "int64" + case Uint64: + return "uint64" + case BigInt: + return "big.Int" + default: + return fmt.Sprintf("", uint8(i)) + } +} diff --git a/ion/type_test.go b/ion/type_test.go new file mode 100644 index 00000000..e1702baa --- /dev/null +++ b/ion/type_test.go @@ -0,0 +1,21 @@ +package ion + +import "testing" + +func TestTypeToString(t *testing.T) { + for i := NoType; i <= StructType+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected a non-empty string for type %v", uint8(i)) + } + } +} + +func TestIntSizeToString(t *testing.T) { + for i := NullInt; i <= BigInt+1; i++ { + str := i.String() + if str == "" { + t.Errorf("expected a non-empty string for size %v", uint8(i)) + } + } +} diff --git a/ion/unmarshal.go b/ion/unmarshal.go new file mode 100644 index 00000000..e5e61090 --- /dev/null +++ b/ion/unmarshal.go @@ -0,0 +1,673 @@ +package ion + +import ( + "bytes" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +var ( + // ErrNoInput is returned when there is no input to decode + ErrNoInput = errors.New("ion: no input to decode") +) + +// Unmarshal unmarshals Ion data to the given object. +func Unmarshal(data []byte, v interface{}) error { + return NewDecoder(NewReader(bytes.NewReader(data))).DecodeTo(v) +} + +// UnmarshalStr unmarshals Ion data from a string to the given object. +func UnmarshalStr(data string, v interface{}) error { + return Unmarshal([]byte(data), v) +} + +// UnmarshalFrom unmarshal Ion data from a reader to the given object. +func UnmarshalFrom(r Reader, v interface{}) error { + d := Decoder{ + r: r, + } + return d.DecodeTo(v) +} + +// A Decoder decodes go values from an Ion reader. +type Decoder struct { + r Reader +} + +// NewDecoder creates a new decoder. +func NewDecoder(r Reader) *Decoder { + return &Decoder{ + r: r, + } +} + +// NewTextDecoder creates a new text decoder. Well, a decoder that uses a reader with +// no shared symbol tables, it'll work to read binary too if the binary doesn't reference +// any shared symbol tables. +func NewTextDecoder(in io.Reader) *Decoder { + return NewDecoder(NewReader(in)) +} + +// Decode decodes a value from the underlying Ion reader without any expectations +// about what it's going to get. Structs become map[string]interface{}s, Lists and +// Sexps become []interface{}s. +func (d *Decoder) Decode() (interface{}, error) { + if !d.r.Next() { + if d.r.Err() != nil { + return nil, d.r.Err() + } + return nil, ErrNoInput + } + + return d.decode() +} + +// Helper form of Decode for when you've already called Next. +func (d *Decoder) decode() (interface{}, error) { + if d.r.IsNull() { + return nil, nil + } + + switch d.r.Type() { + case BoolType: + return d.r.BoolValue() + + case IntType: + return d.decodeInt() + + case FloatType: + return d.r.FloatValue() + + case DecimalType: + return d.r.DecimalValue() + + case TimestampType: + return d.r.TimeValue() + + case StringType, SymbolType: + return d.r.StringValue() + + case BlobType, ClobType: + return d.r.ByteValue() + + case StructType: + return d.decodeMap() + + case ListType, SexpType: + return d.decodeSlice() + + default: + panic("wat?") + } +} + +func (d *Decoder) decodeInt() (interface{}, error) { + size, err := d.r.IntSize() + if err != nil { + return nil, err + } + + switch size { + case NullInt: + return nil, nil + case Int32: + return d.r.IntValue() + case Int64: + return d.r.Int64Value() + default: + return d.r.BigIntValue() + } +} + +// DecodeMap decodes an Ion struct to a go map. +func (d *Decoder) decodeMap() (map[string]interface{}, error) { + if err := d.r.StepIn(); err != nil { + return nil, err + } + + result := map[string]interface{}{} + + for d.r.Next() { + name := d.r.FieldName() + value, err := d.decode() + if err != nil { + return nil, err + } + result[name] = value + } + + if err := d.r.StepOut(); err != nil { + return nil, err + } + + return result, nil +} + +// DecodeSlice decodes an Ion list or sexp to a go slice. +func (d *Decoder) decodeSlice() ([]interface{}, error) { + if err := d.r.StepIn(); err != nil { + return nil, err + } + + result := []interface{}{} + + for d.r.Next() { + value, err := d.decode() + if err != nil { + return nil, err + } + result = append(result, value) + } + + if err := d.r.StepOut(); err != nil { + return nil, err + } + + return result, nil +} + +// DecodeTo decodes an Ion value from the underlying Ion reader into the +// value provided. +func (d *Decoder) DecodeTo(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return errors.New("ion: v must be a pointer") + } + if rv.IsNil() { + return errors.New("ion: v must not be nil") + } + + if !d.r.Next() { + if d.r.Err() != nil { + return d.r.Err() + } + return ErrNoInput + } + + return d.decodeTo(rv) +} + +func (d *Decoder) decodeTo(v reflect.Value) error { + if !v.IsValid() { + // Don't actually have anywhere to put this value; skip it. + return nil + } + + isNull := d.r.IsNull() + v = indirect(v, isNull) + if isNull { + v.Set(reflect.Zero(v.Type())) + return nil + } + + switch d.r.Type() { + case BoolType: + return d.decodeBoolTo(v) + + case IntType: + return d.decodeIntTo(v) + + case FloatType: + return d.decodeFloatTo(v) + + case DecimalType: + return d.decodeDecimalTo(v) + + case TimestampType: + return d.decodeTimestampTo(v) + + case StringType, SymbolType: + return d.decodeStringTo(v) + + case BlobType, ClobType: + return d.decodeLobTo(v) + + case StructType: + return d.decodeStructTo(v) + + case ListType, SexpType: + return d.decodeSliceTo(v) + + default: + panic("wat?") + } +} + +func (d *Decoder) decodeBoolTo(v reflect.Value) error { + val, err := d.r.BoolValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Bool: + // Too easy. + v.SetBool(val) + return nil + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode bool to %v", v.Type().String()) +} + +var bigIntType = reflect.TypeOf(big.Int{}) + +func (d *Decoder) decodeIntTo(v reflect.Value) error { + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val, err := d.r.Int64Value() + if err != nil { + return err + } + if v.OverflowInt(val) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetInt(val) + return nil + + case reflect.Uint8, reflect.Uint16, reflect.Uint32: + val, err := d.r.Int64Value() + if err != nil { + return err + } + if val < 0 || v.OverflowUint(uint64(val)) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetUint(uint64(val)) + return nil + + case reflect.Uint, reflect.Uint64, reflect.Uintptr: + val, err := d.r.BigIntValue() + if err != nil { + return err + } + if !val.IsUint64() { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + uiv := val.Uint64() + if v.OverflowUint(uiv) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetUint(uiv) + return nil + + case reflect.Struct: + if v.Type() == bigIntType { + val, err := d.r.BigIntValue() + if err != nil { + return err + } + v.Set(reflect.ValueOf(*val)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + val, err := d.decodeInt() + if err != nil { + return err + } + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode int to %v", v.Type().String()) +} + +func (d *Decoder) decodeFloatTo(v reflect.Value) error { + val, err := d.r.FloatValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Float32, reflect.Float64: + if v.OverflowFloat(val) { + return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) + } + v.SetFloat(val) + return nil + + case reflect.Struct: + if v.Type() == decimalType { + flt := strconv.FormatFloat(val, 'g', -1, 64) + dec, err := ParseDecimal(strings.Replace(flt, "e", "d", 1)) + if err != nil { + return err + } + v.Set(reflect.ValueOf(*dec)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode float to %v", v.Type().String()) +} + +func (d *Decoder) decodeDecimalTo(v reflect.Value) error { + val, err := d.r.DecimalValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Struct: + if v.Type() == decimalType { + v.Set(reflect.ValueOf(*val)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode decimal to %v", v.Type().String()) +} + +func (d *Decoder) decodeTimestampTo(v reflect.Value) error { + val, err := d.r.TimeValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Struct: + if v.Type() == timeType { + v.Set(reflect.ValueOf(val)) + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode timestamp to %v", v.Type().String()) +} + +func (d *Decoder) decodeStringTo(v reflect.Value) error { + val, err := d.r.StringValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.String: + v.SetString(val) + return nil + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode string to %v", v.Type().String()) +} + +func (d *Decoder) decodeLobTo(v reflect.Value) error { + val, err := d.r.ByteValue() + if err != nil { + return err + } + + switch v.Kind() { + case reflect.Slice: + if v.Type().Elem().Kind() == reflect.Uint8 { + v.SetBytes(val) + return nil + } + + case reflect.Array: + if v.Type().Elem().Kind() == reflect.Uint8 { + i := reflect.Copy(v, reflect.ValueOf(val)) + for ; i < v.Len(); i++ { + v.Index(i).SetUint(0) + } + return nil + } + + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(val)) + return nil + } + } + return fmt.Errorf("ion: cannot decode lob to %v", v.Type().String()) +} + +func (d *Decoder) decodeStructTo(v reflect.Value) error { + switch v.Kind() { + case reflect.Struct: + return d.decodeStructToStruct(v) + + case reflect.Map: + return d.decodeStructToMap(v) + + case reflect.Interface: + if v.NumMethod() == 0 { + m, err := d.decodeMap() + if err != nil { + return err + } + v.Set(reflect.ValueOf(m)) + return nil + } + } + return fmt.Errorf("ion: cannot decode struct to %v", v.Type().String()) +} + +func (d *Decoder) decodeStructToStruct(v reflect.Value) error { + fields := fieldsFor(v.Type()) + + if err := d.r.StepIn(); err != nil { + return err + } + + for d.r.Next() { + name := d.r.FieldName() + field := findField(fields, name) + if field != nil { + subv, err := findSubvalue(v, field) + if err != nil { + return err + } + + if err := d.decodeTo(subv); err != nil { + return err + } + } + } + + return d.r.StepOut() +} + +func findField(fields []field, name string) *field { + var f *field + for i := range fields { + ff := &fields[i] + if ff.name == name { + return ff + } + if f == nil && strings.EqualFold(ff.name, name) { + f = ff + } + } + return f +} + +func findSubvalue(v reflect.Value, f *field) (reflect.Value, error) { + for _, i := range f.path { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + if !v.CanSet() { + return reflect.Value{}, fmt.Errorf("ion: cannot set embedded pointer to unexported struct: %v", v.Type().Elem()) + } + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + v = v.Field(i) + } + return v, nil +} + +func (d *Decoder) decodeStructToMap(v reflect.Value) error { + t := v.Type() + switch t.Key().Kind() { + case reflect.String: + default: + return fmt.Errorf("ion: cannot decode struct to %v", t.String()) + } + + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + + subv := reflect.New(t.Elem()).Elem() + + if err := d.r.StepIn(); err != nil { + return err + } + + for d.r.Next() { + name := d.r.FieldName() + if err := d.decodeTo(subv); err != nil { + return err + } + + var kv reflect.Value + switch t.Key().Kind() { + case reflect.String: + kv = reflect.ValueOf(name) + default: + panic("wat?") + } + + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + return d.r.StepOut() +} + +func (d *Decoder) decodeSliceTo(v reflect.Value) error { + k := v.Kind() + + // If all we know is we need an interface{}, decode an []interface{} with + // types based on the Ion value stream. + if k == reflect.Interface && v.NumMethod() == 0 { + s, err := d.decodeSlice() + if err != nil { + return err + } + v.Set(reflect.ValueOf(s)) + return nil + } + + // Only other valid targets are arrays and slices. + if k != reflect.Array && k != reflect.Slice { + return fmt.Errorf("ion: cannot unmarshal slice to %v", v.Type().String()) + } + + if err := d.r.StepIn(); err != nil { + return err + } + + i := 0 + + // Decode values into the array or slice. + for d.r.Next() { + if v.Kind() == reflect.Slice { + // If it's a slice, we can grow it as needed. + if i >= v.Cap() { + newcap := v.Cap() + v.Cap()/2 + if newcap < 4 { + newcap = 4 + } + newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) + reflect.Copy(newv, v) + v.Set(newv) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + if err := d.decodeTo(v.Index(i)); err != nil { + return err + } + } + + i++ + } + + if err := d.r.StepOut(); err != nil { + return err + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + // Zero out any additional values. + z := reflect.Zero(v.Type().Elem()) + for ; i < v.Len(); i++ { + v.Index(i).Set(z) + } + } else { + v.SetLen(i) + } + } + + return nil +} + +// Dig in through any pointers to find the actual underlying value that we want +// to set. If wantPtr is false, the algorithm terminates at a non-ptr value (e.g., +// if passed an *int, it returns the int it points to, allocating such an int if the +// pointer is currently nil). If wantPtr is true, it terminates on a pointer to that +// value (allowing said pointer to be set to nil, generally). +func indirect(v reflect.Value, wantPtr bool) reflect.Value { + for { + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!wantPtr || e.Elem().Kind() == reflect.Ptr) { + v = e + continue + } + } + + if v.Kind() != reflect.Ptr { + break + } + + if v.Elem().Kind() != reflect.Ptr && wantPtr && v.CanSet() { + break + } + + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + v = v.Elem() + } + + return v +} diff --git a/ion/unmarshal_test.go b/ion/unmarshal_test.go new file mode 100644 index 00000000..e3e2d2d6 --- /dev/null +++ b/ion/unmarshal_test.go @@ -0,0 +1,539 @@ +package ion + +import ( + "bytes" + "math" + "math/big" + "reflect" + "testing" + "time" +) + +func TestUnmarshalBool(t *testing.T) { + test := func(str string, eval bool) { + t.Run(str, func(t *testing.T) { + var val bool + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("null", false) + test("true", true) + test("false", false) +} +func TestUnmarshalBoolPtr(t *testing.T) { + test := func(str string, eval interface{}) { + t.Run(str, func(t *testing.T) { + var bval bool + val := &bval + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if eval == nil { + if val != nil { + t.Errorf("expected , got %v", *val) + } + } else { + switch { + case val == nil: + t.Errorf("expected %v, got ", eval) + case *val != eval.(bool): + t.Errorf("expected %v, got %v", eval, *val) + } + } + }) + } + + test("null", nil) + test("null.bool", nil) + test("false", false) + test("true", true) +} + +func TestUnmarshalInt(t *testing.T) { + testInt8 := func(str string, eval int8) { + t.Run(str, func(t *testing.T) { + var val int8 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt8("null", 0) + testInt8("0", 0) + testInt8("0x7F", 0x7F) + testInt8("-0x80", -0x80) + + testInt16 := func(str string, eval int16) { + t.Run(str, func(t *testing.T) { + var val int16 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt16("0x7F", 0x7F) + testInt16("-0x80", -0x80) + testInt16("0x7FFF", 0x7FFF) + testInt16("-0x8000", -0x8000) + + testInt32 := func(str string, eval int32) { + t.Run(str, func(t *testing.T) { + var val int32 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt32("0x7FFF", 0x7FFF) + testInt32("-0x8000", -0x8000) + testInt32("0x7FFFFFFF", 0x7FFFFFFF) + testInt32("-0x80000000", -0x80000000) + + testInt := func(str string, eval int) { + t.Run(str, func(t *testing.T) { + var val int + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt("0x7FFF", 0x7FFF) + testInt("-0x8000", -0x8000) + testInt("0x7FFFFFFF", 0x7FFFFFFF) + testInt("-0x80000000", -0x80000000) + + testInt64 := func(str string, eval int64) { + t.Run(str, func(t *testing.T) { + var val int64 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testInt64("0x7FFFFFFF", 0x7FFFFFFF) + testInt64("-0x80000000", -0x80000000) + testInt64("0x7FFFFFFFFFFFFFFF", 0x7FFFFFFFFFFFFFFF) + testInt64("-0x8000000000000000", -0x8000000000000000) +} + +func TestUnmarshalUint(t *testing.T) { + testUint8 := func(str string, eval uint8) { + t.Run(str, func(t *testing.T) { + var val uint8 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint8("null", 0) + testUint8("0", 0) + testUint8("0xFF", 0xFF) + + testUint16 := func(str string, eval uint16) { + t.Run(str, func(t *testing.T) { + var val uint16 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint16("0xFF", 0xFF) + testUint16("0xFFFF", 0xFFFF) + + testUint32 := func(str string, eval uint32) { + t.Run(str, func(t *testing.T) { + var val uint32 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint32("0xFFFF", 0xFFFF) + testUint32("0xFFFFFFFF", 0xFFFFFFFF) + + testUint := func(str string, eval uint) { + t.Run(str, func(t *testing.T) { + var val uint + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint("0xFFFF", 0xFFFF) + testUint("0xFFFFFFFF", 0xFFFFFFFF) + + testUintptr := func(str string, eval uintptr) { + t.Run(str, func(t *testing.T) { + var val uintptr + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUintptr("0xFFFF", 0xFFFF) + testUintptr("0xFFFFFFFF", 0xFFFFFFFF) + + testUint64 := func(str string, eval uint64) { + t.Run(str, func(t *testing.T) { + var val uint64 + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testUint64("0xFFFFFFFF", 0xFFFFFFFF) + testUint64("0xFFFFFFFFFFFFFFFF", 0xFFFFFFFFFFFFFFFF) +} + +func TestUnmarshalBigInt(t *testing.T) { + test := func(str string, eval *big.Int) { + t.Run(str, func(t *testing.T) { + var val big.Int + err := UnmarshalStr(str, &val) + if err != nil { + t.Fatal(err) + } + + if val.Cmp(eval) != 0 { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test("null", new(big.Int)) + test("1", new(big.Int).SetUint64(1)) + test("-0xFFFFFFFFFFFFFFFF", new(big.Int).Neg(new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF))) +} + +func TestDecodeFloat(t *testing.T) { + test32 := func(str string, eval float32) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val float32 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test32("null", 0) + test32("1e0", 1) + test32("1e38", 1e38) + test32("+inf", float32(math.Inf(1))) + + test64 := func(str string, eval float64) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val float64 + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test64("1e0", 1) + test64("1e308", 1e308) + test64("+inf", math.Inf(1)) +} + +func TestDecodeDecimal(t *testing.T) { + test := func(str string, eval *Decimal) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val *Decimal + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if !val.Equal(eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("1e10", MustParseDecimal("1d10")) + test("1.20", MustParseDecimal("1.20")) +} + +func TestDecodeTimeTo(t *testing.T) { + test := func(str string, eval time.Time) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val time.Time + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + test("null", time.Time{}) + test("2020T", time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) +} + +func TestDecodeStringTo(t *testing.T) { + test := func(str string, eval string) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val string + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if val != eval { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("null", "") + test("hello", "hello") + test("\"hello\"", "hello") +} + +func TestDecodeLobTo(t *testing.T) { + testSlice := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val []byte + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testSlice("null", nil) + testSlice("{{}}", []byte{}) + testSlice("{{aGVsbG8=}}", []byte("hello")) + testSlice("{{'''hello'''}}", []byte("hello")) + + testArray := func(str string, eval []byte) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + + var val [8]byte + err := d.DecodeTo(&val) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(val[:], eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + testArray("null", make([]byte, 8)) + testArray("{{aGVsbG8=}}", append([]byte("hello"), []byte{0, 0, 0}...)) +} + +func TestDecodeStructTo(t *testing.T) { + test := func(str string, val, eval interface{}) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + err := d.DecodeTo(val) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + type foo struct { + Foo string + Baz int `json:"bar"` + } + + test("{}", &struct{}{}, &struct{}{}) + test("{bogus:(ignore me)}", &foo{}, &foo{}) + test("{foo:bar}", &foo{}, &foo{"bar", 0}) + test("{bar:42}", &foo{}, &foo{"", 42}) + test("{foo:bar,bar:42,bogus:(ignore me)}", &foo{}, &foo{"bar", 42}) + + test("{}", &map[string]string{}, &map[string]string{}) + test("{foo:bar}", &map[string]string{}, &map[string]string{"foo": "bar"}) + test("{a:4,b:2}", &map[string]int{}, &map[string]int{"a": 4, "b": 2}) +} + +func TestDecodeListTo(t *testing.T) { + test := func(str string, val, eval interface{}) { + t.Run(str, func(t *testing.T) { + d := NewDecoder(NewReaderStr(str)) + err := d.DecodeTo(val) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + f := false + pf := &f + ppf := &pf + + test("[]", &[]bool{}, &[]bool{}) + test("[]", &[]bool{true}, &[]bool{}) + + test("[false]", &[]bool{}, &[]bool{false}) + test("[false]", &[]*bool{}, &[]*bool{pf}) + test("[false,false]", &[]**bool{}, &[]**bool{ppf, ppf}) + + test("[true,false]", &[]interface{}{}, &[]interface{}{true, false}) + + var i interface{} + var ei interface{} = []interface{}{true, false} + test("[true,false]", &i, &ei) +} + +func TestDecode(t *testing.T) { + test := func(data string, eval interface{}) { + t.Run(data, func(t *testing.T) { + d := NewDecoder(NewReaderStr(data)) + val, err := d.Decode() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(val, eval) { + t.Errorf("expected %v, got %v", eval, val) + } + }) + } + + test("null", nil) + test("null.null", nil) + + test("null.bool", nil) + test("true", true) + test("false", false) + + test("null.int", nil) + test("0", int(0)) + test("2147483647", math.MaxInt32) + test("-2147483648", math.MinInt32) + test("2147483648", int64(math.MaxInt32)+1) + test("-2147483649", int64(math.MinInt32)-1) + test("9223372036854775808", new(big.Int).SetUint64(math.MaxInt64+1)) + + test("0e0", float64(0.0)) + test("1e100", float64(1e100)) + + test("0.", MustParseDecimal("0.")) + + test("2020T", time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) + + test("hello", "hello") + test("\"hello\"", "hello") + + test("null.blob", nil) + test("{{}}", []byte{}) + test("{{aGVsbG8=}}", []byte("hello")) + + test("null.clob", nil) + test("{{''''''}}", []byte{}) + test("{{'''hello'''}}", []byte("hello")) + + test("null.struct", nil) + test("{}", map[string]interface{}{}) + test("{a:1,b:two}", map[string]interface{}{ + "a": 1, + "b": "two", + }) + + test("null.list", nil) + test("[]", []interface{}{}) + test("[1, two]", []interface{}{1, "two"}) + + test("null.sexp", nil) + test("()", []interface{}{}) + test("(1 + two)", []interface{}{1, "+", "two"}) +} diff --git a/ion/writer.go b/ion/writer.go new file mode 100644 index 00000000..2a85216a --- /dev/null +++ b/ion/writer.go @@ -0,0 +1,165 @@ +package ion + +import ( + "errors" + "io" + "math/big" + "time" +) + +// A Writer writes a stream of Ion values. +// +// The various Write methods write atomic values to the current output stream. The +// Begin methods begin writing a list, sexp, or struct respectively. Subsequent +// calls to Write will write values inside of the container until a matching +// End method is called. +// +// var w Writer +// w.BeginSexp() +// { +// w.WriteInt(1) +// w.WriteSymbol("+") +// w.WriteInt(1) +// } +// w.EndSexp() +// +// When writing values inside a struct, the FieldName method must be called before +// each value to set the value's field name. The Annotation method may likewise +// be called before writing any value to add an annotation to the value. +// +// var w Writer +// w.Annotation("user") +// w.BeginStruct() +// { +// w.FieldName("id") +// w.WriteString("qu33nb33") +// w.FieldName("name") +// w.WriteString("Beyoncé") +// } +// w.EndStruct() +// +// When you're done writing values, you should call Finish to ensure everything has +// been flushed from in-memory buffers. While individual methods all return an error +// on failure, implementations will remember any errors, no-op subsequent calls, and +// return the previous error. This lets you keep code a bit cleaner by only checking +// the return value of the final method call (generally Finish). +// +// var w Writer +// writeSomeStuff(w) +// if err := w.Finish(); err != nil { +// return err +// } +// +type Writer interface { + + // FieldName sets the field name for the next value written. + FieldName(val string) error + + // Annotation adds a single annotation to the next value written. + Annotation(val string) error + + // Annotations adds multiple annotations to the next value written. + Annotations(vals ...string) error + + // WriteNull writes an untyped null value. + WriteNull() error + // WriteNullType writes a null value with a type qualifier, e.g. null.bool. + WriteNullType(t Type) error + + // WriteBool writes a boolean value. + WriteBool(val bool) error + + // WriteInt writes an integer value. + WriteInt(val int64) error + // WriteUint writes an unsigned integer value. + WriteUint(val uint64) error + // WriteBigInt writes a big integer value. + WriteBigInt(val *big.Int) error + // WriteFloat writes a floating-point value. + WriteFloat(val float64) error + // WriteDecimal writes an arbitrary-precision decimal value. + WriteDecimal(val *Decimal) error + + // WriteTimestamp writes a timestamp value. + WriteTimestamp(val time.Time) error + + // WriteSymbol writes a symbol value. + WriteSymbol(val string) error + // WriteString writes a string value. + WriteString(val string) error + + // WriteClob writes a clob value. + WriteClob(val []byte) error + // WriteBlob writes a blob value. + WriteBlob(val []byte) error + + // BeginList begins writing a list value. + BeginList() error + // EndList finishes writing a list value. + EndList() error + + // BeginSexp begins writing an s-expression value. + BeginSexp() error + // EndSexp finishes writing an s-expression value. + EndSexp() error + + // BeginStruct begins writing a struct value. + BeginStruct() error + // EndStruct finishes writing a struct value. + EndStruct() error + + // Finish finishes writing values and flushes any buffered data. + Finish() error +} + +// A writer holds shared stuff for all writers. +type writer struct { + out io.Writer + ctx ctxstack + err error + + fieldName string + annotations []string +} + +// FieldName sets the field name for the next value written. +// It may only be called while writing a struct. +func (w *writer) FieldName(val string) error { + if w.err != nil { + return w.err + } + if !w.inStruct() { + w.err = errors.New("ion: Writer.FieldName called when not writing a struct") + return w.err + } + + w.fieldName = val + return nil +} + +// Annotation adds an annotation to the next value written. +func (w *writer) Annotation(val string) error { + if w.err == nil { + w.annotations = append(w.annotations, val) + } + return w.err +} + +// Annotations adds one or more annotations to the next value written. +func (w *writer) Annotations(val ...string) error { + if w.err == nil { + w.annotations = append(w.annotations, val...) + } + return w.err +} + +// InStruct returns true if we're currently writing a struct. +func (w *writer) inStruct() bool { + return w.ctx.peek() == ctxInStruct +} + +// Clear clears field name and annotations after writing a value. +func (w *writer) clear() { + w.fieldName = "" + w.annotations = nil +} From 684e59acdf5ff8beec48d7cc690ca1c1511bad26 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 19:56:33 -0700 Subject: [PATCH 43/56] moved files into ion folder --- binaryreader.go | 487 ---------------- binaryreader_test.go | 454 --------------- binarywriter.go | 562 ------------------- binarywriter_test.go | 387 ------------- bits.go | 285 ---------- bits_test.go | 205 ------- bitstream.go | 935 ------------------------------- bitstream_test.go | 111 ---- buf.go | 119 ---- buf_test.go | 79 --- catalog.go | 88 --- catalog_test.go | 62 --- consts.go | 52 -- ctx.go | 65 --- decimal.go | 342 ------------ decimal_test.go | 312 ----------- err.go | 88 --- fields.go | 122 ---- marshal.go | 338 ----------- marshal_test.go | 162 ------ reader.go | 388 ------------- reader_test.go | 127 ----- skipper.go | 863 ---------------------------- skipper_test.go | 178 ------ symboltable.go | 475 ---------------- symboltable_test.go | 181 ------ textreader.go | 658 ---------------------- textreader_test.go | 788 -------------------------- textutils.go | 382 ------------- textutils_test.go | 164 ------ textwriter.go | 359 ------------ textwriter_test.go | 351 ------------ tokenizer.go | 1265 ------------------------------------------ tokenizer_test.go | 571 ------------------- type.go | 125 ----- type_test.go | 21 - unmarshal.go | 673 ---------------------- unmarshal_test.go | 539 ------------------ writer.go | 165 ------ 39 files changed, 13528 deletions(-) delete mode 100644 binaryreader.go delete mode 100644 binaryreader_test.go delete mode 100644 binarywriter.go delete mode 100644 binarywriter_test.go delete mode 100644 bits.go delete mode 100644 bits_test.go delete mode 100644 bitstream.go delete mode 100644 bitstream_test.go delete mode 100644 buf.go delete mode 100644 buf_test.go delete mode 100644 catalog.go delete mode 100644 catalog_test.go delete mode 100644 consts.go delete mode 100644 ctx.go delete mode 100644 decimal.go delete mode 100644 decimal_test.go delete mode 100644 err.go delete mode 100644 fields.go delete mode 100644 marshal.go delete mode 100644 marshal_test.go delete mode 100644 reader.go delete mode 100644 reader_test.go delete mode 100644 skipper.go delete mode 100644 skipper_test.go delete mode 100644 symboltable.go delete mode 100644 symboltable_test.go delete mode 100644 textreader.go delete mode 100644 textreader_test.go delete mode 100644 textutils.go delete mode 100644 textutils_test.go delete mode 100644 textwriter.go delete mode 100644 textwriter_test.go delete mode 100644 tokenizer.go delete mode 100644 tokenizer_test.go delete mode 100644 type.go delete mode 100644 type_test.go delete mode 100644 unmarshal.go delete mode 100644 unmarshal_test.go delete mode 100644 writer.go diff --git a/binaryreader.go b/binaryreader.go deleted file mode 100644 index 9a64a7c2..00000000 --- a/binaryreader.go +++ /dev/null @@ -1,487 +0,0 @@ -package ion - -import ( - "bufio" - "fmt" -) - -// A binaryReader reads binary Ion. -type binaryReader struct { - reader - - bits bitstream - cat Catalog - lst SymbolTable -} - -func newBinaryReaderBuf(in *bufio.Reader, cat Catalog) Reader { - r := &binaryReader{ - cat: cat, - } - r.bits.Init(in) - return r -} - -// SymbolTable returns the current symbol table. -func (r *binaryReader) SymbolTable() SymbolTable { - return r.lst -} - -// Next moves the reader to the next value. -func (r *binaryReader) Next() bool { - if r.eof || r.err != nil { - return false - } - - r.clear() - - done := false - for !done { - done, r.err = r.next() - if r.err != nil { - return false - } - } - - return !r.eof -} - -// Next consumes the next raw value from the stream, returning true if it -// represents a user-facing value and false if it does not. -func (r *binaryReader) next() (bool, error) { - if err := r.bits.Next(); err != nil { - return false, err - } - - code := r.bits.Code() - switch code { - case bitcodeEOF: - r.eof = true - return true, nil - - case bitcodeBVM: - err := r.readBVM() - return false, err - - case bitcodeFieldID: - err := r.readFieldName() - return false, err - - case bitcodeAnnotation: - err := r.readAnnotations() - return false, err - - case bitcodeNull: - if !r.bits.IsNull() { - // NOP padding; skip it and keep going. - err := r.bits.SkipValue() - return false, err - } - r.valueType = NullType - return true, nil - - case bitcodeFalse, bitcodeTrue: - r.valueType = BoolType - if !r.bits.IsNull() { - r.value = (r.bits.Code() == bitcodeTrue) - } - return true, nil - - case bitcodeInt, bitcodeNegInt: - r.valueType = IntType - if !r.bits.IsNull() { - val, err := r.bits.ReadInt() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeFloat: - r.valueType = FloatType - if !r.bits.IsNull() { - val, err := r.bits.ReadFloat() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeDecimal: - r.valueType = DecimalType - if !r.bits.IsNull() { - val, err := r.bits.ReadDecimal() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeTimestamp: - r.valueType = TimestampType - if !r.bits.IsNull() { - val, err := r.bits.ReadTimestamp() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeSymbol: - r.valueType = SymbolType - if !r.bits.IsNull() { - id, err := r.bits.ReadSymbolID() - if err != nil { - return false, err - } - r.value = r.resolve(id) - } - return true, nil - - case bitcodeString: - r.valueType = StringType - if !r.bits.IsNull() { - val, err := r.bits.ReadString() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeClob: - r.valueType = ClobType - if !r.bits.IsNull() { - val, err := r.bits.ReadBytes() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeBlob: - r.valueType = BlobType - if !r.bits.IsNull() { - val, err := r.bits.ReadBytes() - if err != nil { - return false, err - } - r.value = val - } - return true, nil - - case bitcodeList: - r.valueType = ListType - if !r.bits.IsNull() { - r.value = ListType - } - return true, nil - - case bitcodeSexp: - r.valueType = SexpType - if !r.bits.IsNull() { - r.value = SexpType - } - return true, nil - - case bitcodeStruct: - r.valueType = StructType - if !r.bits.IsNull() { - r.value = StructType - } - - // If it's a local symbol table, install it and keep going. - if r.ctx.peek() == ctxAtTopLevel && isIonSymbolTable(r.annotations) { - err := r.readLocalSymbolTable() - return false, err - } - - return true, nil - } - panic(fmt.Sprintf("invalid bitcode %v", code)) -} - -func isIonSymbolTable(as []string) bool { - return len(as) > 0 && as[0] == "$ion_symbol_table" -} - -// ReadBVM reads a BVM, validates it, and resets the local symbol table. -func (r *binaryReader) readBVM() error { - major, minor, err := r.bits.ReadBVM() - if err != nil { - return err - } - - switch major { - case 1: - switch minor { - case 0: - r.lst = V1SystemSymbolTable - return nil - } - } - - return &UnsupportedVersionError{ - int(major), - int(minor), - r.bits.Pos() - 4, - } -} - -// ReadLocalSymbolTable reads and installs a new local symbol table. -func (r *binaryReader) readLocalSymbolTable() error { - if r.IsNull() { - r.clear() - r.lst = V1SystemSymbolTable - return nil - } - - if err := r.StepIn(); err != nil { - return err - } - - imps := []SharedSymbolTable{} - syms := []string{} - - for r.Next() { - var err error - switch r.FieldName() { - case "imports": - imps, err = r.readImports() - case "symbols": - syms, err = r.readSymbols() - } - if err != nil { - return err - } - } - - if err := r.StepOut(); err != nil { - return err - } - - r.lst = NewLocalSymbolTable(imps, syms) - return nil -} - -// ReadImports reads the imports field of a local symbol table. -func (r *binaryReader) readImports() ([]SharedSymbolTable, error) { - if r.valueType == SymbolType && r.value == "$ion_symbol_table" { - // Special case that imports the current local symbol table. - if r.lst == nil || r.lst == V1SystemSymbolTable { - return nil, nil - } - - imps := r.lst.Imports() - lsst := NewSharedSymbolTable("", 0, r.lst.Symbols()) - return append(imps, lsst), nil - } - - if r.Type() != ListType || r.IsNull() { - return nil, nil - } - if err := r.StepIn(); err != nil { - return nil, err - } - - imps := []SharedSymbolTable{} - for r.Next() { - imp, err := r.readImport() - if err != nil { - return nil, err - } - if imp != nil { - imps = append(imps, imp) - } - } - - err := r.StepOut() - return imps, err -} - -// ReadImport reads an import definition. -func (r *binaryReader) readImport() (SharedSymbolTable, error) { - if r.Type() != StructType || r.IsNull() { - return nil, nil - } - if err := r.StepIn(); err != nil { - return nil, err - } - - name := "" - version := 0 - maxID := uint64(0) - - for r.Next() { - var err error - switch r.FieldName() { - case "name": - if r.Type() == StringType { - name, err = r.StringValue() - } - case "version": - if r.Type() == IntType { - version, err = r.IntValue() - } - case "max_id": - if r.Type() == IntType { - var i int64 - i, err = r.Int64Value() - if i < 0 { - i = 0 - } - maxID = uint64(i) - } - } - if err != nil { - return nil, err - } - } - - if err := r.StepOut(); err != nil { - return nil, err - } - - if name == "" || name == "$ion" { - return nil, nil - } - if version < 1 { - version = 1 - } - - var imp SharedSymbolTable - if r.cat != nil { - imp = r.cat.FindExact(name, version) - if imp == nil { - imp = r.cat.FindLatest(name) - } - } - - if maxID == 0 { - if imp == nil || version != imp.Version() { - return nil, fmt.Errorf("ion: import of shared table %v/%v lacks a valid max_id, but an exact "+ - "match was not found in the catalog", name, version) - } - maxID = imp.MaxID() - } - - if imp == nil { - imp = &bogusSST{ - name: name, - version: version, - maxID: maxID, - } - } else { - imp = imp.Adjust(maxID) - } - - return imp, nil -} - -// ReadSymbols reads the symbols from a symbol table. -func (r *binaryReader) readSymbols() ([]string, error) { - if r.Type() != ListType { - return nil, nil - } - if err := r.StepIn(); err != nil { - return nil, err - } - - syms := []string{} - for r.Next() { - if r.Type() == StringType { - sym, err := r.StringValue() - if err != nil { - return nil, err - } - syms = append(syms, sym) - } else { - syms = append(syms, "") - } - } - - err := r.StepOut() - return syms, err -} - -// ReadFieldName reads and resolves a field name. -func (r *binaryReader) readFieldName() error { - id, err := r.bits.ReadFieldID() - if err != nil { - return err - } - - r.fieldName = r.resolve(id) - return nil -} - -// ReadAnnotations reads and resolves a set of annotations. -func (r *binaryReader) readAnnotations() error { - ids, err := r.bits.ReadAnnotationIDs() - if err != nil { - return err - } - - as := make([]string, len(ids)) - for i, id := range ids { - as[i] = r.resolve(id) - } - - r.annotations = as - return nil -} - -// Resolve resolves a symbol ID to a symbol value (possibly ${id} if we're -// missing the appropriate symbol table). -func (r *binaryReader) resolve(id uint64) string { - s, ok := r.lst.FindByID(id) - if !ok { - return fmt.Sprintf("$%v", id) - } - return s -} - -// StepIn steps in to a container-type value -func (r *binaryReader) StepIn() error { - if r.err != nil { - return r.err - } - - if r.valueType != ListType && r.valueType != SexpType && r.valueType != StructType { - return &UsageError{"Reader.StepIn", fmt.Sprintf("cannot step in to a %v", r.valueType)} - } - if r.value == nil { - return &UsageError{"Reader.StepIn", "cannot step in to a null container"} - } - - r.ctx.push(containerTypeToCtx(r.valueType)) - r.clear() - r.bits.StepIn() - - return nil -} - -// StepOut steps out of a container-type value. -func (r *binaryReader) StepOut() error { - if r.err != nil { - return r.err - } - if r.ctx.peek() == ctxAtTopLevel { - return &UsageError{"Reader.StepOut", "cannot step out of top-level datagram"} - } - - if err := r.bits.StepOut(); err != nil { - return err - } - - r.clear() - r.ctx.pop() - r.eof = false - - return nil -} diff --git a/binaryreader_test.go b/binaryreader_test.go deleted file mode 100644 index 4f0901b5..00000000 --- a/binaryreader_test.go +++ /dev/null @@ -1,454 +0,0 @@ -package ion - -import ( - "fmt" - "math" - "math/big" - "testing" - "time" -) - -func TestReadBadBVMs(t *testing.T) { - t.Run("E00200E9", func(t *testing.T) { - // Need a good first one or we'll get sent to the text reader. - r := NewReaderBytes([]byte{0xE0, 0x01, 0x00, 0xEA, 0xE0, 0x02, 0x00, 0xE9}) - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() == nil { - t.Fatal("err is nil") - } - }) - - t.Run("E00200EA", func(t *testing.T) { - r := NewReaderBytes([]byte{0xE0, 0x02, 0x00, 0xEA}) - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() == nil { - t.Fatal("err is nil") - } - - uve, ok := r.Err().(*UnsupportedVersionError) - if !ok { - t.Fatal("err is not an UnsupportedVersionError") - } - if uve.Major != 2 { - t.Errorf("expected major=2, got %v", uve.Major) - } - if uve.Minor != 0 { - t.Errorf("expected minor=0, got %v", uve.Minor) - } - }) -} - -func TestReadNullLST(t *testing.T) { - ion := []byte{ - 0xE0, 0x01, 0x00, 0xEA, - 0xE4, 0x82, 0x83, 0x87, 0xDF, - 0x71, 0x09, - } - r := NewReaderBytes(ion) - _symbol(t, r, "$ion_shared_symbol_table") - _eof(t, r) -} - -func TestReadEmptyLST(t *testing.T) { - ion := []byte{ - 0xE0, 0x01, 0x00, 0xEA, - 0xE4, 0x82, 0x83, 0x87, 0xD0, - 0x71, 0x09, - } - r := NewReaderBytes(ion) - _symbol(t, r, "$ion_shared_symbol_table") - _eof(t, r) -} - -func TestReadBadLST(t *testing.T) { - ion := []byte{ - 0xE0, 0x01, 0x00, 0xEA, - 0xE3, 0x81, 0x83, 0xD9, - 0x86, 0xB7, 0xD6, // imports:[{ - 0x84, 0x81, 'a', // name: "a", - 0x85, 0x21, 0x01, // version: 1}]} - 0x0F, // null - } - r := NewReaderBytes(ion) - if r.Next() { - t.Fatal("next returned true") - } - if r.Err() == nil { - t.Fatal("err is nil") - } -} - -func TestReadMultipleLSTs(t *testing.T) { - r := readBinary([]byte{ - 0x71, 0x0B, // $11 - 0x71, 0x6F, // bar - 0xE3, 0x81, 0x83, 0xDF, // $ion_symbol_table::null.struct - 0xEE, 0x8F, 0x81, 0x83, 0xDD, // $ion_symbol_table::{ - 0x86, 0x71, 0x03, // imports: $ion_symbol_table, - 0x87, 0xB8, // symbols:[ - 0x83, 'f', 'o', 'o', // "foo" - 0x83, 'b', 'a', 'r', // "bar" ]} - 0x71, 0x0B, // bar - 0x71, 0x0C, // $12 - 0x71, 0x6F, // $111 - 0xEC, 0x81, 0x83, 0xD9, // $ion_symbol_table::{ - 0x86, 0x71, 0x03, // imports: $ion_symbol_table - 0x87, 0xB4, // symbols:[ - 0x83, 'b', 'a', 'z', // "baz" ]} - 0x71, 0x0B, // bar - 0x71, 0x0C, // baz - }) - _symbol(t, r, "$11") - _symbol(t, r, "bar") - - _symbol(t, r, "bar") - _symbol(t, r, "$12") - _symbol(t, r, "$111") - - _symbol(t, r, "bar") - _symbol(t, r, "baz") - _eof(t, r) -} - -func TestReadBinaryLST(t *testing.T) { - r := readBinary([]byte{0x0F}) - _next(t, r, NullType) - - lst := r.SymbolTable() - if lst == nil { - t.Fatal("symboltable is nil") - } - - if lst.MaxID() != 111 { - t.Errorf("expected maxid=111, got %v", lst.MaxID()) - } - - if _, ok := lst.FindByID(109); ok { - t.Error("found a symbol for $109") - } - - sym, ok := lst.FindByID(111) - if !ok { - t.Fatal("no symbol defined for $111") - } - if sym != "bar" { - t.Errorf("expected $111=bar, got %v", sym) - } - - id, ok := lst.FindByName("foo") - if !ok { - t.Fatal("no id defined for foo") - } - if id != 110 { - t.Errorf("expected foo=$110, got $%v", id) - } - - if _, ok := lst.FindByID(112); ok { - t.Error("found a symbol for $112") - } - - if _, ok := lst.FindByName("bogus"); ok { - t.Error("found a symbol for bogus") - } -} - -func TestReadBinaryStructs(t *testing.T) { - r := readBinary([]byte{ - 0xDF, // null.struct - 0xD0, // {} - 0xEA, 0x81, 0xEE, 0xD7, // foo::{ - 0x84, 0xE3, 0x81, 0xEF, 0xD0, // name:bar::{}, - 0x88, 0x20, // max_id:0 - // } - }) - - _null(t, r, StructType) - _struct(t, r, func(t *testing.T, r Reader) { - _eof(t, r) - }) - _structAF(t, r, "", []string{"foo"}, func(t *testing.T, r Reader) { - _structAF(t, r, "name", []string{"bar"}, func(t *testing.T, r Reader) { - _eof(t, r) - }) - _intAF(t, r, "max_id", nil, 0) - }) - _eof(t, r) -} - -func TestReadBinarySexps(t *testing.T) { - r := readBinary([]byte{ - 0xCF, - 0xC3, 0xC1, 0xC0, 0xC0, - }) - - _null(t, r, SexpType) - _sexp(t, r, func(t *testing.T, r Reader) { - _sexp(t, r, func(t *testing.T, r Reader) { - _sexp(t, r, func(t *testing.T, r Reader) { - _eof(t, r) - }) - }) - _sexp(t, r, func(t *testing.T, r Reader) { - _eof(t, r) - }) - _eof(t, r) - }) - _eof(t, r) -} - -func TestReadBinaryLists(t *testing.T) { - r := readBinary([]byte{ - 0xBF, - 0xB3, 0xB1, 0xB0, 0xB0, - }) - - _null(t, r, ListType) - _list(t, r, func(t *testing.T, r Reader) { - _list(t, r, func(t *testing.T, r Reader) { - _list(t, r, func(t *testing.T, r Reader) { - _eof(t, r) - }) - }) - _list(t, r, func(t *testing.T, r Reader) { - _eof(t, r) - }) - _eof(t, r) - }) - _eof(t, r) -} - -func TestReadBinaryBlobs(t *testing.T) { - r := readBinary([]byte{ - 0xAF, - 0xA0, - 0xA1, 'a', - 0xAE, 0x96, - 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', - ' ', 'l', 'o', 'n', 'g', 'e', 'r', - }) - - _null(t, r, BlobType) - _blob(t, r, []byte("")) - _blob(t, r, []byte("a")) - _blob(t, r, []byte("hello world but longer")) - _eof(t, r) -} - -func TestReadBinaryClobs(t *testing.T) { - r := readBinary([]byte{ - 0x9F, - 0x90, // {{}} - 0x91, 'a', // {{a}} - 0x9E, 0x96, - 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', - ' ', 'l', 'o', 'n', 'g', 'e', 'r', - }) - - _null(t, r, ClobType) - _clob(t, r, []byte("")) - _clob(t, r, []byte("a")) - _clob(t, r, []byte("hello world but longer")) - _eof(t, r) -} - -func TestReadBinaryStrings(t *testing.T) { - r := readBinary([]byte{ - 0x8F, - 0x80, // "" - 0x81, 'a', // "a" - 0x8E, 0x96, - 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', ' ', 'b', 'u', 't', - ' ', 'l', 'o', 'n', 'g', 'e', 'r', - }) - - _null(t, r, StringType) - _string(t, r, "") - _string(t, r, "a") - _string(t, r, "hello world but longer") - _eof(t, r) -} - -func TestReadBinarySymbols(t *testing.T) { - r := readBinary([]byte{ - 0x7F, - 0x70, // $0 - 0x71, 0x01, // $ion - 0x71, 0x0A, // $10 - 0x71, 0x6E, // foo - 0xE4, 0x81, 0xEE, 0x71, 0x6F, // foo::bar - 0x78, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // ${maxint64} - }) - - _null(t, r, SymbolType) - _symbol(t, r, "$0") - _symbol(t, r, "$ion") - _symbol(t, r, "$10") - _symbol(t, r, "foo") - _symbolAF(t, r, "", []string{"foo"}, "bar") - _symbol(t, r, fmt.Sprintf("$%v", uint64(math.MaxUint64))) - _eof(t, r) -} - -func TestReadBinaryTimestamps(t *testing.T) { - r := readBinary([]byte{ - 0x6F, - 0x62, 0x80, 0x81, // 0001T - 0x63, 0x80, 0x81, 0x81, // 0001-01T - 0x64, 0x80, 0x81, 0x81, 0x81, // 0001-01-01T - 0x66, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, // 0001-01-01T00:00Z - 0x67, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80, // 0001-01-01T00:00:00Z - 0x6E, 0x8E, // 0x0E-bit timestamp - 0x04, 0xD8, // offset: +600 minutes (+10:00) - 0x0F, 0xE3, // year: 2019 - 0x88, // month: 8 - 0x84, // day: 4 - 0x88, // hour: 8 utc (18 local) - 0x8F, // minute: 15 - 0xAB, // second: 43 - 0xC9, // exp: -9 - 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 - }) - - _null(t, r, TimestampType) - - for i := 0; i < 5; i++ { - _timestamp(t, r, time.Time{}) - } - - nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") - _timestamp(t, r, nowish) - _eof(t, r) -} - -func TestReadBinaryDecimals(t *testing.T) { - r := readBinary([]byte{ - 0x50, // 0. - 0x5F, // null.decimal - 0x51, 0xC3, // 0.000, aka 0 x 10^-3 - 0x53, 0xC3, 0x03, 0xE8, // 1.000, aka 1000 x 10^-3 - 0x53, 0xC3, 0x83, 0xE8, // -1.000, aka -1000 x 10^-3 - 0x53, 0x00, 0xE4, 0x01, // 1d100, aka 1 * 10^100 - 0x53, 0x00, 0xE4, 0x81, // -1d100, aka -1 * 10^100 - }) - - _decimal(t, r, MustParseDecimal("0.")) - _null(t, r, DecimalType) - _decimal(t, r, MustParseDecimal("0.000")) - _decimal(t, r, MustParseDecimal("1.000")) - _decimal(t, r, MustParseDecimal("-1.000")) - _decimal(t, r, MustParseDecimal("1d100")) - _decimal(t, r, MustParseDecimal("-1d100")) - _eof(t, r) -} - -func TestReadBinaryFloats(t *testing.T) { - r := readBinary([]byte{ - 0x40, // 0 - 0x4F, // null.float - 0x44, 0x7F, 0x7F, 0xFF, 0xFF, // MaxFloat32 - 0x44, 0xFF, 0x7F, 0xFF, 0xFF, // -MaxFloat32 - 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 - 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 - 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf - 0x48, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -inf - 0x48, 0x7F, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // NaN - }) - - _float(t, r, 0) - _null(t, r, FloatType) - _float(t, r, math.MaxFloat32) - _float(t, r, -math.MaxFloat32) - _float(t, r, math.MaxFloat64) - _float(t, r, -math.MaxFloat64) - _float(t, r, math.Inf(1)) - _float(t, r, math.Inf(-1)) - _float(t, r, math.NaN()) - _eof(t, r) -} - -func TestReadBinaryInts(t *testing.T) { - r := readBinary([]byte{ - 0x20, // 0 - 0x2F, // null.int - 0x21, 0x01, // 1 - 0x31, 0x01, // -1 - 0x28, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // 0x7FFFFFFFFFFFFFFF - 0x38, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -0x7FFFFFFFFFFFFFFF - 0x28, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0x8000000000000000 - 0x38, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // 0x8000000000000000 - }) - - _int(t, r, 0) - _null(t, r, IntType) - _int(t, r, 1) - _int(t, r, -1) - _int64(t, r, math.MaxInt64) - _int64(t, r, -math.MaxInt64) - - _uint(t, r, math.MaxInt64+1) - - i := new(big.Int).SetUint64(math.MaxInt64 + 1) - _bigInt(t, r, new(big.Int).Neg(i)) - - _eof(t, r) -} - -func TestReadBinaryBools(t *testing.T) { - r := readBinary([]byte{ - 0x10, // false - 0x11, // true - 0x1F, // null.bool - }) - - _bool(t, r, false) - _bool(t, r, true) - _null(t, r, BoolType) - _eof(t, r) -} - -func TestReadBinaryNulls(t *testing.T) { - r := readBinary([]byte{ - 0x00, // 1-byte NOP - 0x0F, // null - 0x01, 0xFF, // 2-byte NOP - 0xE3, 0x81, 0x81, 0x0F, // $ion::null - 0x0E, 0x8F, // 16-byte NOP - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xE4, 0x82, 0xEE, 0xEF, 0x0F, // foo::bar::null - }) - - _null(t, r, NullType) - _nullAF(t, r, NullType, "", []string{"$ion"}) - _nullAF(t, r, NullType, "", []string{"foo", "bar"}) - _eof(t, r) -} - -func TestReadEmptyBinary(t *testing.T) { - r := NewReaderBytes([]byte{0xE0, 0x01, 0x00, 0xEA}) - _eof(t, r) - _eof(t, r) -} - -func readBinary(ion []byte) Reader { - prefix := []byte{ - 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 - 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ - 0x86, 0xBE, 0x8E, // imports:[ - 0xDD, // { - 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" - 0x85, 0x21, 0x2A, // version: 42 - 0x88, 0x21, 0x64, // max_id: 100 - // }] - 0x87, 0xB8, // symbols: [ - 0x83, 'f', 'o', 'o', // "foo" - 0x83, 'b', 'a', 'r', // "bar" - // ] - // } - } - return NewReaderBytes(append(prefix, ion...)) -} diff --git a/binarywriter.go b/binarywriter.go deleted file mode 100644 index 0115faf6..00000000 --- a/binarywriter.go +++ /dev/null @@ -1,562 +0,0 @@ -package ion - -import ( - "encoding/binary" - "fmt" - "io" - "math" - "math/big" - "strconv" - "strings" - "time" -) - -// A binaryWriter writes binary ion. -type binaryWriter struct { - writer - bufs bufstack - - lst SymbolTable - lstb SymbolTableBuilder - - wroteLST bool -} - -// NewBinaryWriter creates a new binary writer that will construct a -// local symbol table as it is written to. -func NewBinaryWriter(out io.Writer, sts ...SharedSymbolTable) Writer { - w := &binaryWriter{ - writer: writer{ - out: out, - }, - lstb: NewSymbolTableBuilder(sts...), - } - w.bufs.push(&datagram{}) - return w -} - -// NewBinaryWriterLST creates a new binary writer with a pre-built local -// symbol table. -func NewBinaryWriterLST(out io.Writer, lst SymbolTable) Writer { - return &binaryWriter{ - writer: writer{ - out: out, - }, - lst: lst, - } -} - -// WriteNull writes an untyped null. -func (w *binaryWriter) WriteNull() error { - return w.writeValue("Writer.WriteNull", []byte{0x0F}) -} - -// WriteNullType writes a typed null. -func (w *binaryWriter) WriteNullType(t Type) error { - return w.writeValue("Writer.WriteNullType", []byte{binaryNulls[t]}) -} - -// WriteBool writes a bool. -func (w *binaryWriter) WriteBool(val bool) error { - b := byte(0x10) - if val { - b = 0x11 - } - return w.writeValue("Writer.WriteBool", []byte{b}) -} - -// WriteInt writes an integer. -func (w *binaryWriter) WriteInt(val int64) error { - if val == 0 { - return w.writeValue("Writer.WriteInt", []byte{0x20}) - } - - code := byte(0x20) - mag := uint64(val) - - if val < 0 { - code = 0x30 - mag = uint64(-val) - } - - len := uintLen(mag) - buflen := len + tagLen(len) - - buf := make([]byte, 0, buflen) - buf = appendTag(buf, code, len) - buf = appendUint(buf, mag) - - return w.writeValue("Writer.WriteInt", buf) -} - -// WriteUint writes an unsigned integer. -func (w *binaryWriter) WriteUint(val uint64) error { - if val == 0 { - return w.writeValue("Writer.WriteUint", []byte{0x20}) - } - - len := uintLen(val) - buflen := len + tagLen(len) - - buf := make([]byte, 0, buflen) - buf = appendTag(buf, 0x20, len) - buf = appendUint(buf, val) - - return w.writeValue("Writer.WriteUint", buf) -} - -// WriteBigInt writes a big integer. -func (w *binaryWriter) WriteBigInt(val *big.Int) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteBigInt"); w.err != nil { - return w.err - } - - if w.err = w.writeBigInt(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -// WriteBigInt writes the actual big integer value. -func (w *binaryWriter) writeBigInt(val *big.Int) error { - sign := val.Sign() - if sign == 0 { - return w.write([]byte{0x20}) - } - - code := byte(0x20) - if sign < 0 { - code = 0x30 - } - - bs := val.Bytes() - - bl := uint64(len(bs)) - if bl < 64 { - buflen := bl + tagLen(bl) - buf := make([]byte, 0, buflen) - - buf = appendTag(buf, code, bl) - buf = append(buf, bs...) - return w.write(buf) - } - - // no sense in copying, emit tag separately. - if err := w.writeTag(code, bl); err != nil { - return err - } - return w.write(bs) -} - -// WriteFloat writes a floating-point value. -func (w *binaryWriter) WriteFloat(val float64) error { - if val == 0 { - return w.writeValue("Writer.WriteFloat", []byte{0x40}) - } - - bs := make([]byte, 9) - bs[0] = 0x48 - - bits := math.Float64bits(val) - binary.BigEndian.PutUint64(bs[1:], bits) - - return w.writeValue("Writer.WriteFloat", bs) -} - -// WriteDecimal writes a decimal value. -func (w *binaryWriter) WriteDecimal(val *Decimal) error { - coef, exp := val.CoEx() - - vlen := uint64(0) - if exp != 0 { - vlen += varIntLen(int64(exp)) - } - if coef.Sign() != 0 { - vlen += bigIntLen(coef) - } - - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) - - buf = appendTag(buf, 0x50, vlen) - if exp != 0 { - buf = appendVarInt(buf, int64(exp)) - } - buf = appendBigInt(buf, coef) - - return w.writeValue("Writer.WriteDecimal", buf) -} - -// WriteTimestamp writes a timestamp value. -func (w *binaryWriter) WriteTimestamp(val time.Time) error { - _, offset := val.Zone() - offset /= 60 - utc := val.In(time.UTC) - - vlen := timeLen(offset, utc) - buflen := vlen + tagLen(vlen) - - buf := make([]byte, 0, buflen) - - buf = appendTag(buf, 0x60, vlen) - buf = appendTime(buf, offset, utc) - - return w.writeValue("Writer.WriteTimestamp", buf) -} - -// WriteSymbol writes a symbol value. -func (w *binaryWriter) WriteSymbol(val string) error { - id, err := w.resolve("Writer.WriteSymbol", val) - if err != nil { - w.err = err - return err - } - - vlen := uintLen(uint64(id)) - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) - - buf = appendTag(buf, 0x70, vlen) - buf = appendUint(buf, uint64(id)) - - return w.writeValue("Writer.WriteSymbol", buf) -} - -// WriteString writes a string. -func (w *binaryWriter) WriteString(val string) error { - if len(val) == 0 { - return w.writeValue("Writer.WriteString", []byte{0x80}) - } - - vlen := uint64(len(val)) - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) - - buf = appendTag(buf, 0x80, vlen) - buf = append(buf, val...) - - return w.writeValue("Writer.WriteString", buf) -} - -// WriteClob writes a clob. -func (w *binaryWriter) WriteClob(val []byte) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteClob"); w.err != nil { - return w.err - } - - if w.err = w.writeLob(0x90, val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -// WriteBlob writes a blob. -func (w *binaryWriter) WriteBlob(val []byte) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { - return w.err - } - - if w.err = w.writeLob(0xA0, val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -func (w *binaryWriter) writeLob(code byte, val []byte) error { - vlen := uint64(len(val)) - - if vlen < 64 { - buflen := vlen + tagLen(vlen) - buf := make([]byte, 0, buflen) - - buf = appendTag(buf, code, vlen) - buf = append(buf, val...) - - return w.write(buf) - } - - if err := w.writeTag(code, vlen); err != nil { - return err - } - return w.write(val) -} - -// BeginList begins writing a list. -func (w *binaryWriter) BeginList() error { - if w.err == nil { - w.err = w.begin("Writer.BeginList", ctxInList, 0xB0) - } - return w.err -} - -// EndList finishes writing a list. -func (w *binaryWriter) EndList() error { - if w.err == nil { - w.err = w.end("Writer.EndList", ctxInList) - } - return w.err -} - -// BeginSexp begins writing an s-expression. -func (w *binaryWriter) BeginSexp() error { - if w.err == nil { - w.err = w.begin("Writer.BeginSexp", ctxInSexp, 0xC0) - } - return w.err -} - -// EndSexp finishes writing an s-expression. -func (w *binaryWriter) EndSexp() error { - if w.err == nil { - w.err = w.end("Writer.EndSexp", ctxInSexp) - } - return w.err -} - -// BeginStruct begins writing a struct. -func (w *binaryWriter) BeginStruct() error { - if w.err == nil { - w.err = w.begin("Writer.BeginStruct", ctxInStruct, 0xD0) - } - return w.err -} - -// EndStruct finishes writing a struct. -func (w *binaryWriter) EndStruct() error { - if w.err == nil { - w.err = w.end("Writer.EndStruct", ctxInStruct) - } - return w.err -} - -// Finish finishes writing a datagram. -func (w *binaryWriter) Finish() error { - if w.err != nil { - return w.err - } - if w.ctx.peek() != ctxAtTopLevel { - return &UsageError{"Writer.Finish", "not at top level"} - } - - w.clear() - w.wroteLST = false - - seq := w.bufs.peek() - if seq != nil { - w.bufs.pop() - if w.bufs.peek() != nil { - panic("at top level but too many bufseqs") - } - - lst := w.lstb.Build() - if err := w.writeLST(lst); err != nil { - return err - } - if w.err = w.emit(seq); w.err != nil { - return w.err - } - } - - return nil -} - -// Emit emits the given node. If we're currently at the top level, that -// means actually emitting to the output stream. If not, we emit append -// to the current bufseq. -func (w *binaryWriter) emit(node bufnode) error { - s := w.bufs.peek() - if s == nil { - return node.EmitTo(w.out) - } - s.Append(node) - return nil -} - -// Write emits the given bytes as an atom. -func (w *binaryWriter) write(bs []byte) error { - return w.emit(atom(bs)) -} - -// WriteValue writes a serialized value to the output stream. -func (w *binaryWriter) writeValue(api string, val []byte) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(api); w.err != nil { - return w.err - } - - if w.err = w.write(val); w.err != nil { - return w.err - } - - w.err = w.endValue() - return w.err -} - -// WriteTag writes out a type+length tag. Use me when you've already got the value to -// be written as a []byte and don't want to copy it. -func (w *binaryWriter) writeTag(code byte, len uint64) error { - tl := tagLen(len) - - tag := make([]byte, 0, tl) - tag = appendTag(tag, code, len) - - return w.write(tag) -} - -// WriteLST writes out a local symbol table. -func (w *binaryWriter) writeLST(lst SymbolTable) error { - if err := w.write([]byte{0xE0, 0x01, 0x00, 0xEA}); err != nil { - return err - } - return lst.WriteTo(w) -} - -// BeginValue begins the process of writing a value by writing out -// its field name and annotations. -func (w *binaryWriter) beginValue(api string) error { - // We have to record/empty these before calling writeLST, which - // will end up using/modifying them. Ugh. - name := w.fieldName - as := w.annotations - w.clear() - - // If we have a local symbol table and haven't written it out yet, do that now. - if w.lst != nil && !w.wroteLST { - w.wroteLST = true - if err := w.writeLST(w.lst); err != nil { - return err - } - } - - if w.inStruct() { - if name == "" { - return &UsageError{api, "field name not set"} - } - - id, err := w.resolve(api, name) - if err != nil { - return err - } - - buf := make([]byte, 0, 10) - buf = appendVarUint(buf, id) - if err := w.write(buf); err != nil { - return err - } - } - - if len(as) > 0 { - ids := make([]uint64, len(as)) - idlen := uint64(0) - - for i, a := range as { - id, err := w.resolve(api, a) - if err != nil { - return err - } - - ids[i] = id - idlen += varUintLen(id) - } - - buflen := idlen + varUintLen(idlen) - buf := make([]byte, 0, buflen) - - buf = appendVarUint(buf, idlen) - for _, id := range ids { - buf = appendVarUint(buf, id) - } - - // TODO: We could theoretically write the actual tag here if we know the - // length of the value ahead of time. - w.bufs.push(&container{code: 0xE0}) - if err := w.write(buf); err != nil { - return err - } - } - - return nil -} - -// EndValue ends the process of writing a value by flushing it and its annotations -// up a level, if needed. -func (w *binaryWriter) endValue() error { - seq := w.bufs.peek() - if seq != nil { - if c, ok := seq.(*container); ok && c.code == 0xE0 { - w.bufs.pop() - return w.emit(seq) - } - } - return nil -} - -// Begin begins writing a new container. -func (w *binaryWriter) begin(api string, t ctx, code byte) error { - if err := w.beginValue(api); err != nil { - return err - } - - w.ctx.push(t) - w.bufs.push(&container{code: code}) - - return nil -} - -// End ends writing a container, emitting its buffered contents up a level in the stack. -func (w *binaryWriter) end(api string, t ctx) error { - if w.ctx.peek() != t { - return &UsageError{api, "not in that kind of container"} - } - - seq := w.bufs.peek() - if seq != nil { - w.bufs.pop() - if err := w.emit(seq); err != nil { - return err - } - } - - w.clear() - w.ctx.pop() - - return w.endValue() -} - -// Resolve resolves a symbol to its ID. -func (w *binaryWriter) resolve(api, sym string) (uint64, error) { - if strings.HasPrefix(sym, "$") { - id, err := strconv.ParseUint(sym[1:], 10, 64) - if err == nil { - return id, nil - } - } - - if w.lst != nil { - id, ok := w.lst.FindByName(sym) - if !ok { - return 0, &UsageError{api, fmt.Sprintf("symbol '%v' not defined", sym)} - } - return id, nil - } - - id, _ := w.lstb.Add(sym) - return id, nil -} diff --git a/binarywriter_test.go b/binarywriter_test.go deleted file mode 100644 index c1cbaaaa..00000000 --- a/binarywriter_test.go +++ /dev/null @@ -1,387 +0,0 @@ -package ion - -import ( - "bytes" - "encoding/hex" - "fmt" - "math" - "math/big" - "strings" - "testing" - "time" -) - -func TestWriteBinaryStruct(t *testing.T) { - eval := []byte{ - 0xD0, // {} - 0xEA, 0x81, 0xEE, 0xD7, // foo::{ - 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, - 0x88, 0x20, // max_id:0 - // } - } - testBinaryWriter(t, eval, func(w Writer) { - w.BeginStruct() - w.EndStruct() - - w.Annotation("foo") - w.BeginStruct() - { - w.FieldName("name") - w.Annotation("bar") - w.WriteNull() - - w.FieldName("max_id") - w.WriteInt(0) - } - w.EndStruct() - }) -} - -func TestWriteBinarySexp(t *testing.T) { - eval := []byte{ - 0xC0, // () - 0xE8, 0x81, 0xEE, 0xC5, // foo::( - 0xE3, 0x81, 0xEF, 0x0F, // bar::null, - 0x20, // 0 - // ) - } - testBinaryWriter(t, eval, func(w Writer) { - w.BeginSexp() - w.EndSexp() - - w.Annotation("foo") - w.BeginSexp() - { - w.Annotation("bar") - w.WriteNull() - - w.WriteInt(0) - } - w.EndSexp() - }) -} - -func TestWriteBinaryList(t *testing.T) { - eval := []byte{ - 0xB0, // [] - 0xE8, 0x81, 0xEE, 0xB5, // foo::[ - 0xE3, 0x81, 0xEF, 0x0F, // bar::null, - 0x20, // 0 - // ] - } - testBinaryWriter(t, eval, func(w Writer) { - w.BeginList() - w.EndList() - - w.Annotation("foo") - w.BeginList() - { - w.Annotation("bar") - w.WriteNull() - - w.WriteInt(0) - } - w.EndList() - }) -} - -func TestWriteBinaryBlob(t *testing.T) { - eval := []byte{ - 0xA0, - 0xAB, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', - } - testBinaryWriter(t, eval, func(w Writer) { - w.WriteBlob([]byte{}) - w.WriteBlob([]byte("Hello World")) - }) -} - -func TestWriteLargeBinaryBlob(t *testing.T) { - eval := make([]byte, 131) - eval[0] = 0xAE - eval[1] = 0x01 - eval[2] = 0x80 - testBinaryWriter(t, eval, func(w Writer) { - w.WriteBlob(make([]byte, 128)) - }) -} - -func TestWriteBinaryClob(t *testing.T) { - eval := []byte{ - 0x90, - 0x9B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', - } - testBinaryWriter(t, eval, func(w Writer) { - w.WriteClob([]byte{}) - w.WriteClob([]byte("Hello World")) - }) -} - -func TestWriteBinaryString(t *testing.T) { - eval := []byte{ - 0x80, // "" - 0x8B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', - 0x8E, 0x9B, 'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd', - ' ', 'B', 'u', 't', ' ', 'E', 'v', 'e', 'n', ' ', 'L', 'o', 'n', 'g', 'e', 'r', - 0x84, 0xE0, 0x01, 0x00, 0xEA, - } - testBinaryWriter(t, eval, func(w Writer) { - w.WriteString("") - w.WriteString("Hello World") - w.WriteString("Hello World But Even Longer") - w.WriteString("\xE0\x01\x00\xEA") - }) -} - -func TestWriteBinarySymbol(t *testing.T) { - eval := []byte{ - 0x71, 0x01, // $ion - 0x71, 0x04, // name - 0x71, 0x05, // version - 0x71, 0x09, // $ion_shared_symbol_table - 0x74, 0xFF, 0xFF, 0xFF, 0xFF, // $4294967295 - } - testBinaryWriter(t, eval, func(w Writer) { - w.WriteSymbol("$ion") - w.WriteSymbol("name") - w.WriteSymbol("version") - w.WriteSymbol("$ion_shared_symbol_table") - w.WriteSymbol("$4294967295") - }) -} - -func TestWriteBinaryTimestamp(t *testing.T) { - eval := []byte{ - 0x67, 0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80, // 0001-01-01T00:00:00Z - 0x6E, 0x8E, // 0x0E-bit timestamp - 0x04, 0xD8, // offset: +600 minutes (+10:00) - 0x0F, 0xE3, // year: 2019 - 0x88, // month: 8 - 0x84, // day: 4 - 0x88, // hour: 8 utc (18 local) - 0x8F, // minute: 15 - 0xAB, // second: 43 - 0xC9, // exp: -9 - 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 - } - - nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") - - testBinaryWriter(t, eval, func(w Writer) { - w.WriteTimestamp(time.Time{}) - w.WriteTimestamp(nowish) - }) -} - -func TestWriteBinaryDecimal(t *testing.T) { - eval := []byte{ - 0x50, // 0. - 0x51, 0xC3, // 0.000, aka 0 x 10^-3 - 0x53, 0xC3, 0x03, 0xE8, // 1.000, aka 1000 x 10^-3 - 0x53, 0xC3, 0x83, 0xE8, // -1.000, aka -1000 x 10^-3 - 0x53, 0x00, 0xE4, 0x01, // 1d100, aka 1 * 10^100 - 0x53, 0x00, 0xE4, 0x81, // -1d100, aka -1 * 10^100 - } - - testBinaryWriter(t, eval, func(w Writer) { - w.WriteDecimal(MustParseDecimal("0.")) - w.WriteDecimal(MustParseDecimal("0.000")) - w.WriteDecimal(MustParseDecimal("1.000")) - w.WriteDecimal(MustParseDecimal("-1.000")) - w.WriteDecimal(MustParseDecimal("1d100")) - w.WriteDecimal(MustParseDecimal("-1d100")) - }) -} - -func TestWriteBinaryFloats(t *testing.T) { - eval := []byte{ - 0x40, // 0 - 0x48, 0x7F, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // MaxFloat64 - 0x48, 0xFF, 0xEF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // -MaxFloat64 - 0x48, 0x7F, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // +inf - 0x48, 0xFF, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // -inf - 0x48, 0x7F, 0xF8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // NaN - } - testBinaryWriter(t, eval, func(w Writer) { - w.WriteFloat(0) - w.WriteFloat(math.MaxFloat64) - w.WriteFloat(-math.MaxFloat64) - w.WriteFloat(math.Inf(1)) - w.WriteFloat(math.Inf(-1)) - w.WriteFloat(math.NaN()) - }) -} - -func TestWriteBinaryBigInts(t *testing.T) { - eval := []byte{ - 0x20, // 0 - 0x21, 0xFF, // 0xFF - 0x31, 0xFF, // -0xFF - 0x2E, 0x90, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // a really big integer - } - - testBinaryWriter(t, eval, func(w Writer) { - w.WriteBigInt(big.NewInt(0)) - w.WriteBigInt(big.NewInt(0xFF)) - w.WriteBigInt(big.NewInt(-0xFF)) - w.WriteBigInt(new(big.Int).SetBytes([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})) - }) -} - -func TestWriteBinaryReallyBigInts(t *testing.T) { - eval := []byte{ - 0x2E, 0x01, 0x80, // 128-byte positive integer - 0x80, // high bit set - } - eval = append(eval, make([]byte, 127)...) - testBinaryWriter(t, eval, func(w Writer) { - i := new(big.Int) - i = i.SetBit(i, 1023, 1) - w.WriteBigInt(i) - }) -} - -func TestWriteBinaryInts(t *testing.T) { - eval := []byte{ - 0x20, // 0 - 0x21, 0xFF, // 0xFF - 0x31, 0xFF, // -0xFF - 0x22, 0xFF, 0xFF, // 0xFFFF - 0x33, 0xFF, 0xFF, 0xFF, // -0xFFFFFF - 0x28, 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, // math.MaxInt64 - } - - testBinaryWriter(t, eval, func(w Writer) { - w.WriteInt(0) - w.WriteInt(0xFF) - w.WriteInt(-0xFF) - w.WriteInt(0xFFFF) - w.WriteInt(-0xFFFFFF) - w.WriteInt(math.MaxInt64) - }) -} - -func TestWriteBinaryBoolAnnotated(t *testing.T) { - eval := []byte{ - 0xE4, // 4-byte annotated value - 0x82, // 2 bytes of annotations - 0x84, // $4 (name) - 0x85, // $5 (version) - 0x10, // false - } - - testBinaryWriter(t, eval, func(w Writer) { - w.Annotations("name", "version") - w.WriteBool(false) - }) -} - -func TestWriteBinaryBools(t *testing.T) { - eval := []byte{ - 0x10, // false - 0x11, // true - } - - testBinaryWriter(t, eval, func(w Writer) { - w.WriteBool(false) - w.WriteBool(true) - }) -} - -func TestWriteBinaryNulls(t *testing.T) { - eval := []byte{ - 0x0F, - 0x1F, - 0x2F, - // 0x3F, // negative integer, not actually valid - 0x4F, - 0x5F, - 0x6F, - 0x7F, - 0x8F, - 0x9F, - 0xAF, - 0xBF, - 0xCF, - 0xDF, - } - - testBinaryWriter(t, eval, func(w Writer) { - w.WriteNull() - w.WriteNullType(BoolType) - w.WriteNullType(IntType) - w.WriteNullType(FloatType) - w.WriteNullType(DecimalType) - w.WriteNullType(TimestampType) - w.WriteNullType(SymbolType) - w.WriteNullType(StringType) - w.WriteNullType(ClobType) - w.WriteNullType(BlobType) - w.WriteNullType(ListType) - w.WriteNullType(SexpType) - w.WriteNullType(StructType) - }) -} - -func testBinaryWriter(t *testing.T, eval []byte, f func(w Writer)) { - val := writeBinary(t, f) - - prefix := []byte{ - 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 - 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ - 0x86, 0xBE, 0x8E, // imports:[ - 0xDD, // { - 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" - 0x85, 0x21, 0x2A, // version: 42 - 0x88, 0x21, 0x64, // max_id: 100 - // }] - 0x87, 0xB8, // symbols: [ - 0x83, 'f', 'o', 'o', // "foo" - 0x83, 'b', 'a', 'r', // "bar" - // ] - // } - } - eval = append(prefix, eval...) - - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", fmtbytes(eval), fmtbytes(val)) - } -} - -func fmtbytes(bs []byte) string { - buf := strings.Builder{} - buf.WriteByte('[') - for i, b := range bs { - if i > 0 { - buf.WriteByte(' ') - } - buf.WriteString(hex.EncodeToString([]byte{b})) - } - buf.WriteByte(']') - return buf.String() -} - -func writeBinary(t *testing.T, f func(w Writer)) []byte { - bogusSyms := []string{} - for i := 0; i < 100; i++ { - bogusSyms = append(bogusSyms, fmt.Sprintf("bogus_sym_%v", i)) - } - - bogus := []SharedSymbolTable{ - NewSharedSymbolTable("bogus", 42, bogusSyms), - } - - buf := bytes.Buffer{} - w := NewBinaryWriterLST(&buf, NewLocalSymbolTable(bogus, []string{ - "foo", - "bar", - })) - - f(w) - - if err := w.Finish(); err != nil { - t.Fatal(err) - } - - return buf.Bytes() -} diff --git a/bits.go b/bits.go deleted file mode 100644 index f3772702..00000000 --- a/bits.go +++ /dev/null @@ -1,285 +0,0 @@ -package ion - -import ( - "math/big" - "time" -) - -// uintLen pre-calculates the length, in bytes, of the given uint value. -func uintLen(v uint64) uint64 { - len := uint64(1) - v >>= 8 - - for v > 0 { - len++ - v >>= 8 - } - - return len -} - -// appendUint appends a uint value to the given slice. The reader is -// expected to know how many bytes the value takes up. -func appendUint(b []byte, v uint64) []byte { - var buf [8]byte - - i := 7 - buf[i] = byte(v & 0xFF) - v >>= 8 - - for v > 0 { - i-- - buf[i] = byte(v & 0xFF) - v >>= 8 - } - - return append(b, buf[i:]...) -} - -// intLen pre-calculates the length, in bytes, of the given int value. -func intLen(n int64) uint64 { - if n == 0 { - return 0 - } - - mag := uint64(n) - if n < 0 { - mag = uint64(-n) - } - - len := uintLen(mag) - - // If the high bit is a one, we need an extra byte to store the sign bit. - hb := mag >> ((len - 1) * 8) - if hb&0x80 != 0 { - len++ - } - - return len -} - -// appendInt appends a (signed) int to the given slice. The reader is -// expected to know how many bytes the value takes up. -func appendInt(b []byte, n int64) []byte { - if n == 0 { - return b - } - - neg := false - mag := uint64(n) - - if n < 0 { - neg = true - mag = uint64(-n) - } - - var buf [8]byte - bits := buf[:0] - bits = appendUint(bits, mag) - - if bits[0]&0x80 == 0 { - // We've got space we can use for the sign bit. - if neg { - bits[0] ^= 0x80 - } - } else { - // We need to add more space. - bit := byte(0) - if neg { - bit = 0x80 - } - b = append(b, bit) - } - - return append(b, bits...) -} - -// bigIntLen pre-calculates the length, in bytes, of the given big.Int value. -func bigIntLen(v *big.Int) uint64 { - if v.Sign() == 0 { - return 0 - } - - bitl := v.BitLen() - bytel := bitl / 8 - - // Either bitl is evenly divisibly by 8, in which case we need another - // byte for the sign bit, or its not in which case we need to round up - // (but will then have room for the sign bit). - return uint64(bytel) + 1 -} - -// appendBigInt appends a (signed) big.Int to the given slice. The reader is -// expected to know how many bytes the value takes up. -func appendBigInt(b []byte, v *big.Int) []byte { - sign := v.Sign() - if sign == 0 { - return b - } - - bits := v.Bytes() - - if bits[0]&0x80 == 0 { - // We've got space we can use for the sign bit. - if sign < 0 { - bits[0] ^= 0x80 - } - } else { - // We need to add more space. - bit := byte(0) - if sign < 0 { - bit = 0x80 - } - b = append(b, bit) - } - - return append(b, bits...) -} - -// varUintLen pre-calculates the length, in bytes, of the given varUint value. -func varUintLen(v uint64) uint64 { - len := uint64(1) - v >>= 7 - - for v > 0 { - len++ - v >>= 7 - } - - return len -} - -// appendVarUint appends a variable-length-encoded uint to the given slice. -// Each byte stores seven bits of value; the high bit is a flag marking the -// last byte of the value. -func appendVarUint(b []byte, v uint64) []byte { - var buf [10]byte - - i := 9 - buf[i] = 0x80 | byte(v&0x7F) - v >>= 7 - - for v > 0 { - i-- - buf[i] = byte(v & 0x7F) - v >>= 7 - } - - return append(b, buf[i:]...) -} - -// varIntLen pre-calculates the length, in bytes, of the given varInt value. -func varIntLen(v int64) uint64 { - mag := uint64(v) - if v < 0 { - mag = uint64(-v) - } - - // Reserve one extra bit of the first byte for sign. - len := uint64(1) - mag >>= 6 - - for mag > 0 { - len++ - mag >>= 7 - } - - return len -} - -// appendVarInt appends a variable-length-encoded int to the given slice. -// Most bytes store seven bits of value; the high bit is a flag marking the -// last byte of the value. The first byte additionally stores a sign bit. -func appendVarInt(b []byte, v int64) []byte { - var buf [10]byte - - signbit := byte(0) - mag := uint64(v) - if v < 0 { - signbit = 0x40 - mag = uint64(-v) - } - - next := mag >> 6 - if next == 0 { - // The whole thing fits in one byte. - return append(b, 0x80|signbit|byte(mag&0x3F)) - } - - i := 9 - buf[i] = 0x80 | byte(mag&0x7F) - mag >>= 7 - next = mag >> 6 - - for next > 0 { - i-- - buf[i] = byte(mag & 0x7F) - mag >>= 7 - next = mag >> 6 - } - - i-- - buf[i] = signbit | byte(mag&0x3F) - - return append(b, buf[i:]...) -} - -// tagLen pre-calculates the length, in bytes, of a tag. -func tagLen(len uint64) uint64 { - if len < 0x0E { - return 1 - } - return 1 + varUintLen(len) -} - -// appendTag appends a code+len tag to the given slice. -func appendTag(b []byte, code byte, len uint64) []byte { - if len < 0x0E { - // Short form, with length embedded in the code byte. - return append(b, code|byte(len)) - } - - // Long form, with separate length. - b = append(b, code|0x0E) - return appendVarUint(b, len) -} - -// timeLen pre-calculates the length, in bytes, of the given time value. -func timeLen(offset int, utc time.Time) uint64 { - ret := varIntLen(int64(offset)) - - // Almost certainly two but let's be safe. - ret += varUintLen(uint64(utc.Year())) - - // Month, day, hour, minute, and second are all guaranteed to be one byte. - ret += 5 - - ns := utc.Nanosecond() - if ns > 0 { - ret++ // varIntLen(-9) - ret += intLen(int64(ns)) - } - - return ret -} - -// appendTime appends a timestamp value -func appendTime(b []byte, offset int, utc time.Time) []byte { - b = appendVarInt(b, int64(offset)) - - b = appendVarUint(b, uint64(utc.Year())) - b = appendVarUint(b, uint64(utc.Month())) - b = appendVarUint(b, uint64(utc.Day())) - - b = appendVarUint(b, uint64(utc.Hour())) - b = appendVarUint(b, uint64(utc.Minute())) - b = appendVarUint(b, uint64(utc.Second())) - - ns := utc.Nanosecond() - if ns > 0 { - b = appendVarInt(b, -9) - b = appendInt(b, int64(ns)) - } - - return b -} diff --git a/bits_test.go b/bits_test.go deleted file mode 100644 index db2c15c4..00000000 --- a/bits_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package ion - -import ( - "bytes" - "fmt" - "math" - "math/big" - "testing" - "time" -) - -func TestAppendUint(t *testing.T) { - test := func(val uint64, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - len := uintLen(val) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendUint(nil, val) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - test(0, 1, []byte{0}) - test(0xFF, 1, []byte{0xFF}) - test(0x1FF, 2, []byte{0x01, 0xFF}) - test(math.MaxUint64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) -} - -func TestAppendInt(t *testing.T) { - test := func(val int64, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - len := intLen(val) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendInt(nil, val) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - test(0, 0, []byte{}) - test(0x7F, 1, []byte{0x7F}) - test(-0x7F, 1, []byte{0xFF}) - - test(0xFF, 2, []byte{0x00, 0xFF}) - test(-0xFF, 2, []byte{0x80, 0xFF}) - - test(0x7FFF, 2, []byte{0x7F, 0xFF}) - test(-0x7FFF, 2, []byte{0xFF, 0xFF}) - - test(math.MaxInt64, 8, []byte{0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) - test(-math.MaxInt64, 8, []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) - test(math.MinInt64, 9, []byte{0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) -} - -func TestAppendBigInt(t *testing.T) { - test := func(val *big.Int, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - len := bigIntLen(val) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendBigInt(nil, val) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - test(big.NewInt(0), 0, []byte{}) - test(big.NewInt(0x7F), 1, []byte{0x7F}) - test(big.NewInt(-0x7F), 1, []byte{0xFF}) - - test(big.NewInt(0xFF), 2, []byte{0x00, 0xFF}) - test(big.NewInt(-0xFF), 2, []byte{0x80, 0xFF}) - - test(big.NewInt(0x7FFF), 2, []byte{0x7F, 0xFF}) - test(big.NewInt(-0x7FFF), 2, []byte{0xFF, 0xFF}) -} - -func TestAppendVarUint(t *testing.T) { - test := func(val uint64, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - len := varUintLen(val) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendVarUint(nil, val) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - test(0, 1, []byte{0x80}) - test(0x7F, 1, []byte{0xFF}) - test(0xFF, 2, []byte{0x01, 0xFF}) - test(0x1FF, 2, []byte{0x03, 0xFF}) - test(0x3FFF, 2, []byte{0x7F, 0xFF}) - test(0x7FFF, 3, []byte{0x01, 0x7F, 0xFF}) - test(0x7FFFFFFFFFFFFFFF, 9, []byte{0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) - test(0xFFFFFFFFFFFFFFFF, 10, []byte{0x01, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) -} - -func TestAppendVarInt(t *testing.T) { - test := func(val int64, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - len := varIntLen(val) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendVarInt(nil, val) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - test(0, 1, []byte{0x80}) - - test(0x3F, 1, []byte{0xBF}) // 1011 1111 - test(-0x3F, 1, []byte{0xFF}) - - test(0x7F, 2, []byte{0x00, 0xFF}) - test(-0x7F, 2, []byte{0x40, 0xFF}) - - test(0x1FFF, 2, []byte{0x3F, 0xFF}) - test(-0x1FFF, 2, []byte{0x7F, 0xFF}) - - test(0x3FFF, 3, []byte{0x00, 0x7F, 0xFF}) - test(-0x3FFF, 3, []byte{0x40, 0x7F, 0xFF}) - - test(0x3FFFFFFFFFFFFFFF, 9, []byte{0x3F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) - test(-0x3FFFFFFFFFFFFFFF, 9, []byte{0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) - - test(math.MaxInt64, 10, []byte{0x00, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) - test(-math.MaxInt64, 10, []byte{0x40, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) - test(math.MinInt64, 10, []byte{0x41, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80}) -} - -func TestAppendTag(t *testing.T) { - test := func(code byte, vlen uint64, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("(%x,%v)", code, vlen), func(t *testing.T) { - len := tagLen(vlen) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendTag(nil, code, vlen) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - test(0x20, 1, 1, []byte{0x21}) - test(0x30, 0x0D, 1, []byte{0x3D}) - test(0x40, 0x0E, 2, []byte{0x4E, 0x8E}) - test(0x50, math.MaxInt64, 10, []byte{0x5E, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0xFF}) -} - -func TestAppendTime(t *testing.T) { - test := func(val time.Time, elen uint64, ebits []byte) { - t.Run(fmt.Sprintf("%x", val), func(t *testing.T) { - _, offset := val.Zone() - offset /= 60 - utc := val.In(time.UTC) - - len := timeLen(offset, utc) - if len != elen { - t.Errorf("expected len=%v, got len=%v", elen, len) - } - - bits := appendTime(nil, offset, utc) - if !bytes.Equal(bits, ebits) { - t.Errorf("expected %v, got %v", fmtbytes(ebits), fmtbytes(bits)) - } - }) - } - - nowish, _ := time.Parse(time.RFC3339Nano, "2019-08-04T18:15:43.863494+10:00") - - test(time.Time{}, 7, []byte{0x80, 0x81, 0x81, 0x81, 0x80, 0x80, 0x80}) - test(nowish, 14, []byte{ - 0x04, 0xD8, // offset: +600 minutes (+10:00) - 0x0F, 0xE3, // year: 2019 - 0x88, // month: 8 - 0x84, // day: 4 - 0x88, // hour: 8 utc (18 local) - 0x8F, // minute: 15 - 0xAB, // second: 43 - 0xC9, // exp: -9 - 0x33, 0x77, 0xDF, 0x70, // nsec: 863494000 - }) -} diff --git a/bitstream.go b/bitstream.go deleted file mode 100644 index 512aa109..00000000 --- a/bitstream.go +++ /dev/null @@ -1,935 +0,0 @@ -package ion - -import ( - "bufio" - "bytes" - "encoding/binary" - "fmt" - "io" - "math" - "math/big" - "time" -) - -type bss uint8 - -const ( - bssBeforeValue bss = iota - bssOnValue - bssBeforeFieldID - bssOnFieldID -) - -type bitcode uint8 - -const ( - bitcodeNone bitcode = iota - bitcodeEOF - bitcodeBVM - bitcodeNull - bitcodeFalse - bitcodeTrue - bitcodeInt - bitcodeNegInt - bitcodeFloat - bitcodeDecimal - bitcodeTimestamp - bitcodeSymbol - bitcodeString - bitcodeClob - bitcodeBlob - bitcodeList - bitcodeSexp - bitcodeStruct - bitcodeFieldID - bitcodeAnnotation -) - -func (b bitcode) String() string { - switch b { - case bitcodeNone: - return "none" - case bitcodeEOF: - return "eof" - case bitcodeBVM: - return "bvm" - case bitcodeFalse: - return "false" - case bitcodeTrue: - return "true" - case bitcodeInt: - return "int" - case bitcodeNegInt: - return "negint" - case bitcodeFloat: - return "float" - case bitcodeDecimal: - return "decimal" - case bitcodeTimestamp: - return "timestamp" - case bitcodeSymbol: - return "symbol" - case bitcodeString: - return "string" - case bitcodeClob: - return "clob" - case bitcodeBlob: - return "blob" - case bitcodeList: - return "list" - case bitcodeSexp: - return "sexp" - case bitcodeStruct: - return "struct" - case bitcodeFieldID: - return "fieldid" - case bitcodeAnnotation: - return "annotation" - default: - return fmt.Sprintf("", uint8(b)) - } -} - -// A bitstream is a low-level parser for binary Ion values. -type bitstream struct { - in *bufio.Reader - pos uint64 - state bss - stack bitstack - - code bitcode - null bool - len uint64 -} - -// Init initializes this stream with the given bufio.Reader. -func (b *bitstream) Init(in *bufio.Reader) { - b.in = in -} - -// InitBytes initializes this stream with the given bytes. -func (b *bitstream) InitBytes(in []byte) { - b.in = bufio.NewReader(bytes.NewReader(in)) -} - -// Code returns the typecode of the current value. -func (b *bitstream) Code() bitcode { - return b.code -} - -// IsNull returns true if the current value is null. -func (b *bitstream) IsNull() bool { - return b.null -} - -// Pos returns the current position. -func (b *bitstream) Pos() uint64 { - return b.pos -} - -// Len returns the length of the current value. -func (b *bitstream) Len() uint64 { - return b.len -} - -// Next advances the stream to the next value. -func (b *bitstream) Next() error { - // If we have an unread value, skip over it to get to the next one. - switch b.state { - case bssOnValue, bssOnFieldID: - if err := b.SkipValue(); err != nil { - return err - } - } - - // If we're at the end of the current container, stop and make the user step out. - if !b.stack.empty() { - cur := b.stack.peek() - if b.pos == cur.end { - b.code = bitcodeEOF - return nil - } - } - - // If it's time to read a field id, do that. - if b.state == bssBeforeFieldID { - b.code = bitcodeFieldID - b.state = bssOnFieldID - return nil - } - - // Otherwise it's time to read a value. Read the tag byte. - c, err := b.read() - if err != nil { - return err - } - - // Found the end of the file. - if c == -1 { - b.code = bitcodeEOF - return nil - } - - // Parse the tag. - code, len := parseTag(c) - if code == bitcodeNone { - return &InvalidTagByteError{byte(c), b.pos - 1} - } - - b.state = bssOnValue - - if code == bitcodeAnnotation { - switch len { - case 0: - // This value is actually a BVM. It's invalid if we're not at the top level. - if !b.stack.empty() { - return &SyntaxError{"invalid BVM in a container", b.pos - 1} - } - b.code = bitcodeBVM - b.len = 3 - return nil - - case 0x0F: - // No such thing as a null annotation. - return &InvalidTagByteError{byte(c), b.pos - 1} - } - } - - // Booleans are a bit special; the 'length' stores the value. - if code == bitcodeFalse { - switch len { - case 0, 0x0F: - break - case 1: - code = bitcodeTrue - len = 0 - default: - // Other forms are invalid. - return &InvalidTagByteError{byte(c), b.pos - 1} - } - } - - if len == 0x0F { - // This value is actually a null. - b.code = code - b.null = true - return nil - } - - pos := b.pos - rem := b.remaining() - - // This value's actual len is encoded as a separate varUint. - if len == 0x0E { - var lenlen uint64 - len, lenlen, err = b.readVarUintLen(rem) - if err != nil { - return err - } - rem -= lenlen - } - - if len > rem { - msg := fmt.Sprintf("value overruns its container: %v vs %v", len, rem) - return &SyntaxError{msg, pos - 1} - } - - b.code = code - b.len = len - return nil -} - -// SkipValue skips over the current value. -func (b *bitstream) SkipValue() error { - switch b.state { - case bssBeforeFieldID, bssBeforeValue: - // No current value to skip yet. - return nil - - case bssOnFieldID: - if err := b.skipVarUint(); err != nil { - return err - } - b.state = bssBeforeValue - - case bssOnValue: - if b.len > 0 { - if err := b.skip(b.len); err != nil { - return err - } - } - b.state = b.stateAfterValue() - - default: - panic(fmt.Sprintf("invalid state %v", b.state)) - } - - b.clear() - return nil -} - -// StepIn steps in to a container. -func (b *bitstream) StepIn() { - switch b.code { - case bitcodeStruct: - b.state = bssBeforeFieldID - - case bitcodeList, bitcodeSexp: - b.state = bssBeforeValue - - default: - panic(fmt.Sprintf("StepIn called with b.code=%v", b.code)) - } - - b.stack.push(b.code, b.pos+b.len) - b.clear() -} - -// StepOut steps out of a container. -func (b *bitstream) StepOut() error { - if b.stack.empty() { - panic("StepOut called at top level") - } - - cur := b.stack.peek() - b.stack.pop() - - if cur.end < b.pos { - panic(fmt.Sprintf("end (%v) greater than b.pos (%v)", cur.end, b.pos)) - } - diff := cur.end - b.pos - - // Skip over anything left in the container we're stepping out of. - if diff > 0 { - if err := b.skip(diff); err != nil { - return err - } - } - - b.state = b.stateAfterValue() - b.clear() - - return nil -} - -// ReadBVM reads a binary version marker, returning its major and minor version. -func (b *bitstream) ReadBVM() (byte, byte, error) { - if b.code != bitcodeBVM { - panic("not a BVM") - } - - major, err := b.read1() - if err != nil { - return 0, 0, err - } - - minor, err := b.read1() - if err != nil { - return 0, 0, err - } - - end, err := b.read1() - if err != nil { - return 0, 0, err - } - - if end != 0xEA { - msg := fmt.Sprintf("invalid BVM: 0xE0 0x%02X 0x%02X 0x%02X", major, minor, end) - return 0, 0, &SyntaxError{msg, b.pos - 4} - } - - b.state = bssBeforeValue - b.clear() - - return byte(major), byte(minor), nil -} - -// ReadFieldID reads a field ID. -func (b *bitstream) ReadFieldID() (uint64, error) { - if b.code != bitcodeFieldID { - panic("not a field ID") - } - - id, err := b.readVarUint() - if err != nil { - return 0, err - } - - b.state = bssBeforeValue - b.code = bitcodeNone - - return id, nil -} - -// ReadAnnotationIDs reads a set of annotation IDs. -func (b *bitstream) ReadAnnotationIDs() ([]uint64, error) { - if b.code != bitcodeAnnotation { - panic("not an annotation") - } - - alen, lenlen, err := b.readVarUintLen(b.len) - if err != nil { - return nil, err - } - - if b.len-lenlen <= alen { - // The size of the annotations is larger than the remaining free space inside the - // annotation container. - return nil, &SyntaxError{"malformed annotation", b.pos - lenlen} - } - - as := []uint64{} - for alen > 0 { - id, idlen, err := b.readVarUintLen(alen) - if err != nil { - return nil, err - } - - as = append(as, id) - alen -= idlen - } - - b.state = bssBeforeValue - b.clear() - - return as, nil -} - -// ReadInt reads an integer value. -func (b *bitstream) ReadInt() (interface{}, error) { - if b.code != bitcodeInt && b.code != bitcodeNegInt { - panic("not an integer") - } - - bs, err := b.readN(b.len) - if err != nil { - return "", err - } - - var ret interface{} - switch { - case b.len == 0: - // Special case for zero. - ret = int64(0) - - case b.len < 8, (b.len == 8 && bs[0]&0x80 == 0): - // It'll fit in an int64. - i := int64(0) - for _, b := range bs { - i <<= 8 - i ^= int64(b) - } - if b.code == bitcodeNegInt { - i = -i - } - ret = i - - default: - // Need to go big.Int. - i := new(big.Int).SetBytes(bs) - if b.code == bitcodeNegInt { - i = i.Neg(i) - } - ret = i - } - - b.state = b.stateAfterValue() - b.clear() - - return ret, nil -} - -// ReadFloat reads a float value. -func (b *bitstream) ReadFloat() (float64, error) { - if b.code != bitcodeFloat { - panic("not a float") - } - - bs, err := b.readN(b.len) - if err != nil { - return 0, err - } - - var ret float64 - switch len(bs) { - case 0: - ret = 0 - - case 4: - ui := binary.BigEndian.Uint32(bs) - ret = float64(math.Float32frombits(ui)) - - case 8: - ui := binary.BigEndian.Uint64(bs) - ret = math.Float64frombits(ui) - - default: - return 0, &SyntaxError{"invalid float size", b.pos - b.len} - } - - b.state = b.stateAfterValue() - b.clear() - - return ret, nil -} - -// ReadDecimal reads a decimal value. -func (b *bitstream) ReadDecimal() (*Decimal, error) { - if b.code != bitcodeDecimal { - panic("not a decimal") - } - - d, err := b.readDecimal(b.len) - if err != nil { - return nil, err - } - - b.state = b.stateAfterValue() - b.clear() - - return d, nil -} - -// ReadTimestamp reads a timestamp value. -func (b *bitstream) ReadTimestamp() (time.Time, error) { - if b.code != bitcodeTimestamp { - panic("not a timestamp") - } - - len := b.len - - offset, olen, err := b.readVarIntLen(len) - if err != nil { - return time.Time{}, err - } - len -= olen - - ts := []int{1, 1, 1, 0, 0, 0} - for i := 0; len > 0 && i < 6; i++ { - val, vlen, err := b.readVarUintLen(len) - if err != nil { - return time.Time{}, err - } - len -= vlen - ts[i] = int(val) - } - - nsecs, err := b.readNsecs(len) - if err != nil { - return time.Time{}, err - } - - b.state = b.stateAfterValue() - b.clear() - - utc := time.Date(ts[0], time.Month(ts[1]), ts[2], ts[3], ts[4], ts[5], int(nsecs), time.UTC) - return utc.In(time.FixedZone("fixed", int(offset)*60)), nil -} - -// ReadNsecs reads the fraction part of a timestamp and truncates it to nanoseconds. -func (b *bitstream) readNsecs(len uint64) (int, error) { - d, err := b.readDecimal(len) - if err != nil { - return 0, err - } - - nsec, err := d.ShiftL(9).Trunc() - if err != nil || nsec < 0 || nsec > 999999999 { - msg := fmt.Sprintf("invalid timestamp fraction: %v", d) - return 0, &SyntaxError{msg, b.pos} - } - - return int(nsec), nil -} - -// ReadDecimal reads a decimal value of the given length: an exponent encoded as a -// varInt, followed by an integer coefficient taking up the remaining bytes. -func (b *bitstream) readDecimal(len uint64) (*Decimal, error) { - exp := int64(0) - coef := new(big.Int) - - if len > 0 { - val, vlen, err := b.readVarIntLen(len) - if err != nil { - return nil, err - } - - if val > math.MaxInt32 || val < math.MinInt32 { - msg := fmt.Sprintf("decimal exponent out of range: %v", val) - return nil, &SyntaxError{msg, b.pos - vlen} - } - - exp = val - len -= vlen - } - - if len > 0 { - if err := b.readBigInt(len, coef); err != nil { - return nil, err - } - } - - return NewDecimal(coef, int32(exp)), nil -} - -// ReadSymbolID reads a symbol value. -func (b *bitstream) ReadSymbolID() (uint64, error) { - if b.code != bitcodeSymbol { - panic("not a symbol") - } - - if b.len > 8 { - return 0, &SyntaxError{"symbol id too large", b.pos} - } - - bs, err := b.readN(b.len) - if err != nil { - return 0, err - } - - b.state = b.stateAfterValue() - b.clear() - - ret := uint64(0) - for _, b := range bs { - ret <<= 8 - ret ^= uint64(b) - } - return ret, nil -} - -// ReadString reads a string value. -func (b *bitstream) ReadString() (string, error) { - if b.code != bitcodeString { - panic("not a string") - } - - bs, err := b.readN(b.len) - if err != nil { - return "", err - } - - b.state = b.stateAfterValue() - b.clear() - - return string(bs), nil -} - -// ReadBytes reads a blob or clob value. -func (b *bitstream) ReadBytes() ([]byte, error) { - if b.code != bitcodeClob && b.code != bitcodeBlob { - panic("not a lob") - } - - bs, err := b.readN(b.len) - if err != nil { - return nil, err - } - - b.state = b.stateAfterValue() - b.clear() - - return bs, nil -} - -// Clear clears the current code and len. -func (b *bitstream) clear() { - b.code = bitcodeNone - b.null = false - b.len = 0 -} - -// ReadBigInt reads a fixed-length integer of the given length and stores -// the value in the given big.Int. -func (b *bitstream) readBigInt(len uint64, ret *big.Int) error { - bs, err := b.readN(len) - if err != nil { - return err - } - - neg := (bs[0]&0x80 != 0) - bs[0] &= 0x7F - if bs[0] == 0 { - bs = bs[1:] - } - - ret.SetBytes(bs) - if neg { - ret.Neg(ret) - } - - return nil -} - -// ReadVarUint reads a variable-length-encoded uint. -func (b *bitstream) readVarUint() (uint64, error) { - val, _, err := b.readVarUintLen(b.remaining()) - return val, err -} - -// ReadVarUintLen reads a variable-length-encoded uint of at most max bytes, -// returning the value and its actual length in bytes. -func (b *bitstream) readVarUintLen(max uint64) (uint64, uint64, error) { - if max > 10 { - max = 10 - } - - val := uint64(0) - len := uint64(0) - - for { - if len >= max { - return 0, 0, &SyntaxError{"varuint too large", b.pos} - } - - c, err := b.read1() - if err != nil { - return 0, 0, err - } - - val <<= 7 - val ^= uint64(c & 0x7F) - len++ - - if c&0x80 != 0 { - return val, len, nil - } - } -} - -// SkipVarUint skips over a variable-length-encoded uint. -func (b *bitstream) skipVarUint() error { - _, err := b.skipVarUintLen(b.remaining()) - return err -} - -// SkipVarUintLen skips over a variable-length-encoded uint of at most max bytes. -func (b *bitstream) skipVarUintLen(max uint64) (uint64, error) { - if max > 10 { - max = 10 - } - - len := uint64(0) - for { - if len >= max { - return 0, &SyntaxError{"varuint too large", b.pos - len} - } - - c, err := b.read1() - if err != nil { - return 0, err - } - - len++ - - if c&0x80 != 0 { - return len, nil - } - } -} - -// Remaining returns the number of bytes remaining in the current container. -func (b *bitstream) remaining() uint64 { - if b.stack.empty() { - return math.MaxUint64 - } - - end := b.stack.peek().end - if b.pos > end { - panic(fmt.Sprintf("pos (%v) > end (%v)", b.pos, end)) - } - - return end - b.pos -} - -// ReadVarIntLen reads a variable-length-encoded int of at most max bytes, -// returning the value and its actual length in bytes -func (b *bitstream) readVarIntLen(max uint64) (int64, uint64, error) { - if max == 0 { - return 0, 0, &SyntaxError{"varint too large", b.pos} - } - if max > 10 { - max = 10 - } - - // Read the first byte, which contains the sign bit. - c, err := b.read1() - if err != nil { - return 0, 0, err - } - - sign := int64(1) - if c&0x40 != 0 { - sign = -1 - } - - val := int64(c & 0x3F) - len := uint64(1) - - // Check if that was the last (only) byte. - if c&0x80 != 0 { - return val * sign, len, nil - } - - for { - if len >= max { - return 0, 0, &SyntaxError{"varint too large", b.pos - len} - } - - c, err := b.read1() - if err != nil { - return 0, 0, err - } - - val <<= 7 - val ^= int64(c & 0x7F) - len++ - - if c&0x80 != 0 { - return val * sign, len, nil - } - } -} - -// StateAfterValue returns the state this stream is in after reading a value. -func (b *bitstream) stateAfterValue() bss { - if b.stack.peek().code == bitcodeStruct { - return bssBeforeFieldID - } - return bssBeforeValue -} - -var bitcodes = []bitcode{ - bitcodeNull, // 0x00 - bitcodeFalse, // 0x10 - bitcodeInt, // 0x20 - bitcodeNegInt, // 0x30 - bitcodeFloat, // 0x40 - bitcodeDecimal, // 0x50 - bitcodeTimestamp, // 0x60 - bitcodeSymbol, // 0x70 - bitcodeString, // 0x80 - bitcodeClob, // 0x90 - bitcodeBlob, // 0xA0 - bitcodeList, // 0xB0 - bitcodeSexp, // 0xC0 - bitcodeStruct, // 0xD0 - bitcodeAnnotation, // 0xE0 -} - -// ParseTag parses a tag byte into a typecode and a length. -func parseTag(c int) (bitcode, uint64) { - high := (c >> 4) & 0x0F - low := c & 0x0F - - code := bitcodeNone - if high < len(bitcodes) { - code = bitcodes[high] - } - - return code, uint64(low) -} - -// ReadN reads the next n bytes of input from the underlying stream. -func (b *bitstream) readN(n uint64) ([]byte, error) { - if n == 0 { - return nil, nil - } - - bs := make([]byte, n) - actual, err := b.in.Read(bs) - b.pos += uint64(actual) - - if err == io.EOF { - return nil, &UnexpectedEOFError{b.pos} - } - if err != nil { - return nil, &IOError{err} - } - - return bs, nil -} - -// Read1 reads the next byte of input from the underlying stream, returning -// an UnexpectedEOFError if it's an EOF. -func (b *bitstream) read1() (int, error) { - c, err := b.read() - if err != nil { - return 0, err - } - if c == -1 { - return 0, &UnexpectedEOFError{b.pos} - } - return c, nil -} - -// Read reads the next byte of input from the underlying stream. It returns -// -1 instead of io.EOF if we've hit the end of the stream, because I find -// that easier to reason about. -func (b *bitstream) read() (int, error) { - c, err := b.in.ReadByte() - b.pos++ - - if err == io.EOF { - return -1, nil - } - if err != nil { - return 0, &IOError{err} - } - - return int(c), nil -} - -// Skip skips n bytes of input from the underlying stream. -func (b *bitstream) skip(n uint64) error { - actual, err := b.in.Discard(int(n)) - b.pos += uint64(actual) - - if err == io.EOF { - return nil - } - if err != nil { - return &IOError{err} - } - - return nil -} - -// A bitnode represents a container value, including its typecode and -// the offset at which it (supposedly) ends. -type bitnode struct { - code bitcode - end uint64 -} - -// A stack of bitnodes representing container values that we're currently -// stepped in to. -type bitstack struct { - arr []bitnode -} - -// Empty returns true if this bitstack is empty. -func (b *bitstack) empty() bool { - return len(b.arr) == 0 -} - -// Peek peeks at the top bitnode on the stack. -func (b *bitstack) peek() bitnode { - if len(b.arr) == 0 { - return bitnode{} - } - return b.arr[len(b.arr)-1] -} - -// Push pushes a bitnode onto the stack. -func (b *bitstack) push(code bitcode, end uint64) { - b.arr = append(b.arr, bitnode{code, end}) -} - -// Pop pops a bitnode from the stack. -func (b *bitstack) pop() { - if len(b.arr) == 0 { - panic("pop called on empty bitstack") - } - b.arr = b.arr[:len(b.arr)-1] -} diff --git a/bitstream_test.go b/bitstream_test.go deleted file mode 100644 index 8cbf57ab..00000000 --- a/bitstream_test.go +++ /dev/null @@ -1,111 +0,0 @@ -package ion - -import "testing" - -func TestBitstream(t *testing.T) { - ion := []byte{ - 0xE0, 0x01, 0x00, 0xEA, // $ion_1_0 - 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, // $ion_symbol_table::{ - 0x86, 0xBE, 0x8E, // imports:[ - 0xDD, // { - 0x84, 0x85, 'b', 'o', 'g', 'u', 's', // name: "bogus" - 0x85, 0x21, 0x2A, // version: 42 - 0x88, 0x21, 0x64, // max_id: 100 - // }] - 0x87, 0xB8, // symbols: [ - 0x83, 'f', 'o', 'o', // "foo" - 0x83, 'b', 'a', 'r', // "bar" - // ] - // } - 0xD0, // {} - 0xEA, 0x81, 0xEE, 0xD7, // foo::{ - 0x84, 0xE3, 0x81, 0xEF, 0x0F, // name:bar::null, - 0x88, 0x20, // max_id:0 - // } - } - - b := bitstream{} - b.InitBytes(ion) - - next := func(code bitcode, null bool, len uint64) { - if err := b.Next(); err != nil { - t.Fatal(err) - } - if b.Code() != code { - t.Errorf("expected code=%v, got %v", code, b.Code()) - } - if b.IsNull() != null { - t.Errorf("expected null=%v, got %v", null, b.IsNull()) - } - if b.Len() != len { - t.Errorf("expected len=%v, got %v", len, b.Len()) - } - } - - fieldid := func(eid uint64) { - id, err := b.ReadFieldID() - if err != nil { - t.Fatal(err) - } - if id != eid { - t.Errorf("expected %v, got %v", eid, id) - } - } - - next(bitcodeBVM, false, 3) - maj, min, err := b.ReadBVM() - if err != nil { - t.Fatal(err) - } - if maj != 1 && min != 0 { - t.Errorf("expected $ion_1.0, got $ion_%v.%v", maj, min) - } - - next(bitcodeAnnotation, false, 31) - ids, err := b.ReadAnnotationIDs() - if err != nil { - t.Fatal(err) - } - if len(ids) != 1 || ids[0] != 3 { // $ion_symbol_table - t.Errorf("expected [3], got %v", ids) - } - - next(bitcodeStruct, false, 27) - b.StepIn() - { - next(bitcodeFieldID, false, 0) - fieldid(6) // imports - - next(bitcodeList, false, 14) - b.StepIn() - { - next(bitcodeStruct, false, 13) - } - if err := b.StepOut(); err != nil { - t.Fatal(err) - } - - next(bitcodeFieldID, false, 0) - // fieldid(7) // symbols - - next(bitcodeList, false, 8) - next(bitcodeEOF, false, 0) - } - if err := b.StepOut(); err != nil { - t.Fatal(err) - } - - next(bitcodeStruct, false, 0) - next(bitcodeAnnotation, false, 10) - next(bitcodeEOF, false, 0) - next(bitcodeEOF, false, 0) -} - -func TestBitcodeString(t *testing.T) { - for i := bitcodeNone; i <= bitcodeAnnotation+1; i++ { - str := i.String() - if str == "" { - t.Errorf("expected non-empty string for bitcode %v", uint8(i)) - } - } -} diff --git a/buf.go b/buf.go deleted file mode 100644 index b4a30fe1..00000000 --- a/buf.go +++ /dev/null @@ -1,119 +0,0 @@ -package ion - -import ( - "io" -) - -// Writing binary ion is a bit tricky: values are preceded by their length, -// which can be hard to predict until we've actually written out the value. -// To make matters worse, we can't predict the length of the /length/ ahead -// of time in order to reserve space for it, because it uses a variable-length -// encoding. To avoid copying bytes around all over the place, we write into -// an in-memory tree structure, which we then blast out to the actual io.Writer -// once all the relevant lengths are known. - -// A bufnode is a node in the partially-serialized tree. -type bufnode interface { - Len() uint64 - EmitTo(w io.Writer) error -} - -// A bufseq is a bufnode that's also an appendable sequence of bufnodes. -type bufseq interface { - bufnode - Append(n bufnode) -} - -var _ bufnode = atom([]byte{}) -var _ bufseq = &datagram{} -var _ bufseq = &container{} - -// An atom is a value that has been fully serialized and can be emitted directly. -type atom []byte - -func (a atom) Len() uint64 { - return uint64(len(a)) -} - -func (a atom) EmitTo(w io.Writer) error { - _, err := w.Write(a) - return err -} - -// A datagram is a sequence of nodes that will be emitted one -// after another. Most notably, used to buffer top-level values -// when we haven't yet finalized the local symbol table. -type datagram struct { - len uint64 - children []bufnode -} - -func (d *datagram) Append(n bufnode) { - d.len += n.Len() - d.children = append(d.children, n) -} - -func (d *datagram) Len() uint64 { - return d.len -} - -func (d *datagram) EmitTo(w io.Writer) error { - for _, child := range d.children { - if err := child.EmitTo(w); err != nil { - return err - } - } - - return nil -} - -// A container is a datagram that's preceeded by a code+length tag. -type container struct { - code byte - datagram -} - -func (c *container) Len() uint64 { - if c.len < 0x0E { - return c.len + 1 - } - return c.len + (varUintLen(c.len) + 1) -} - -func (c *container) EmitTo(w io.Writer) error { - var arr [11]byte - buf := arr[:0] - buf = appendTag(buf, c.code, c.len) - - if _, err := w.Write(buf); err != nil { - return err - } - return c.datagram.EmitTo(w) -} - -// A bufstack is a stack of bufseqs, more or less matching the -// stack of BeginList/Sexp/Struct calls made on a binaryWriter. -// The top of the stack is the sequence we're currently writing -// values into; when it's popped off, it will be appended to the -// bufseq below it. -type bufstack struct { - arr []bufseq -} - -func (s *bufstack) peek() bufseq { - if len(s.arr) == 0 { - return nil - } - return s.arr[len(s.arr)-1] -} - -func (s *bufstack) push(b bufseq) { - s.arr = append(s.arr, b) -} - -func (s *bufstack) pop() { - if len(s.arr) == 0 { - panic("pop called on an empty stack") - } - s.arr = s.arr[:len(s.arr)-1] -} diff --git a/buf_test.go b/buf_test.go deleted file mode 100644 index d4b6a308..00000000 --- a/buf_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package ion - -import ( - "bytes" - "testing" -) - -func TestBufnode(t *testing.T) { - root := container{code: 0xE0} - root.Append(atom([]byte{0x81, 0x83})) - { - symtab := &container{code: 0xD0} - { - symtab.Append(atom([]byte{0x86})) // varUint(6) - { - imps := &container{code: 0xB0} - { - imp0 := &container{code: 0xD0} - { - imp0.Append(atom([]byte{0x84})) // varUint(4) - imp0.Append(atom([]byte{0x85, 'b', 'o', 'g', 'u', 's'})) - imp0.Append(atom([]byte{0x85})) // varUint(5) - imp0.Append(atom([]byte{0x21, 0x2A})) - imp0.Append(atom([]byte{0x88})) // varUint(8) - imp0.Append(atom([]byte{0x21, 0x64})) - } - imps.Append(imp0) - } - symtab.Append(imps) - } - - symtab.Append(atom([]byte{0x87})) // varUint(7) - { - syms := &container{code: 0xB0} - { - syms.Append(atom([]byte{0x83, 'f', 'o', 'o'})) - syms.Append(atom([]byte{0x83, 'b', 'a', 'r'})) - } - symtab.Append(syms) - } - } - root.Append(symtab) - } - - buf := bytes.Buffer{} - if err := root.EmitTo(&buf); err != nil { - t.Fatal(err) - } - - val := buf.Bytes() - eval := []byte{ - // $ion_symbol_table::{ - 0xEE, 0x9F, 0x81, 0x83, 0xDE, 0x9B, - // imports:[ - 0x86, 0xBE, 0x8E, - // { - 0xDD, - // name: "bogus" - 0x84, 0x85, 'b', 'o', 'g', 'u', 's', - // version: 42 - 0x85, 0x21, 0x2A, - // max_id: 100 - 0x88, 0x21, 0x64, - // } - // ], - // symbols:[ - 0x87, 0xB8, - // "foo", - 0x83, 'f', 'o', 'o', - // "bar" - 0x83, 'b', 'a', 'r', - // ] - // } - } - - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", fmtbytes(eval), fmtbytes(val)) - } -} diff --git a/catalog.go b/catalog.go deleted file mode 100644 index 65d4f88f..00000000 --- a/catalog.go +++ /dev/null @@ -1,88 +0,0 @@ -package ion - -import ( - "bytes" - "fmt" - "io" - "strings" -) - -// A Catalog provides access to shared symbol tables. -type Catalog interface { - FindExact(name string, version int) SharedSymbolTable - FindLatest(name string) SharedSymbolTable -} - -// A basicCatalog wraps an in-memory collection of shared symbol tables. -type basicCatalog struct { - ssts map[string]SharedSymbolTable - latest map[string]SharedSymbolTable -} - -// NewCatalog creates a new basic catalog containing the given symbol tables. -func NewCatalog(ssts ...SharedSymbolTable) Catalog { - cat := &basicCatalog{ - ssts: make(map[string]SharedSymbolTable), - latest: make(map[string]SharedSymbolTable), - } - for _, sst := range ssts { - cat.add(sst) - } - return cat -} - -// Add adds a shared symbol table to the catalog. -func (c *basicCatalog) add(sst SharedSymbolTable) { - key := fmt.Sprintf("%v/%v", sst.Name(), sst.Version()) - c.ssts[key] = sst - - cur, ok := c.latest[sst.Name()] - if !ok || sst.Version() > cur.Version() { - c.latest[sst.Name()] = sst - } -} - -// FindExact attempts to find a shared symbol table with the given name and version. -func (c *basicCatalog) FindExact(name string, version int) SharedSymbolTable { - key := fmt.Sprintf("%v/%v", name, version) - return c.ssts[key] -} - -// FindLatest finds the shared symbol table with the given name and largest version. -func (c *basicCatalog) FindLatest(name string) SharedSymbolTable { - return c.latest[name] -} - -// A System is a reader factory wrapping a catalog. -type System struct { - Catalog Catalog -} - -// NewReader creates a new reader using this system's catalog. -func (s System) NewReader(in io.Reader) Reader { - return NewReaderCat(in, s.Catalog) -} - -// NewReaderStr creates a new reader using this system's catalog. -func (s System) NewReaderStr(in string) Reader { - return NewReaderCat(strings.NewReader(in), s.Catalog) -} - -// NewReaderBytes creates a new reader using this system's catalog. -func (s System) NewReaderBytes(in []byte) Reader { - return NewReaderCat(bytes.NewReader(in), s.Catalog) -} - -// Unmarshal unmarshals Ion data using this system's catalog. -func (s System) Unmarshal(data []byte, v interface{}) error { - r := s.NewReaderBytes(data) - d := NewDecoder(r) - return d.DecodeTo(v) -} - -// UnmarshalStr unmarshals Ion data using this system's catalog. -func (s System) UnmarshalStr(data string, v interface{}) error { - r := s.NewReaderStr(data) - d := NewDecoder(r) - return d.DecodeTo(v) -} diff --git a/catalog_test.go b/catalog_test.go deleted file mode 100644 index 2e3bd2e0..00000000 --- a/catalog_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package ion - -import ( - "bytes" - "fmt" - "testing" -) - -type Item struct { - ID int `json:"id"` - Name string `json:"name"` - Description string `json:"description"` -} - -func TestCatalog(t *testing.T) { - sst := NewSharedSymbolTable("item", 1, []string{ - "item", - "id", - "name", - "description", - }) - - buf := bytes.Buffer{} - out := NewBinaryWriter(&buf, sst) - - for i := 0; i < 10; i++ { - out.Annotation("item") - MarshalTo(out, &Item{ - ID: i, - Name: fmt.Sprintf("Item %v", i), - Description: fmt.Sprintf("The %vth test item", i), - }) - } - if err := out.Finish(); err != nil { - t.Fatal(err) - } - - bs := buf.Bytes() - - sys := System{Catalog: NewCatalog(sst)} - in := sys.NewReaderBytes(bs) - - i := 0 - for ; ; i++ { - item := Item{} - err := UnmarshalFrom(in, &item) - if err == ErrNoInput { - break - } - if err != nil { - t.Fatal(err) - } - - if item.ID != i { - t.Errorf("expected id=%v, got %v", i, item.ID) - } - } - - if i != 10 { - t.Errorf("expected i=10, got %v", i) - } -} diff --git a/consts.go b/consts.go deleted file mode 100644 index 1a49dec4..00000000 --- a/consts.go +++ /dev/null @@ -1,52 +0,0 @@ -package ion - -import ( - "reflect" - "time" -) - -var binaryNulls = func() []byte { - ret := make([]byte, StructType+1) - ret[NoType] = 0x0F - ret[NullType] = 0x0F - ret[BoolType] = 0x1F - ret[IntType] = 0x2F - ret[FloatType] = 0x4F - ret[DecimalType] = 0x5F - ret[TimestampType] = 0x6F - ret[SymbolType] = 0x7F - ret[StringType] = 0x8F - ret[ClobType] = 0x9F - ret[BlobType] = 0xAF - ret[ListType] = 0xBF - ret[SexpType] = 0xCF - ret[StructType] = 0xDF - return ret -}() - -var textNulls []string = func() []string { - ret := make([]string, StructType+1) - ret[NoType] = "null" - ret[NullType] = "null.null" - ret[BoolType] = "null.bool" - ret[IntType] = "null.int" - ret[FloatType] = "null.float" - ret[DecimalType] = "null.decimal" - ret[TimestampType] = "null.timestamp" - ret[SymbolType] = "null.symbol" - ret[StringType] = "null.string" - ret[ClobType] = "null.clob" - ret[BlobType] = "null.blob" - ret[ListType] = "null.list" - ret[SexpType] = "null.sexp" - ret[StructType] = "null.struct" - return ret -}() - -var hexChars = []byte{ - '0', '1', '2', '3', '4', '5', '6', '7', - '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', -} - -var timeType = reflect.TypeOf(time.Time{}) -var decimalType = reflect.TypeOf(Decimal{}) diff --git a/ctx.go b/ctx.go deleted file mode 100644 index d3a1b9ed..00000000 --- a/ctx.go +++ /dev/null @@ -1,65 +0,0 @@ -package ion - -import "fmt" - -// ctx is the current reader or writer context. -type ctx uint8 - -const ( - ctxAtTopLevel ctx = iota - ctxInStruct - ctxInList - ctxInSexp -) - -func ctxToContainerType(c ctx) Type { - switch c { - case ctxInList: - return ListType - case ctxInSexp: - return SexpType - case ctxInStruct: - return StructType - default: - return NoType - } -} - -func containerTypeToCtx(t Type) ctx { - switch t { - case ListType: - return ctxInList - case SexpType: - return ctxInSexp - case StructType: - return ctxInStruct - default: - panic(fmt.Sprintf("type %v is not a container type", t)) - } -} - -// ctxstack is a context stack. -type ctxstack struct { - arr []ctx -} - -// peek returns the current context. -func (c *ctxstack) peek() ctx { - if len(c.arr) == 0 { - return ctxAtTopLevel - } - return c.arr[len(c.arr)-1] -} - -// push pushes a new context onto the stack. -func (c *ctxstack) push(ctx ctx) { - c.arr = append(c.arr, ctx) -} - -// pop pops the top context off the stack. -func (c *ctxstack) pop() { - if len(c.arr) == 0 { - panic("pop called at top level") - } - c.arr = c.arr[:len(c.arr)-1] -} diff --git a/decimal.go b/decimal.go deleted file mode 100644 index 1b522286..00000000 --- a/decimal.go +++ /dev/null @@ -1,342 +0,0 @@ -package ion - -import ( - "fmt" - "math" - "math/big" - "strconv" - "strings" -) - -// A ParseError is returned if ParseDecimal is called with a parameter that -// cannot be parsed as a Decimal. -type ParseError struct { - Num string - Msg string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("ion: ParseDecimal(%v): %v", e.Num, e.Msg) -} - -// TODO: Explicitly track precision? - -// Decimal is an arbitrary-precision decimal value. -type Decimal struct { - n *big.Int - scale int32 -} - -// NewDecimal creates a new decimal whose value is equal to n * 10^exp. -func NewDecimal(n *big.Int, exp int32) *Decimal { - return &Decimal{ - n: n, - scale: -exp, - } -} - -// NewDecimalInt creates a new decimal whose value is equal to n. -func NewDecimalInt(n int64) *Decimal { - return NewDecimal(big.NewInt(n), 0) -} - -// MustParseDecimal parses the given string into a decimal object, -// panicing on error. -func MustParseDecimal(in string) *Decimal { - d, err := ParseDecimal(in) - if err != nil { - panic(err) - } - return d -} - -// ParseDecimal parses the given string into a decimal object, -// returning an error on failure. -func ParseDecimal(in string) (*Decimal, error) { - if len(in) == 0 { - return nil, &ParseError{in, "empty string"} - } - - exponent := int32(0) - - d := strings.IndexAny(in, "Dd") - if d != -1 { - // There's an explicit exponent. - exp := in[d+1:] - if len(exp) == 0 { - return nil, &ParseError{in, "unexpected end of input after d"} - } - - tmp, err := strconv.ParseInt(exp, 10, 32) - if err != nil { - return nil, &ParseError{in, err.Error()} - } - - exponent = int32(tmp) - in = in[:d] - } - - d = strings.Index(in, ".") - if d != -1 { - // There's zero or more decimal places. - ipart := in[:d] - fpart := in[d+1:] - - exponent -= int32(len(fpart)) - in = ipart + fpart - } - - n, ok := new(big.Int).SetString(in, 10) - if !ok { - // Unfortunately this is all we get? - return nil, &ParseError{in, "cannot parse coefficient"} - } - - return NewDecimal(n, exponent), nil -} - -// CoEx returns this decimal's coefficient and exponent. -func (d *Decimal) CoEx() (*big.Int, int32) { - return d.n, -d.scale -} - -// Abs returns the absolute value of this Decimal. -func (d *Decimal) Abs() *Decimal { - return &Decimal{ - n: new(big.Int).Abs(d.n), - scale: d.scale, - } -} - -// Add returns the result of adding this Decimal to another Decimal. -func (d *Decimal) Add(o *Decimal) *Decimal { - // a*10^x + b*10^y = (a*10^(x-y) + b) * 10^y - dd, oo := rescale(d, o) - return &Decimal{ - n: new(big.Int).Add(dd.n, oo.n), - scale: dd.scale, - } -} - -// Sub returns the result of substrating another Decimal from this Decimal. -func (d *Decimal) Sub(o *Decimal) *Decimal { - dd, oo := rescale(d, o) - return &Decimal{ - n: new(big.Int).Sub(dd.n, oo.n), - scale: dd.scale, - } -} - -// Neg returns the negative of this Decimal. -func (d *Decimal) Neg() *Decimal { - return &Decimal{ - n: new(big.Int).Neg(d.n), - scale: d.scale, - } -} - -// Mul multiplies two decimals and returns the result. -func (d *Decimal) Mul(o *Decimal) *Decimal { - // a*10^x * b*10^y = (a*b) * 10^(x+y) - scale := int64(d.scale) + int64(o.scale) - if scale > math.MaxInt32 || scale < math.MinInt32 { - panic("exponent out of bounds") - } - - return &Decimal{ - n: new(big.Int).Mul(d.n, o.n), - scale: int32(scale), - } -} - -// ShiftL returns a new decimal shifted the given number of decimal -// places to the left. It's a computationally-cheap way to compute -// d * 10^shift. -func (d *Decimal) ShiftL(shift int) *Decimal { - scale := int64(d.scale) - int64(shift) - if scale > math.MaxInt32 || scale < math.MinInt32 { - panic("exponent out of bounds") - } - - return &Decimal{ - n: d.n, - scale: int32(scale), - } -} - -// ShiftR returns a new decimal shifted the given number of decimal -// places to the right. It's a computationally-cheap way to compute -// d / 10^shift. -func (d *Decimal) ShiftR(shift int) *Decimal { - scale := int64(d.scale) + int64(shift) - if scale > math.MaxInt32 || scale < math.MinInt32 { - panic("exponent out of bounds") - } - - return &Decimal{ - n: d.n, - scale: int32(scale), - } -} - -// TODO: Div, Exp, etc? - -// Sign returns -1 if the value is less than 0, 0 if it is equal to zero, -// and +1 if it is greater than zero. -func (d *Decimal) Sign() int { - return d.n.Sign() -} - -// Cmp compares two decimals, returning -1 if d is smaller, +1 if d is -// larger, and 0 if they are equal (ignoring precision). -func (d *Decimal) Cmp(o *Decimal) int { - dd, oo := rescale(d, o) - return dd.n.Cmp(oo.n) -} - -// Equal determines if two decimals are equal (discounting precision, -// at least for now). -func (d *Decimal) Equal(o *Decimal) bool { - return d.Cmp(o) == 0 -} - -func rescale(a, b *Decimal) (*Decimal, *Decimal) { - if a.scale < b.scale { - return a.upscale(b.scale), b - } else if a.scale > b.scale { - return a, b.upscale(a.scale) - } else { - return a, b - } -} - -var ten = big.NewInt(10) - -// Make 'n' bigger by making 'scale' smaller, since we know we can -// do that. (1d100 -> 10d99). Makes comparisons and math easier, at the -// expense of more storage space. Technically speaking implies adding -// more precision, but we're not tracking that too closely. -func (d *Decimal) upscale(scale int32) *Decimal { - diff := int64(scale) - int64(d.scale) - if diff < 0 { - panic("can't upscale to a smaller scale") - } - - pow := new(big.Int).Exp(ten, big.NewInt(diff), nil) - n := new(big.Int).Mul(d.n, pow) - - return &Decimal{ - n: n, - scale: scale, - } -} - -// Trunc attempts to truncate this decimal to an int64, dropping any fractional bits. -func (d *Decimal) Trunc() (int64, error) { - if d.scale < 0 { - // Don't even bother trying this with numbers that *definitely* too big to represent - // as an int64, because upscale(0) will consume a bunch of memory. - if d.scale < -20 { - return 0, &strconv.NumError{ - Func: "ParseInt", - Num: d.String(), - Err: strconv.ErrRange, - } - } - d = d.upscale(0) - } - - str := d.n.String() - - want := len(str) - int(d.scale) - if want <= 0 { - return 0, nil - } - - return strconv.ParseInt(str[:want], 10, 64) -} - -// Truncate returns a new decimal, truncated to the given number of -// decimal digits of precision. It does not round, so 19.Truncate(1) -// = 1d1. -func (d *Decimal) Truncate(precision int) *Decimal { - if precision <= 0 { - panic("precision must be positive") - } - - // Is there a better way to calculate precision? It really - // seems like there should be... - - str := d.n.String() - if str[0] == '-' { - // Cheating a bit. - precision++ - } - - diff := len(str) - precision - if diff <= 0 { - // Already small enough, nothing to truncate. - return d - } - - // Lazy man's division by a power of 10. - n, ok := new(big.Int).SetString(str[:precision], 10) - if !ok { - // Should never happen, since we started with a valid int. - panic("failed to parse integer") - } - - scale := int64(d.scale) - int64(diff) - if scale < math.MinInt32 { - panic("exponent out of range") - } - - return &Decimal{ - n: n, - scale: int32(scale), - } -} - -// String formats the decimal as a string in Ion text format. -func (d *Decimal) String() string { - switch { - case d.scale == 0: - // Value is an unscaled integer. Just mark it as a decimal. - return d.n.String() + "." - - case d.scale < 0: - // Value is a upscaled integer, nn'd'ss - return d.n.String() + "d" + fmt.Sprintf("%d", -d.scale) - - default: - // Value is a downscaled integer nn.nn('d'-ss)? - str := d.n.String() - idx := len(str) - int(d.scale) - - prefix := 1 - if d.n.Sign() < 0 { - // Account for leading '-'. - prefix++ - } - - if idx >= prefix { - // Put the decimal point in the middle, no exponent. - return str[:idx] + "." + str[idx:] - } - - // Put the decimal point at the beginning and - // add a (negative) exponent. - b := strings.Builder{} - b.WriteString(str[:prefix]) - - if len(str) > prefix { - b.WriteString(".") - b.WriteString(str[prefix:]) - } - - b.WriteString("d") - b.WriteString(fmt.Sprintf("%d", idx-prefix)) - - return b.String() - } -} diff --git a/decimal_test.go b/decimal_test.go deleted file mode 100644 index 21b72c9e..00000000 --- a/decimal_test.go +++ /dev/null @@ -1,312 +0,0 @@ -package ion - -import ( - "fmt" - "math/big" - "testing" -) - -func TestDecimalToString(t *testing.T) { - test := func(n int64, scale int32, expected string) { - t.Run(expected, func(t *testing.T) { - d := Decimal{ - n: big.NewInt(n), - scale: scale, - } - actual := d.String() - if actual != expected { - t.Errorf("expected '%v', got '%v'", expected, actual) - } - }) - } - - test(0, 0, "0.") - test(0, -1, "0d1") - test(0, 1, "0d-1") - - test(1, 0, "1.") - test(1, -1, "1d1") - test(1, 1, "1d-1") - - test(-1, 0, "-1.") - test(-1, -1, "-1d1") - test(-1, 1, "-1d-1") - - test(123, 0, "123.") - test(-456, 0, "-456.") - - test(123, -5, "123d5") - test(-456, -5, "-456d5") - - test(123, 1, "12.3") - test(123, 2, "1.23") - test(123, 3, "1.23d-1") - test(123, 4, "1.23d-2") - - test(-456, 1, "-45.6") - test(-456, 2, "-4.56") - test(-456, 3, "-4.56d-1") - test(-456, 4, "-4.56d-2") -} - -func TestParseDecimal(t *testing.T) { - test := func(in string, n *big.Int, scale int32) { - t.Run(in, func(t *testing.T) { - d, err := ParseDecimal(in) - if err != nil { - t.Fatal(err) - } - - if n.Cmp(d.n) != 0 { - t.Errorf("wrong n; expected %v, got %v", n, d.n) - } - if scale != d.scale { - t.Errorf("wrong scale; expected %v, got %v", scale, d.scale) - } - }) - } - - test("0", big.NewInt(0), 0) - test("-0", big.NewInt(0), 0) - test("0D0", big.NewInt(0), 0) - test("-0d-1", big.NewInt(0), 1) - - test("1.", big.NewInt(1), 0) - test("1.0", big.NewInt(10), 1) - test("0.123", big.NewInt(123), 3) - - test("1d0", big.NewInt(1), 0) - test("1d1", big.NewInt(1), -1) - test("1d+1", big.NewInt(1), -1) - test("1d-1", big.NewInt(1), 1) - - test("-0.12d4", big.NewInt(-12), -2) -} - -func absF(d *Decimal) *Decimal { return d.Abs() } -func negF(d *Decimal) *Decimal { return d.Neg() } - -type unaryop struct { - sym string - fun func(d *Decimal) *Decimal -} - -var abs = &unaryop{"abs", absF} -var neg = &unaryop{"neg", negF} - -func testUnaryOp(t *testing.T, a, e string, op *unaryop) { - t.Run(op.sym+"("+a+")="+e, func(t *testing.T) { - aa, _ := ParseDecimal(a) - ee, _ := ParseDecimal(e) - actual := op.fun(aa) - if !actual.Equal(ee) { - t.Errorf("expected %v, got %v", ee, actual) - } - }) -} - -func TestAbs(t *testing.T) { - test := func(a, e string) { - testUnaryOp(t, a, e, abs) - } - - test("0", "0") - test("1d100", "1d100") - test("-1d100", "1d100") - test("1.2d-3", "1.2d-3") - test("-1.2d-3", "1.2d-3") -} - -func TestNeg(t *testing.T) { - test := func(a, e string) { - testUnaryOp(t, a, e, neg) - } - - test("0", "0") - test("1d100", "-1d100") - test("-1d100", "1d100") - test("1.2d-3", "-1.2d-3") - test("-1.2d-3", "1.2d-3") -} - -func TestTrunc(t *testing.T) { - test := func(a string, eval int64) { - t.Run(fmt.Sprintf("trunc(%v)=%v", a, eval), func(t *testing.T) { - aa := MustParseDecimal(a) - val, err := aa.Trunc() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - test("0.", 0) - test("0.01", 0) - test("1.", 1) - test("-1.", -1) - test("1.01", 1) - test("-1.01", -1) - test("101", 101) - test("1d3", 1000) -} - -func addF(a, b *Decimal) *Decimal { return a.Add(b) } -func subF(a, b *Decimal) *Decimal { return a.Sub(b) } -func mulF(a, b *Decimal) *Decimal { return a.Mul(b) } - -type binop struct { - sym string - fun func(a, b *Decimal) *Decimal -} - -func TestShiftL(t *testing.T) { - test := func(a string, b int, e string) { - aa, _ := ParseDecimal(a) - ee, _ := ParseDecimal(e) - actual := aa.ShiftL(b) - if !actual.Equal(ee) { - t.Errorf("expected %v, got %v", ee, actual) - } - } - - test("0", 10, "0") - test("1", 0, "1") - test("123", 1, "1230") - test("123", 100, "123d100") - test("1.23d-100", 102, "123") -} - -func TestShiftR(t *testing.T) { - test := func(a string, b int, e string) { - aa, _ := ParseDecimal(a) - ee, _ := ParseDecimal(e) - actual := aa.ShiftR(b) - if !actual.Equal(ee) { - t.Errorf("expected %v, got %v", ee, actual) - } - } - - test("0", 10, "0") - test("1", 0, "1") - test("123", 1, "12.3") - test("123", 100, "1.23d-98") - test("1.23d100", 98, "123") -} - -var add = &binop{"+", addF} -var sub = &binop{"-", subF} -var mul = &binop{"*", mulF} - -func testBinaryOp(t *testing.T, a, b, e string, op *binop) { - t.Run(a+op.sym+b+"="+e, func(t *testing.T) { - aa, _ := ParseDecimal(a) - bb, _ := ParseDecimal(b) - ee, _ := ParseDecimal(e) - - actual := op.fun(aa, bb) - if !actual.Equal(ee) { - t.Errorf("expected %v, got %v", ee, actual) - } - }) -} - -func TestAdd(t *testing.T) { - test := func(a, b, e string) { - testBinaryOp(t, a, b, e, add) - } - - test("1", "0", "1") - test("1", "1", "2") - test("1", "0.1", "1.1") - test("0.3", "0.06", "0.36") - test("1", "100", "101") - test("1d100", "1d98", "101d98") - test("1d-100", "1d-98", "1.01d-98") -} - -func TestSub(t *testing.T) { - test := func(a, b, e string) { - testBinaryOp(t, a, b, e, sub) - } - - test("1", "0", "1") - test("1", "1", "0") - test("1", "0.1", "0.9") - test("0.3", "0.06", "0.24") - test("1", "100", "-99") - test("1d100", "1d98", "99d98") - test("1d-100", "1d-98", "-99d-100") -} - -func TestMul(t *testing.T) { - test := func(a, b, e string) { - testBinaryOp(t, a, b, e, mul) - } - - test("1", "0", "0") - test("1", "1", "1") - test("2", "-1", "-2") - test("7", "6", "42") - test("10", "0.3", "3") - test("3d100", "2d50", "6d150") - test("3d-100", "2d-50", "6d-150") - test("2d100", "4d-98", "8d2") -} - -func TestTruncate(t *testing.T) { - test := func(a string, p int, expected string) { - t.Run(fmt.Sprintf("trunc(%v,%v)", a, p), func(t *testing.T) { - aa := MustParseDecimal(a) - actual := aa.Truncate(p).String() - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) - } - - test("1", 1, "1.") - test("1", 10, "1.") - test("10", 1, "1d1") - test("1999", 1, "1d3") - test("1.2345", 3, "1.23") - test("100d100", 2, "10d101") - test("1.2345d-100", 2, "1.2d-100") -} - -func TestCmp(t *testing.T) { - test := func(a, b string, expected int) { - t.Run("("+a+","+b+")", func(t *testing.T) { - ad, _ := ParseDecimal(a) - bd, _ := ParseDecimal(b) - actual := ad.Cmp(bd) - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) - } - - test("0", "0", 0) - test("0", "1", -1) - test("0", "-1", 1) - - test("1d2", "100", 0) - test("100", "1d2", 0) - test("1d2", "10", 1) - test("10", "1d2", -1) - - test("0.01", "1d-2", 0) - test("1d-2", "0.01", 0) - test("0.01", "1d-3", 1) - test("1d-3", "0.01", -1) -} - -func TestUpscale(t *testing.T) { - d, _ := ParseDecimal("1d1") - actual := d.upscale(4).String() - if actual != "10.0000" { - t.Errorf("expected 10.0000, got %v", actual) - } -} diff --git a/err.go b/err.go deleted file mode 100644 index 737076a9..00000000 --- a/err.go +++ /dev/null @@ -1,88 +0,0 @@ -package ion - -import "fmt" - -// A UsageError is returned when you use a Reader or Writer in an inappropriate way. -type UsageError struct { - API string - Msg string -} - -func (e *UsageError) Error() string { - return fmt.Sprintf("ion: usage error in %v: %v", e.API, e.Msg) -} - -// An IOError is returned when there is an error reading from or writing to an -// underlying io.Reader or io.Writer. -type IOError struct { - Err error -} - -func (e *IOError) Error() string { - return fmt.Sprintf("ion: i/o error: %v", e.Err) -} - -// A SyntaxError is returned when a Reader encounters invalid input for which no more -// specific error type is defined. -type SyntaxError struct { - Msg string - Offset uint64 -} - -func (e *SyntaxError) Error() string { - return fmt.Sprintf("ion: syntax error: %v (offset %v)", e.Msg, e.Offset) -} - -// An UnexpectedEOFError is returned when a Reader unexpectedly encounters an -// io.EOF error. -type UnexpectedEOFError struct { - Offset uint64 -} - -func (e *UnexpectedEOFError) Error() string { - return fmt.Sprintf("ion: unexpected end of input (offset %v)", e.Offset) -} - -// An UnsupportedVersionError is returned when a Reader encounters a binary version -// marker with a version that this library does not understand. -type UnsupportedVersionError struct { - Major int - Minor int - Offset uint64 -} - -func (e *UnsupportedVersionError) Error() string { - return fmt.Sprintf("ion: unsupported version %v.%v (offset %v)", e.Major, e.Minor, e.Offset) -} - -// An InvalidTagByteError is returned when a binary Reader encounters an invalid -// tag byte. -type InvalidTagByteError struct { - Byte byte - Offset uint64 -} - -func (e *InvalidTagByteError) Error() string { - return fmt.Sprintf("ion: invalid tag byte 0x%02X (offset %v)", e.Byte, e.Offset) -} - -// An UnexpectedRuneError is returned when a text Reader encounters an unexpected rune. -type UnexpectedRuneError struct { - Rune rune - Offset uint64 -} - -func (e *UnexpectedRuneError) Error() string { - return fmt.Sprintf("ion: unexpected rune %q (offset %v)", e.Rune, e.Offset) -} - -// An UnexpectedTokenError is returned when a text Reader encounters an unexpected -// token. -type UnexpectedTokenError struct { - Token string - Offset uint64 -} - -func (e *UnexpectedTokenError) Error() string { - return fmt.Sprintf("ion: unexpected token '%v' (offset %v)", e.Token, e.Offset) -} diff --git a/fields.go b/fields.go deleted file mode 100644 index 2f4a8f06..00000000 --- a/fields.go +++ /dev/null @@ -1,122 +0,0 @@ -package ion - -import ( - "fmt" - "reflect" - "strings" -) - -// A field is a reflectively-accessed field of a struct type. -type field struct { - name string - typ reflect.Type - path []int - omitEmpty bool -} - -// A fielder maps out the fields of a type. -type fielder struct { - fields []field - index map[string]bool -} - -// FieldsFor returns the fields of the given struct type. -// TODO: cache me. -func fieldsFor(t reflect.Type) []field { - fldr := fielder{index: map[string]bool{}} - fldr.inspect(t, nil) - return fldr.fields -} - -// Inspect recursively inspects a type to determine all of its fields. -func (f *fielder) inspect(t reflect.Type, path []int) { - for i := 0; i < t.NumField(); i++ { - sf := t.Field(i) - if !visible(&sf) { - // Skip non-visible fields. - continue - } - - tag := sf.Tag.Get("json") - if tag == "-" { - // Skip fields that are explicitly hidden by tag. - continue - } - name, opts := parseJSONTag(tag) - - newpath := make([]int, len(path)+1) - copy(newpath, path) - newpath[len(path)] = i - - ft := sf.Type - if ft.Name() == "" && ft.Kind() == reflect.Ptr { - ft = ft.Elem() - } - - if name == "" && sf.Anonymous && ft.Kind() == reflect.Struct { - // Dig in to the embedded struct. - f.inspect(ft, newpath) - } else { - // Add this named field. - if name == "" { - name = sf.Name - } - - if f.index[name] { - panic(fmt.Sprintf("too many fields named %v", name)) - } - f.index[name] = true - - f.fields = append(f.fields, field{ - name: name, - typ: ft, - path: newpath, - omitEmpty: omitEmpty(opts), - }) - } - } -} - -// Visible returns true if the given StructField should show up in the output. -func visible(sf *reflect.StructField) bool { - exported := sf.PkgPath == "" - if sf.Anonymous { - t := sf.Type - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if t.Kind() == reflect.Struct { - // Fields of embedded structs are visible even if the struct type itself is not. - return true - } - } - return exported -} - -// ParseJSONTag parses a `json:"..."` field tag, returning the name and opts. -func parseJSONTag(tag string) (string, string) { - if idx := strings.Index(tag, ","); idx != -1 { - // Ignore additional JSON options, at least for now. - return tag[:idx], tag[idx+1:] - } - return tag, "" -} - -// OmitEmpty returns true if opts includes "omitempty". -func omitEmpty(opts string) bool { - for opts != "" { - var o string - - i := strings.Index(opts, ",") - if i >= 0 { - o, opts = opts[:i], opts[i+1:] - } else { - o, opts = opts, "" - } - - if o == "omitempty" { - return true - } - } - return false -} diff --git a/marshal.go b/marshal.go deleted file mode 100644 index 7a5529d4..00000000 --- a/marshal.go +++ /dev/null @@ -1,338 +0,0 @@ -package ion - -import ( - "bytes" - "fmt" - "io" - "math/big" - "reflect" - "sort" - "time" -) - -// EncoderOpts holds bit-flag options for an Encoder. -type EncoderOpts uint - -const ( - // EncodeSortMaps instructs the encoder to write map keys in sorted order. - EncodeSortMaps EncoderOpts = 1 -) - -// MarshalText marshals values to text ion. -func MarshalText(v interface{}) ([]byte, error) { - buf := bytes.Buffer{} - w := NewTextWriterOpts(&buf, TextWriterQuietFinish) - e := Encoder{ - w: w, - opts: EncodeSortMaps, - } - - if err := e.Encode(v); err != nil { - return nil, err - } - if err := e.Finish(); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -// MarshalBinary marshals values to binary ion. -func MarshalBinary(v interface{}, ssts ...SharedSymbolTable) ([]byte, error) { - buf := bytes.Buffer{} - w := NewBinaryWriter(&buf, ssts...) - e := Encoder{w: w} - - if err := e.Encode(v); err != nil { - return nil, err - } - if err := e.Finish(); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -// MarshalBinaryLST marshals values to binary ion with a fixed local symbol table. -func MarshalBinaryLST(v interface{}, lst SymbolTable) ([]byte, error) { - buf := bytes.Buffer{} - w := NewBinaryWriterLST(&buf, lst) - e := Encoder{w: w} - - if err := e.Encode(v); err != nil { - return nil, err - } - if err := e.Finish(); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - -// MarshalTo marshals the given value to the given writer. It does -// not call Finish, so is suitable for encoding values inside of -// a partially-constructed Ion value. -func MarshalTo(w Writer, v interface{}) error { - e := Encoder{ - w: w, - } - return e.Encode(v) -} - -// An Encoder writes Ion values to an output stream. -type Encoder struct { - w Writer - opts EncoderOpts -} - -// NewEncoder creates a new encoder. -func NewEncoder(w Writer) *Encoder { - return NewEncoderOpts(w, 0) -} - -// NewEncoderOpts creates a new encoder with the specified options. -func NewEncoderOpts(w Writer, opts EncoderOpts) *Encoder { - return &Encoder{ - w: w, - opts: opts, - } -} - -// NewTextEncoder creates a new text Encoder. -func NewTextEncoder(w io.Writer) *Encoder { - return NewEncoder(NewTextWriter(w)) -} - -// NewBinaryEncoder creates a new binary Encoder. -func NewBinaryEncoder(w io.Writer, ssts ...SharedSymbolTable) *Encoder { - return NewEncoder(NewBinaryWriter(w, ssts...)) -} - -// NewBinaryEncoderLST creates a new binary Encoder with a fixed local symbol table. -func NewBinaryEncoderLST(w io.Writer, lst SymbolTable) *Encoder { - return NewEncoder(NewBinaryWriterLST(w, lst)) -} - -// Encode marshals the given value to Ion, writing it to the underlying writer. -func (m *Encoder) Encode(v interface{}) error { - return m.encodeValue(reflect.ValueOf(v)) -} - -// Finish finishes writing the current Ion datagram. -func (m *Encoder) Finish() error { - return m.w.Finish() -} - -// EncodeValue recursively encodes a value. -func (m *Encoder) encodeValue(v reflect.Value) error { - if !v.IsValid() { - m.w.WriteNull() - return nil - } - - t := v.Type() - switch t.Kind() { - case reflect.Bool: - return m.w.WriteBool(v.Bool()) - - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return m.w.WriteInt(v.Int()) - - case reflect.Uint8, reflect.Uint16, reflect.Uint32: - return m.w.WriteInt(int64(v.Uint())) - - case reflect.Uint, reflect.Uint64, reflect.Uintptr: - i := big.Int{} - i.SetUint64(v.Uint()) - return m.w.WriteBigInt(&i) - - case reflect.Float32, reflect.Float64: - return m.w.WriteFloat(v.Float()) - - case reflect.String: - return m.w.WriteString(v.String()) - - case reflect.Interface, reflect.Ptr: - return m.encodePtr(v) - - case reflect.Struct: - return m.encodeStruct(v) - - case reflect.Map: - return m.encodeMap(v) - - case reflect.Slice: - return m.encodeSlice(v) - - case reflect.Array: - return m.encodeArray(v) - - default: - return fmt.Errorf("ion: unsupported type: %v", v.Type().String()) - } -} - -// EncodePtr encodes an Ion null if the pointer is nil, and otherwise encodes the value that -// the pointer is pointing to. -func (m *Encoder) encodePtr(v reflect.Value) error { - if v.IsNil() { - return m.w.WriteNull() - } - return m.encodeValue(v.Elem()) -} - -// EncodeMap encodes a map to the output writer as an Ion struct. -func (m *Encoder) encodeMap(v reflect.Value) error { - if v.IsNil() { - return m.w.WriteNull() - } - - m.w.BeginStruct() - - keys := keysFor(v) - if m.opts&EncodeSortMaps != 0 { - sort.Slice(keys, func(i, j int) bool { return keys[i].s < keys[j].s }) - } - - for _, key := range keys { - m.w.FieldName(key.s) - value := v.MapIndex(key.v) - if err := m.encodeValue(value); err != nil { - return err - } - } - - return m.w.EndStruct() -} - -// A mapkey holds the reflective map key value as well as its stringified form. -type mapkey struct { - v reflect.Value - s string -} - -// KeysFor returns the stringified keys for the given map. -func keysFor(v reflect.Value) []mapkey { - keys := v.MapKeys() - res := make([]mapkey, len(keys)) - - for i, key := range keys { - // TODO: Handle other kinds of keys. - if key.Kind() != reflect.String { - panic("unexpected map key type") - } - res[i] = mapkey{ - v: key, - s: key.String(), - } - } - - return res -} - -// EncodeSlice encodes a slice to the output writer as an appropriate Ion type. -func (m *Encoder) encodeSlice(v reflect.Value) error { - if v.Type().Elem().Kind() == reflect.Uint8 { - return m.encodeBlob(v) - } - - if v.IsNil() { - return m.w.WriteNull() - } - - return m.encodeArray(v) -} - -// EncodeBlob encodes a []byte to the output writer as an Ion blob. -func (m *Encoder) encodeBlob(v reflect.Value) error { - if v.IsNil() { - return m.w.WriteNull() - } - return m.w.WriteBlob(v.Bytes()) -} - -// EncodeArray encodes an array to the output writer as an Ion list. -func (m *Encoder) encodeArray(v reflect.Value) error { - m.w.BeginList() - - for i := 0; i < v.Len(); i++ { - if err := m.encodeValue(v.Index(i)); err != nil { - return err - } - } - - return m.w.EndList() -} - -// EncodeStruct encodes a struct to the output writer as an Ion struct. -func (m *Encoder) encodeStruct(v reflect.Value) error { - t := v.Type() - if t == timeType { - return m.encodeTime(v) - } - if t == decimalType { - return m.encodeDecimal(v) - } - - fields := fieldsFor(v.Type()) - - m.w.BeginStruct() - -FieldLoop: - for i := range fields { - f := &fields[i] - - fv := v - for _, i := range f.path { - if fv.Kind() == reflect.Ptr { - if fv.IsNil() { - continue FieldLoop - } - fv = fv.Elem() - } - fv = fv.Field(i) - } - - if f.omitEmpty && emptyValue(fv) { - continue - } - - m.w.FieldName(f.name) - if err := m.encodeValue(fv); err != nil { - return err - } - } - - return m.w.EndStruct() -} - -// EncodeTime encodes a time.Time to the output writer as an Ion timestamp. -func (m *Encoder) encodeTime(v reflect.Value) error { - t := v.Interface().(time.Time) - return m.w.WriteTimestamp(t) -} - -// EncodeDecimal encodes an ion.Decimal to the output writer as an Ion decimal. -func (m *Encoder) encodeDecimal(v reflect.Value) error { - d := v.Addr().Interface().(*Decimal) - return m.w.WriteDecimal(d) -} - -// EmptyValue returns true if the given value is the empty value for its type. -func emptyValue(v reflect.Value) bool { - switch v.Kind() { - case reflect.Array, reflect.Map, reflect.Slice, reflect.String: - return v.Len() == 0 - case reflect.Bool: - return !v.Bool() - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return v.Int() == 0 - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return v.Uint() == 0 - case reflect.Float32, reflect.Float64: - return v.Float() == 0 - case reflect.Interface, reflect.Ptr: - return v.IsNil() - } - return false -} diff --git a/marshal_test.go b/marshal_test.go deleted file mode 100644 index 72e8dc7a..00000000 --- a/marshal_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package ion - -import ( - "bytes" - "math" - "testing" - "time" -) - -func TestMarshalText(t *testing.T) { - test := func(v interface{}, eval string) { - t.Run(eval, func(t *testing.T) { - val, err := MarshalText(v) - if err != nil { - t.Fatal(err) - } - if string(val) != eval { - t.Errorf("expected '%v', got '%v'", eval, string(val)) - } - }) - } - - test(nil, "null") - test(true, "true") - test(false, "false") - - test(byte(42), "42") - test(-42, "-42") - test(uint64(math.MaxUint64), "18446744073709551615") - test(math.MinInt64, "-9223372036854775808") - - test(42.0, "4.2e+1") - test(math.Inf(1), "+inf") - test(math.Inf(-1), "-inf") - test(math.NaN(), "nan") - - test(MustParseDecimal("1.20"), "1.20") - test(time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC), "2010-01-01T00:00:00Z") - - test("hello\tworld", "\"hello\\tworld\"") - - test(struct{ A, B int }{42, 0}, "{A:42,B:0}") - test(struct { - A int `json:"val,ignoreme"` - B int `json:"-"` - C int `json:",omitempty"` - d int - }{42, 0, 0, 0}, "{val:42}") - - test(struct{ V interface{} }{}, "{V:null}") - test(struct{ V interface{} }{"42"}, "{V:\"42\"}") - - fourtytwo := 42 - - test(struct{ V *int }{}, "{V:null}") - test(struct{ V *int }{&fourtytwo}, "{V:42}") - - test(map[string]int{"b": 2, "a": 1}, "{a:1,b:2}") - - test(struct{ V []int }{}, "{V:null}") - test(struct{ V []int }{[]int{4, 2}}, "{V:[4,2]}") - - test(struct{ V []byte }{}, "{V:null}") - test(struct{ V []byte }{[]byte{4, 2}}, "{V:{{BAI=}}}") - - test(struct{ V [2]byte }{[2]byte{4, 2}}, "{V:[4,2]}") -} - -func TestMarshalBinary(t *testing.T) { - test := func(v interface{}, name string, eval []byte) { - t.Run(name, func(t *testing.T) { - val, err := MarshalBinary(v) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(val, eval) { - t.Errorf("expected '%v', got '%v'", fmtbytes(eval), fmtbytes(val)) - } - }) - } - - test(nil, "null", []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) - test(struct{ A, B int }{42, 0}, "{A:42,B:0}", []byte{ - 0xE0, 0x01, 0x00, 0xEA, - 0xE9, 0x81, 0x83, 0xD6, 0x87, 0xB4, 0x81, 'A', 0x81, 'B', - 0xD5, - 0x8A, 0x21, 0x2A, - 0x8B, 0x20, - }) -} - -func TestMarshalBinaryLST(t *testing.T) { - lsta := NewLocalSymbolTable(nil, nil) - lstb := NewLocalSymbolTable(nil, []string{ - "A", "B", - }) - - test := func(v interface{}, name string, lst SymbolTable, eval []byte) { - t.Run(name, func(t *testing.T) { - val, err := MarshalBinaryLST(v, lst) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(val, eval) { - t.Errorf("expected '%v', got '%v'", fmtbytes(eval), fmtbytes(val)) - } - }) - } - - test(nil, "null", lsta, []byte{0xE0, 0x01, 0x00, 0xEA, 0x0F}) - test(struct{ A, B int }{42, 0}, "{A:42,B:0}", lstb, []byte{ - 0xE0, 0x01, 0x00, 0xEA, - 0xE9, 0x81, 0x83, 0xD6, 0x87, 0xB4, 0x81, 'A', 0x81, 'B', - 0xD5, - 0x8A, 0x21, 0x2A, - 0x8B, 0x20, - }) -} - -func TestMarshalNestedStructs(t *testing.T) { - type gp struct { - A int `json:"a"` - } - - type gp2 struct { - B int `json:"b"` - } - - type parent struct { - gp - *gp2 - C int `json:"c"` - } - - type root struct { - parent - D int `json:"d"` - } - - v := root{ - parent: parent{ - gp: gp{ - A: 1, - }, - gp2: &gp2{ - B: 2, - }, - C: 3, - }, - D: 4, - } - - val, err := MarshalText(v) - if err != nil { - t.Fatal(err) - } - - eval := "{a:1,b:2,c:3,d:4}" - if string(val) != eval { - t.Errorf("expected %v, got %v", eval, string(val)) - } -} diff --git a/reader.go b/reader.go deleted file mode 100644 index 9f16c5af..00000000 --- a/reader.go +++ /dev/null @@ -1,388 +0,0 @@ -package ion - -import ( - "bufio" - "bytes" - "io" - "math" - "math/big" - "strings" - "time" -) - -// A Reader reads a stream of Ion values. -// -// The Reader has a logical position within the stream of values, influencing the -// values returnedd from its methods. Initially, the Reader is positioned before the -// first value in the stream. A call to Next advances the Reader to the first value -// in the stream, with subsequent calls advancing to subsequent values. When a call to -// Next moves the Reader to the position after the final value in the stream, it returns -// false, making it easy to loop through the values in a stream. -// -// var r Reader -// for r.Next() { -// // ... -// } -// -// Next also returns false in case of error. This can be distinguished from a legitimate -// end-of-stream by calling Err after exiting the loop. -// -// When positioned on an Ion value, the type of the value can be retrieved by calling -// Type. If it has an associated field name (inside a struct) or annotations, they can -// be read by calling FieldName and Annotations respectively. -// -// For atomic values, an appropriate XxxValue method can be called to read the value. -// For lists, sexps, and structs, you should instead call StepIn to move the Reader in -// to the contained sequence of values. The Reader will initially be positioned before -// the first value in the container. Calling Next without calling StepIn will skip over -// the composite value and return the next value in the outer value stream. -// -// At any point while reading through a composite value, including when Next returns false -// to indicate the end of the contained values, you may call StepOut to move back to the -// outer sequence of values. The Reader will be positioned at the end of the composite value, -// such that a call to Next will move to the immediately-following value (if any). -// -// r := NewTextReaderStr("[foo, bar] [baz]") -// for r.Next() { -// if err := r.StepIn(); err != nil { -// return err -// } -// for r.Next() { -// fmt.Println(r.StringValue()) -// } -// if err := r.StepOut(); err != nil { -// return err -// } -// } -// if err := r.Err(); err != nil { -// return err -// } -// -type Reader interface { - - // SymbolTable returns the current symbol table, or nil if there isn't one. - // Text Readers do not, generally speaking, have an associated symbol table. - // Binary Readers do. - SymbolTable() SymbolTable - - // Next advances the Reader to the next position in the current value stream. - // It returns true if this is the position of an Ion value, and false if it - // is not. On error, it returns false and sets Err. - Next() bool - - // Err returns an error if a previous call call to Next has failed. - Err() error - - // Type returns the type of the Ion value the Reader is currently positioned on. - // It returns NoType if the Reader is positioned before or after a value. - Type() Type - - // IsNull returns true if the current value is an explicit null. This may be true - // even if the Type is not NullType (for example, null.struct has type Struct). Yes, - // that's a bit confusing. - IsNull() bool - - // FieldName returns the field name associated with the current value. It returns - // the empty string if there is no current value or the current value has no field - // name. - FieldName() string - - // Annotations returns the set of annotations associated with the current value. - // It returns nil if there is no current value or the current value has no annotations. - Annotations() []string - - // StepIn steps in to the current value if it is a container. It returns an error if there - // is no current value or if the value is not a container. On success, the Reader is - // positioned before the first value in the container. - StepIn() error - - // StepOut steps out of the current container value being read. It returns an error if - // this Reader is not currently stepped in to a container. On success, the Reader is - // positioned after the end of the container, but before any subsequent values in the - // stream. - StepOut() error - - // BoolValue returns the current value as a boolean (if that makes sense). It returns - // an error if the current value is not an Ion bool. - BoolValue() (bool, error) - - // IntSize returns the size of integer needed to losslessly represent the current value - // (if that makes sense). It returns an error if the current value is not an Ion int. - IntSize() (IntSize, error) - - // IntValue returns the current value as a 32-bit integer (if that makes sense). It - // returns an error if the current value is not an Ion integer or requires more than - // 32 bits to represent losslessly. - IntValue() (int, error) - - // Int64Value returns the current value as a 64-bit integer (if that makes sense). It - // returns an error if the current value is not an Ion integer or requires more than - // 64 bits to represent losslessly. - Int64Value() (int64, error) - - // Uint64Value returns the current value as an unsigned 64-bit integer (if that makes - // sense). It returns an error if the current value is not an Ion integer, is negative, - // or requires more than 64 bits to represent losslessly. - Uint64Value() (uint64, error) - - // BigIntValue returns the current value as a big.Integer (if that makes sense). It - // returns an error if the current value is not an Ion integer. - BigIntValue() (*big.Int, error) - - // FloatValue returns the current value as a 64-bit floating point number (if that - // makes sense). It returns an error if the current value is not an Ion float. - FloatValue() (float64, error) - - // DecimalValue returns the current value as an arbitrary-precision Decimal (if that - // makes sense). It returns an error if the current value is not an Ion decimal. - DecimalValue() (*Decimal, error) - - // TimeValue returns the current value as a timestamp (if that makes sense). It returns - // an error if the current value is not an Ion timestamp. - TimeValue() (time.Time, error) - - // StringValue returns the current value as a string (if that makes sense). It returns - // an error if the current value is not an Ion symbol or an Ion string. - StringValue() (string, error) - - // ByteValue returns the current value as a byte slice (if that makes sense). It returns - // an error if the current value is not an Ion clob or an Ion blob. - ByteValue() ([]byte, error) -} - -// NewReader creates a new Ion reader of the appropriate type by peeking -// at the first several bytes of input for a binary version marker. -func NewReader(in io.Reader) Reader { - return NewReaderCat(in, nil) -} - -// NewReaderStr creates a new reader from a string. -func NewReaderStr(str string) Reader { - return NewReader(strings.NewReader(str)) -} - -// NewReaderBytes creates a new reader for the given bytes. -func NewReaderBytes(in []byte) Reader { - return NewReader(bytes.NewReader(in)) -} - -// NewReaderCat creates a new reader with the given catalog. -func NewReaderCat(in io.Reader, cat Catalog) Reader { - br := bufio.NewReader(in) - - bs, err := br.Peek(4) - if err == nil && bs[0] == 0xE0 && bs[3] == 0xEA { - return newBinaryReaderBuf(br, cat) - } - - return newTextReaderBuf(br) -} - -// A reader holds common implementation stuff to both the text and binary readers. -type reader struct { - ctx ctxstack - eof bool - err error - - fieldName string - annotations []string - valueType Type - value interface{} -} - -// Err returns the current error. -func (r *reader) Err() error { - return r.err -} - -// Type returns the current value's type. -func (r *reader) Type() Type { - return r.valueType -} - -// IsNull returns true if the current value is null. -func (r *reader) IsNull() bool { - return r.valueType != NoType && r.value == nil -} - -// FieldName returns the current value's field name. -func (r *reader) FieldName() string { - return r.fieldName -} - -// Annotations returns the current value's annotations. -func (r *reader) Annotations() []string { - return r.annotations -} - -// BoolValue returns the current value as a bool. -func (r *reader) BoolValue() (bool, error) { - if r.valueType != BoolType { - return false, &UsageError{"Reader.BoolValue", "value is not a bool"} - } - if r.value == nil { - return false, nil - } - return r.value.(bool), nil -} - -// IntSize returns the size of the current int value. -func (r *reader) IntSize() (IntSize, error) { - if r.valueType != IntType { - return NullInt, &UsageError{"Reader.IntSize", "value is not a int"} - } - if r.value == nil { - return NullInt, nil - } - - if i, ok := r.value.(int64); ok { - if i > math.MaxInt32 || i < math.MinInt32 { - return Int64, nil - } - return Int32, nil - } - - i := r.value.(*big.Int) - if i.IsUint64() { - return Uint64, nil - } - - return BigInt, nil -} - -// IntValue returns the current value as an int. -func (r *reader) IntValue() (int, error) { - i, err := r.Int64Value() - if err != nil { - return 0, err - } - if i > math.MaxInt32 || i < math.MinInt32 { - return 0, &UsageError{"Reader.IntValue", "value too large for an int32"} - } - return int(i), nil -} - -// Int64Value returns the current value as an int64. -func (r *reader) Int64Value() (int64, error) { - if r.valueType != IntType { - return 0, &UsageError{"Reader.Int64Value", "value is not an int"} - } - if r.value == nil { - return 0, nil - } - - if i, ok := r.value.(int64); ok { - return i, nil - } - - bi := r.value.(*big.Int) - if bi.IsInt64() { - return bi.Int64(), nil - } - - return 0, &UsageError{"Reader.Int64Value", "value too large for an int64"} -} - -// Uint64Value returns the current value as a uint64. -func (r *reader) Uint64Value() (uint64, error) { - if r.valueType != IntType { - return 0, &UsageError{"Reader.Uint64Value", "value is not an int"} - } - if r.value == nil { - return 0, nil - } - - if i, ok := r.value.(int64); ok { - if i >= 0 { - return uint64(i), nil - } - return 0, &UsageError{"Reader.Uint64Value", "value is negative"} - } - - bi := r.value.(*big.Int) - if bi.Sign() < 0 { - return 0, &UsageError{"Reader.Uint64Value", "value is negative"} - } - if !bi.IsUint64() { - return 0, &UsageError{"Reader.Uint64Value", "value too large for a uint64"} - } - return bi.Uint64(), nil -} - -// BigIntValue returns the current value as a big int. -func (r *reader) BigIntValue() (*big.Int, error) { - if r.valueType != IntType { - return nil, &UsageError{"Reader.BigIntValue", "value is not an int"} - } - if r.value == nil { - return nil, nil - } - - if i, ok := r.value.(int64); ok { - return big.NewInt(i), nil - } - return r.value.(*big.Int), nil -} - -// FloatValue returns the current value as a float. -func (r *reader) FloatValue() (float64, error) { - if r.valueType != FloatType { - return 0, &UsageError{"Reader.FloatValue", "value is not a float"} - } - if r.value == nil { - return 0.0, nil - } - return r.value.(float64), nil -} - -// DecimalValue returns the current value as a Decimal. -func (r *reader) DecimalValue() (*Decimal, error) { - if r.valueType != DecimalType { - return nil, &UsageError{"Reader.DecimalValue", "value is not a decimal"} - } - if r.value == nil { - return nil, nil - } - return r.value.(*Decimal), nil -} - -// TimeValue returns the current value as a time. -func (r *reader) TimeValue() (time.Time, error) { - if r.valueType != TimestampType { - return time.Time{}, &UsageError{"Reader.TimestampValue", "value is not a timestamp"} - } - if r.value == nil { - return time.Time{}, nil - } - return r.value.(time.Time), nil -} - -// StringValue returns the current value as a string. -func (r *reader) StringValue() (string, error) { - if r.valueType != StringType && r.valueType != SymbolType { - return "", &UsageError{"Reader.StringValue", "value is not a string"} - } - if r.value == nil { - return "", nil - } - return r.value.(string), nil -} - -// ByteValue returns the current value as a byte slice. -func (r *reader) ByteValue() ([]byte, error) { - if r.valueType != BlobType && r.valueType != ClobType { - return nil, &UsageError{"Reader.ByteValue", "value is not a lob"} - } - if r.value == nil { - return nil, nil - } - return r.value.([]byte), nil -} - -// Clear clears the current value from the reader. -func (r *reader) clear() { - r.fieldName = "" - r.annotations = nil - r.valueType = NoType - r.value = nil -} diff --git a/reader_test.go b/reader_test.go deleted file mode 100644 index 2e204de1..00000000 --- a/reader_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package ion - -import ( - "fmt" - "io/ioutil" - "os" - "path/filepath" - "testing" -) - -var blacklist = map[string]bool{ - "ion-tests/iontestdata/good/emptyAnnotatedInt.10n": true, - "ion-tests/iontestdata/good/subfieldVarUInt32bit.ion": true, - "ion-tests/iontestdata/good/utf16.ion": true, - "ion-tests/iontestdata/good/utf32.ion": true, - "ion-tests/iontestdata/good/whitespace.ion": true, - "ion-tests/iontestdata/good/item1.10n": true, -} - -type drainfunc func(t *testing.T, r Reader, f string) - -func TestReadFiles(t *testing.T) { - testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { - drain(t, r, 0) - }) -} - -func drain(t *testing.T, r Reader, level int) { - for r.Next() { - // print(level, r.Type()) - - if !r.IsNull() { - switch r.Type() { - case StructType, ListType, SexpType: - if err := r.StepIn(); err != nil { - t.Fatal(err) - } - - drain(t, r, level+1) - - if err := r.StepOut(); err != nil { - t.Fatal(err) - } - } - } - } - - if r.Err() != nil { - t.Fatal(r.Err()) - } -} - -func print(level int, obj interface{}) { - fmt.Print(" > ") - for i := 0; i < level; i++ { - fmt.Print(" ") - } - fmt.Println(obj) -} - -func TestDecodeFiles(t *testing.T) { - testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { - // fmt.Println(f) - d := NewDecoder(r) - for { - v, err := d.Decode() - if err == ErrNoInput { - break - } - if err != nil { - t.Fatal(err) - } - // fmt.Println(v) - _ = v - } - }) -} - -var emptyFiles = []string{ - "ion-tests/iontestdata/good/blank.ion", - "ion-tests/iontestdata/good/empty.ion", -} - -func isEmptyFile(f string) bool { - for _, s := range emptyFiles { - if f == s { - return true - } - } - return false -} - -func testReadDir(t *testing.T, path string, d drainfunc) { - files, err := ioutil.ReadDir(path) - if err != nil { - t.Fatal(err) - } - - for _, file := range files { - fp := filepath.Join(path, file.Name()) - if file.IsDir() { - testReadDir(t, fp, d) - } else { - t.Run(fp, func(t *testing.T) { - testReadFile(t, fp, d) - }) - } - } -} - -func testReadFile(t *testing.T, path string, d drainfunc) { - if _, ok := blacklist[path]; ok { - return - } - - // fmt.Println(path) - - file, err := os.Open(path) - if err != nil { - t.Fatal(err) - } - defer file.Close() - - r := NewReader(file) - - d(t, r, path) -} diff --git a/skipper.go b/skipper.go deleted file mode 100644 index cf74ebb7..00000000 --- a/skipper.go +++ /dev/null @@ -1,863 +0,0 @@ -package ion - -import ( - "fmt" - "io" -) - -// SkipContainerContents skips over the contents of a container of the given type. -func (t *tokenizer) SkipContainerContents(typ Type) error { - switch typ { - case StructType: - return t.skipStructHelper() - case ListType: - return t.skipListHelper() - case SexpType: - return t.skipSexpHelper() - default: - panic(fmt.Sprintf("invalid container type: %v", typ)) - } -} - -// Skips whitespace and a double-colon token, if there is one. -func (t *tokenizer) SkipDoubleColon() (bool, bool, error) { - ws, err := t.skipWhitespaceHelper() - if err != nil { - return false, false, err - } - - ok, err := t.skipDoubleColon() - if err != nil { - return false, false, err - } - - return ok, ws, nil -} - -// Peeks ahead to see if the next token is a dot, and -// if so skips it. If not, leaves the next token unconsumed. -func (t *tokenizer) SkipDot() (bool, error) { - c, err := t.peek() - if err != nil { - return false, err - } - if c != '.' { - return false, nil - } - - t.read() - return true, nil -} - -// SkipLobWhitespace skips whitespace when we're inside a large -// object ({{ ///= }} or {{ '''///=''' }}) where comments are -// not allowed. -func (t *tokenizer) SkipLobWhitespace() (int, error) { - c, _, err := t.skipLobWhitespace() - return c, err -} - -// SkipValue skips to the end of the current value, if the caller -// didn't bother to consume it before calling Next again. -func (t *tokenizer) skipValue() (int, error) { - var c int - var err error - - switch t.token { - case tokenNumber: - c, err = t.skipNumber() - case tokenBinary: - c, err = t.skipBinary() - case tokenHex: - c, err = t.skipHex() - case tokenTimestamp: - c, err = t.skipTimestamp() - case tokenSymbol: - c, err = t.skipSymbol() - case tokenSymbolQuoted: - c, err = t.skipSymbolQuoted() - case tokenSymbolOperator: - c, err = t.skipSymbolOperator() - case tokenString: - c, err = t.skipString() - case tokenLongString: - c, err = t.skipLongString() - case tokenOpenDoubleBrace: - c, err = t.skipBlob() - case tokenOpenBrace: - c, err = t.skipStruct() - case tokenOpenParen: - c, err = t.skipSexp() - case tokenOpenBracket: - c, err = t.skipList() - default: - panic(fmt.Sprintf("skipValue called with token=%v", t.token)) - } - - if err != nil { - return 0, err - } - - if isWhitespace(c) { - c, _, err = t.skipWhitespace() - if err != nil { - return 0, err - } - } - - t.unfinished = false - return c, nil -} - -// SkipNumber skips a (non-binary, non-hex) number. -func (t *tokenizer) skipNumber() (int, error) { - c, err := t.read() - if err != nil { - return 0, err - } - - if c == '-' { - c, err = t.read() - if err != nil { - return 0, err - } - } - - c, err = t.skipDigits(c) - if err != nil { - return 0, err - } - - if c == '.' { - c, err = t.read() - if err != nil { - return 0, err - } - c, err = t.skipDigits(c) - if err != nil { - return 0, err - } - } - - if c == 'd' || c == 'D' || c == 'e' || c == 'E' { - c, err = t.read() - if err != nil { - return 0, err - } - if c == '+' || c == '-' { - c, err = t.read() - if err != nil { - return 0, err - } - } - c, err = t.skipDigits(c) - if err != nil { - return 0, err - } - } - - ok, err := t.isStopChar(c) - if err != nil { - return 0, err - } - if !ok { - return 0, t.invalidChar(c) - } - return c, nil -} - -// SkipBinary skips a binary literal value. -func (t *tokenizer) skipBinary() (int, error) { - isB := func(c int) bool { - return c == 'b' || c == 'B' - } - isBinaryDigit := func(c int) bool { - return c == '0' || c == '1' - } - return t.skipRadix(isB, isBinaryDigit) -} - -// SkipHex skips a hex value. -func (t *tokenizer) skipHex() (int, error) { - isX := func(c int) bool { - return c == 'x' || c == 'X' - } - return t.skipRadix(isX, isHexDigit) -} - -func (t *tokenizer) skipRadix(pok, dok matcher) (int, error) { - c, err := t.read() - if err != nil { - return 0, err - } - - if c == '-' { - c, err = t.read() - if err != nil { - return 0, err - } - } - - if c != '0' { - return 0, t.invalidChar(c) - } - if err = t.expect(pok); err != nil { - return 0, err - } - - for { - c, err = t.read() - if err != nil { - return 0, err - } - if !dok(c) { - break - } - } - - ok, err := t.isStopChar(c) - if err != nil { - return 0, err - } - if !ok { - return 0, t.invalidChar(c) - } - - return c, nil -} - -// SkipTimestamp skips a timestamp value, returning the next character. -func (t *tokenizer) skipTimestamp() (int, error) { - // Read the first four digits, yyyy. - c, err := t.skipTimestampDigits(4) - if err != nil { - return 0, err - } - if c == 'T' { - // yyyyT - return t.read() - } - if c != '-' { - return 0, t.invalidChar(c) - } - - // Read the next two, yyyy-mm. - c, err = t.skipTimestampDigits(2) - if err != nil { - return 0, err - } - if c == 'T' { - // yyyy-mmT - return t.read() - } - if c != '-' { - return 0, t.invalidChar(c) - } - - // Read the day. - c, err = t.skipTimestampDigits(2) - if err != nil { - return 0, err - } - if c != 'T' { - // yyyy-mm-dd. - return t.skipTimestampFinish(c) - } - - c, err = t.read() - if err != nil { - return 0, err - } - if !isDigit(c) { - // yyyy-mm-ddT(+hh:mm)? - c, err = t.skipTimestampOffset(c) - if err != nil { - return 0, err - } - return t.skipTimestampFinish(c) - } - - // Already read the first hour digit above. - c, err = t.skipTimestampDigits(1) - if err != nil { - return 0, err - } - if c != ':' { - return 0, t.invalidChar(c) - } - - c, err = t.skipTimestampDigits(2) - if err != nil { - return 0, err - } - if c != ':' { - // yyyy-mm-ddThh:mmZ - c, err = t.skipTimestampOffsetOrZ(c) - if err != nil { - return 0, err - } - return t.skipTimestampFinish(c) - } - - c, err = t.skipTimestampDigits(2) - if err != nil { - return 0, err - } - if c != '.' { - // yyyy-mm-ddThh:mm:ssZ - c, err = t.skipTimestampOffsetOrZ(c) - if err != nil { - return 0, err - } - return t.skipTimestampFinish(c) - } - - // yyyy-mm-ddThh:mm:ss.ssssZ - c, err = t.read() - if err != nil { - return 0, err - } - if isDigit(c) { - c, err = t.skipDigits(c) - if err != nil { - return 0, err - } - } - - c, err = t.skipTimestampOffsetOrZ(c) - if err != nil { - return 0, err - } - return t.skipTimestampFinish(c) -} - -// SkipTimestampOffsetOrZ skips a (required) timestamp offset value or -// letter 'Z' (indicating UTC). -func (t *tokenizer) skipTimestampOffsetOrZ(c int) (int, error) { - if c == '-' || c == '+' { - return t.skipTimestampOffset(c) - } - if c == 'z' || c == 'Z' { - return t.read() - } - return 0, t.invalidChar(c) -} - -// SkipTimestampOffset skips an (optional) +-hh:mm timestamp zone offset -// value. -func (t *tokenizer) skipTimestampOffset(c int) (int, error) { - if c != '-' && c != '+' { - return c, nil - } - - c, err := t.skipTimestampDigits(2) - if err != nil { - return 0, err - } - if c != ':' { - return 0, t.invalidChar(c) - } - return t.skipTimestampDigits(2) -} - -// SkipTimestampDigits skips a bounded sequence of digits inside a -// timestamp. -func (t *tokenizer) skipTimestampDigits(n int) (int, error) { - for n > 0 { - if err := t.expect(func(c int) bool { - return isDigit(c) - }); err != nil { - return 0, err - } - n-- - } - - return t.read() -} - -// SkipTimestampFinish makes sure the character after a timestamp -// value is a valid ending point. If so, it returns it. -func (t *tokenizer) skipTimestampFinish(c int) (int, error) { - ok, err := t.isStopChar(c) - if err != nil { - return 0, err - } - if !ok { - return 0, t.invalidChar(c) - } - return c, nil -} - -// SkipSymbol skips a normal symbol and returns the next character. -func (t *tokenizer) skipSymbol() (int, error) { - c, err := t.read() - if err != nil { - return 0, err - } - - for isIdentifierPart(c) { - c, err = t.read() - if err != nil { - return 0, err - } - } - - return c, nil -} - -// SkipSymbolQuoted skips a quoted symbol and returns the next char. -func (t *tokenizer) skipSymbolQuoted() (int, error) { - if err := t.skipSymbolQuotedHelper(); err != nil { - return 0, err - } - return t.read() -} - -// SkipSymbolQuotedHelper skips a quoted symbol. -func (t *tokenizer) skipSymbolQuotedHelper() error { - for { - c, err := t.read() - if err != nil { - return err - } - - switch c { - case -1, '\n': - return t.invalidChar(c) - - case '\'': - return nil - - case '\\': - if _, err := t.read(); err != nil { - return err - } - } - } -} - -// SkipSymbolOperator skips an operator-style symbol inside an sexp. -func (t *tokenizer) skipSymbolOperator() (int, error) { - c, err := t.read() - if err != nil { - return 0, err - } - - for isOperatorChar(c) { - c, err = t.read() - if err != nil { - return 0, err - } - } - - return c, nil -} - -// SkipString skips over a "-enclosed string, returning the next char. -func (t *tokenizer) skipString() (int, error) { - if err := t.skipStringHelper(); err != nil { - return 0, err - } - return t.read() -} - -// SkipStringHelper skips over a "-enclosed string. -func (t *tokenizer) skipStringHelper() error { - for { - c, err := t.read() - if err != nil { - return err - } - - switch c { - case -1, '\n': - return t.invalidChar(c) - - case '"': - return nil - - case '\\': - if _, err := t.read(); err != nil { - return err - } - } - } -} - -// SkipLongString skips over a '''-enclosed string, returning the next -// character after the closing '''. -func (t *tokenizer) skipLongString() (int, error) { - if err := t.skipLongStringHelper(t.skipCommentsHandler); err != nil { - return 0, err - } - return t.read() -} - -// SkipLongStringHelper skips over a '''-enclosed string. -func (t *tokenizer) skipLongStringHelper(handler commentHandler) error { - for { - c, err := t.read() - if err != nil { - return err - } - - switch c { - case -1: - return t.invalidChar(c) - - case '\'': - ok, err := t.skipEndOfLongString(handler) - if err != nil { - return err - } - if ok { - return nil - } - - case '\\': - if _, err = t.read(); err != nil { - return err - } - } - } -} - -// SkipEndOfLongString is called after reading a ' to determine if we've -// hit the end of the long string.. -func (t *tokenizer) skipEndOfLongString(handler commentHandler) (bool, error) { - // We just read a ', check for two more ''s. - cs, err := t.peekN(2) - if err != nil && err != io.EOF { - return false, err - } - - // If it's not a triple-quote, keep going. - if len(cs) < 2 || cs[0] != '\'' || cs[1] != '\'' { - return false, nil - } - - // Consume the triple-quote. - if err := t.skipN(2); err != nil { - return false, err - } - - // Consume any additional whitespace/comments. - c, _, err := t.skipWhitespaceWith(handler) - if err != nil { - return false, err - } - - // Check if it's another triple-quote; if so, keep going. - if c == '\'' { - ok, err := t.IsTripleQuote() - if err != nil { - return false, err - } - if ok { - return false, nil - } - } - - t.unread(c) - return true, nil -} - -// SkipBlob skips over a blob value, returning the next character. -func (t *tokenizer) skipBlob() (int, error) { - if err := t.skipBlobHelper(); err != nil { - return 0, err - } - return t.read() -} - -// SkipBlobHelper skips over a blob value, stopping after reading the -// final '}'. -func (t *tokenizer) skipBlobHelper() error { - c, _, err := t.skipLobWhitespace() - if err != nil { - return err - } - - // TODO: If this is a clob, could we potentially have an embedded - // '}' here? - for c != '}' { - c, _, err = t.skipLobWhitespace() - if err != nil { - return err - } - if c == -1 { - return t.invalidChar(c) - } - } - - return t.expect(func(c int) bool { - return c == '}' - }) -} - -func (t *tokenizer) skipStruct() (int, error) { - return t.skipContainer('}') -} - -func (t *tokenizer) skipStructHelper() error { - return t.skipContainerHelper('}') -} - -func (t *tokenizer) skipSexp() (int, error) { - return t.skipContainer(')') -} - -func (t *tokenizer) skipSexpHelper() error { - return t.skipContainerHelper(')') -} - -// SkipList skips forward past a list that the caller doesn't care to -// step in to. -func (t *tokenizer) skipList() (int, error) { - return t.skipContainer(']') -} - -func (t *tokenizer) skipListHelper() error { - return t.skipContainerHelper(']') -} - -// SkipContainer skips a container terminated by the given char and -// returns the next character. -func (t *tokenizer) skipContainer(term int) (int, error) { - if err := t.skipContainerHelper(term); err != nil { - return 0, err - } - return t.read() -} - -// SkipContainerHelper skips over a container terminated by the given -// char. -func (t *tokenizer) skipContainerHelper(term int) error { - if term != ']' && term != ')' && term != '}' { - panic("wat") - } - - for { - c, _, err := t.skipWhitespace() - if err != nil { - return err - } - - switch c { - case -1: - return t.invalidChar(c) - - case term: - return nil - - case '"': - if err := t.skipStringHelper(); err != nil { - return err - } - - case '\'': - ok, err := t.IsTripleQuote() - if err != nil { - return err - } - if ok { - if err = t.skipLongStringHelper(t.skipCommentsHandler); err != nil { - return err - } - } else { - if err = t.skipSymbolQuotedHelper(); err != nil { - return err - } - } - - case '(': - if err := t.skipContainerHelper(')'); err != nil { - return err - } - - case '[': - if err := t.skipContainerHelper(']'); err != nil { - return err - } - - case '{': - c, err := t.peek() - if err != nil { - return err - } - - if c == '{' { - if _, err := t.read(); err != nil { - return err - } - if err := t.skipBlobHelper(); err != nil { - return err - } - } else if c == '}' { - if _, err := t.read(); err != nil { - return err - } - } else { - if err := t.skipContainerHelper('}'); err != nil { - return err - } - } - } - } -} - -// SkipDigits skips a sequence of digits starting with the -// given character. -func (t *tokenizer) skipDigits(c int) (int, error) { - var err error - for err == nil && isDigit(c) { - c, err = t.read() - } - return c, err -} - -// SkipWhitespace skips whitespace (and comments) when we're out -// in normal parsing territory. -func (t *tokenizer) skipWhitespace() (int, bool, error) { - return t.skipWhitespaceWith(t.skipCommentsHandler) -} - -// SkipWhitespaceHelper is a 'helper' form of SkipWhitespace that -// unreads the first non-whitespace char instead of returning it. -func (t *tokenizer) skipWhitespaceHelper() (bool, error) { - c, ok, err := t.skipWhitespace() - if err != nil { - return false, err - } - t.unread(c) - return ok, err -} - -// SkipLobWhitespace skips whitespace when we're inside a large -// object ({{ ///= }} or {{ '''///=''' }}) where comments are -// not allowed. -func (t *tokenizer) skipLobWhitespace() (int, bool, error) { - // Comments are not allowed inside a lob value; if we see a '/', - // it's the start of a base64-encoded value. - return t.skipWhitespaceWith(stopForCommentsHandler) -} - -// CommentHandler is a strategy for handling comments. Returns true -// if it found and handled a comment, false if it didn't find a -// comment, and returns an error if it choked on the comment. -type commentHandler func() (bool, error) - -// SkipWhitespaceWith skips whitespace using the given strategy for -// handling comments--generally speaking, either skipping over them -// using skipCommentsHandler, or stopping with a stopForCommentsHandler. -// Returns the first non-whitespace character it reads, and whether it -// actually skipped anything to find it. -func (t *tokenizer) skipWhitespaceWith(handler commentHandler) (int, bool, error) { - skipped := false - for { - c, err := t.read() - if err != nil { - return 0, skipped, err - } - - switch c { - case ' ', '\t', '\n', '\r': - // Skipped. - - case '/': - comment, err := handler() - if err != nil { - return 0, skipped, err - } - if !comment { - return '/', skipped, nil - } - - default: - return c, skipped, nil - } - skipped = true - } -} - -// StopForCommentsHandler is a commentHandler that stops skipping -// whitespace when it finds a (potential) comment. Use it when you -// expect a '/' to be an actual '/', not a comment. -func stopForCommentsHandler() (bool, error) { - return false, nil -} - -// SkipCommentsHandler is a commentHandler that skips over any -// comments it finds. -func (t *tokenizer) skipCommentsHandler() (bool, error) { - // We've just read a '/', which might be the start of a comment. - // Peek ahead to see if it is, and if so skip over it. - c, err := t.peek() - if err != nil { - return false, err - } - - switch c { - case '/': - return true, t.skipSingleLineComment() - case '*': - return true, t.skipBlockComment() - default: - return false, nil - } -} - -// SkipSingleLineComment skips over the body of a single-line comment, -// terminated by the end of the line (or file). -func (t *tokenizer) skipSingleLineComment() error { - for { - c, err := t.read() - if err != nil { - return err - } - - if c == -1 || c == '\n' { - return nil - } - } -} - -// SkipBlockComment skips over the body of a block comment, terminated -// by a '*/' sequence. -func (t *tokenizer) skipBlockComment() error { - star := false - for { - c, err := t.read() - if err != nil { - return err - } - if c == -1 { - return t.invalidChar(c) - } - - if star && c == '/' { - return nil - } - - star = (c == '*') - } -} - -// Peeks ahead to see if the next token is a double colon, and -// if so skips it. If not, leaves the next token unconsumed. -func (t *tokenizer) skipDoubleColon() (bool, error) { - cs, err := t.peekN(2) - if err == io.EOF { - return false, nil - } - if err != nil { - return false, err - } - - if cs[0] == ':' && cs[1] == ':' { - t.skipN(2) - return true, nil - } - - return false, nil -} diff --git a/skipper_test.go b/skipper_test.go deleted file mode 100644 index afb48874..00000000 --- a/skipper_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package ion - -import ( - "testing" -) - -func TestSkipNumber(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipNumber) - - test("", -1) - test("0", -1) - test("-1234567890,", ',') - test("1.2 ", ' ') - test("1d45\n", '\n') - test("1.4e-12//", '/') - - testErr("1.2d3d", "ion: unexpected rune 'd' (offset 5)") -} - -func TestSkipBinary(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipBinary) - - test("0b0", -1) - test("-0b10 ", ' ') - test("0b010101,", ',') - - testErr("0b2", "ion: unexpected rune '2' (offset 2)") -} - -func TestSkipHex(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipHex) - - test("0x0", -1) - test("-0x0F ", ' ') - test("0x1234567890abcdefABCDEF,", ',') - - testErr("0x0G", "ion: unexpected rune 'G' (offset 3)") -} - -func TestSkipTimestamp(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipTimestamp) - - test("2001T", -1) - test("2001-01T,", ',') - test("2001-01-02}", '}') - test("2001-01-02T ", ' ') - test("2001-01-02T+00:00\t", '\t') - test("2001-01-02T-00:00\n", '\n') - test("2001-01-02T03:04+00:00 ", ' ') - test("2001-01-02T03:04-00:00 ", ' ') - test("2001-01-02T03:04Z ", ' ') - test("2001-01-02T03:04z ", ' ') - test("2001-01-02T03:04:05Z ", ' ') - test("2001-01-02T03:04:05+00:00 ", ' ') - test("2001-01-02T03:04:05.666Z ", ' ') - test("2001-01-02T03:04:05.666666z ", ' ') - - testErr("", "ion: unexpected end of input (offset 0)") - testErr("2001", "ion: unexpected end of input (offset 4)") - testErr("2001z", "ion: unexpected rune 'z' (offset 4)") - testErr("20011", "ion: unexpected rune '1' (offset 4)") - testErr("2001-0", "ion: unexpected end of input (offset 6)") - testErr("2001-01", "ion: unexpected end of input (offset 7)") - testErr("2001-01-02Tz", "ion: unexpected rune 'z' (offset 11)") - testErr("2001-01-02T03", "ion: unexpected end of input (offset 13)") - testErr("2001-01-02T03z", "ion: unexpected rune 'z' (offset 13)") - testErr("2001-01-02T03:04x ", "ion: unexpected rune 'x' (offset 16)") - testErr("2001-01-02T03:04:05x ", "ion: unexpected rune 'x' (offset 19)") -} - -func TestSkipSymbol(t *testing.T) { - test, _ := testSkip(t, (*tokenizer).skipSymbol) - - test("f", -1) - test("foo:", ':') - test("foo,", ',') - test("foo ", ' ') - test("foo\n", '\n') - test("foo]", ']') - test("foo}", '}') - test("foo)", ')') - test("foo\\n", '\\') -} - -func TestSkipSymbolQuoted(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipSymbolQuoted) - - test("'", -1) - test("foo',", ',') - test("foo\\'bar':", ':') - test("foo\\\nbar',", ',') - - testErr("foo", "ion: unexpected end of input (offset 3)") - testErr("foo\n", "ion: unexpected rune '\\n' (offset 3)") -} - -func TestSkipSymbolOperator(t *testing.T) { - test, _ := testSkip(t, (*tokenizer).skipSymbolOperator) - - test("+", -1) - test("++", -1) - test("+= ", ' ') - test("%b", 'b') -} - -func TestSkipString(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipString) - - test("\"", -1) - test("\",", ',') - test("foo\\\"bar\"], \"\"", ']') - test("foo\\\nbar\" \t\t\t", ' ') - - testErr("foobar", "ion: unexpected end of input (offset 6)") - testErr("foobar\n", "ion: unexpected rune '\\n' (offset 6)") -} - -func TestSkipLongString(t *testing.T) { - test, _ := testSkip(t, (*tokenizer).skipLongString) - - test("'''", -1) - test("''',", ',') - test("abc''',", ',') - test("abc''' }", '}') - test("abc''' /*more*/ '''def'''\t//more\r\n]", ']') -} - -func TestSkipBlob(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipBlob) - - test("}}", -1) - test("oogboog}},{{}}", ',') - test("'''not encoded'''}}\n", '\n') - - testErr("", "ion: unexpected end of input (offset 1)") - testErr("oogboog", "ion: unexpected end of input (offset 7)") - testErr("oogboog}", "ion: unexpected end of input (offset 8)") - testErr("oog}{boog", "ion: unexpected rune '{' (offset 4)") -} - -func TestSkipList(t *testing.T) { - test, testErr := testSkip(t, (*tokenizer).skipList) - - test("]", -1) - test("[]],", ',') - test("[123, \"]\", ']']] ", ' ') - - testErr("abc, def, ", "ion: unexpected end of input (offset 10)") -} - -type skipFunc func(*tokenizer) (int, error) -type skipTestFunc func(string, int) -type skipTestErrFunc func(string, string) - -func testSkip(t *testing.T, f skipFunc) (skipTestFunc, skipTestErrFunc) { - test := func(str string, ec int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - c, err := f(tok) - if err != nil { - t.Fatal(err) - } - if c != ec { - t.Errorf("expected '%c', got '%c'", ec, c) - } - }) - } - testErr := func(str string, e string) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - _, err := f(tok) - if err == nil || err.Error() != e { - t.Errorf("expected err=%v, got err=%v", e, err) - } - }) - } - return test, testErr -} diff --git a/symboltable.go b/symboltable.go deleted file mode 100644 index 922a1dc3..00000000 --- a/symboltable.go +++ /dev/null @@ -1,475 +0,0 @@ -package ion - -import ( - "strings" -) - -// A SymbolTable maps binary-representation symbol IDs to -// text-representation strings and vice versa. -type SymbolTable interface { - // Imports returns the symbol tables this table imports. - Imports() []SharedSymbolTable - // Symbols returns the symbols this symbol table defines. - Symbols() []string - // MaxID returns the maximum ID this symbol table defines. - MaxID() uint64 - - // FindByName finds the ID of a symbol by its name. - FindByName(symbol string) (uint64, bool) - // FindByID finds the name of a symbol given its ID. - FindByID(id uint64) (string, bool) - // WriteTo serializes the symbol table to an ion.Writer. - WriteTo(w Writer) error - // String returns an ion text representation of the symbol table. - String() string -} - -// A SharedSymbolTable is distributed out-of-band and referenced from -// a local SymbolTable to save space. -type SharedSymbolTable interface { - SymbolTable - - // Name returns the name of this shared symbol table. - Name() string - // Version returns the version of this shared symbol table. - Version() int - // Adjust returns a new shared symbol table limited or extended to the given max ID. - Adjust(maxID uint64) SharedSymbolTable -} - -type sst struct { - name string - version int - symbols []string - index map[string]uint64 - maxID uint64 -} - -// NewSharedSymbolTable creates a new shared symbol table. -func NewSharedSymbolTable(name string, version int, symbols []string) SharedSymbolTable { - syms := make([]string, len(symbols)) - copy(syms, symbols) - - index := buildIndex(syms, 1) - - return &sst{ - name: name, - version: version, - symbols: syms, - index: index, - maxID: uint64(len(syms)), - } -} - -func (s *sst) Name() string { - return s.name -} - -func (s *sst) Version() int { - return s.version -} - -func (s *sst) Imports() []SharedSymbolTable { - return nil -} - -func (s *sst) Symbols() []string { - syms := make([]string, s.maxID) - copy(syms, s.symbols) - return syms -} - -func (s *sst) MaxID() uint64 { - return uint64(s.maxID) -} - -func (s *sst) Adjust(maxID uint64) SharedSymbolTable { - if maxID == s.maxID { - // Nothing needs to change. - return s - } - - if maxID > uint64(len(s.symbols)) { - // Old index will work fine, just adjust the maxID. - return &sst{ - name: s.name, - version: s.version, - symbols: s.symbols, - index: s.index, - maxID: maxID, - } - } - - // Slice the symbols down to size and reindex. - symbols := s.symbols[:maxID] - index := buildIndex(symbols, 1) - - return &sst{ - name: s.name, - version: s.version, - symbols: symbols, - index: index, - maxID: maxID, - } -} - -func (s *sst) FindByName(sym string) (uint64, bool) { - id, ok := s.index[sym] - return uint64(id), ok -} - -func (s *sst) FindByID(id uint64) (string, bool) { - if id <= 0 || id > uint64(len(s.symbols)) { - return "", false - } - return s.symbols[id-1], true -} - -func (s *sst) WriteTo(w Writer) error { - w.Annotation("$ion_shared_symbol_table") - w.BeginStruct() - { - w.FieldName("name") - w.WriteString(s.name) - - w.FieldName("version") - w.WriteInt(int64(s.version)) - - w.FieldName("symbols") - w.BeginList() - { - for _, sym := range s.symbols { - w.WriteString(sym) - } - } - w.EndList() - } - return w.EndStruct() -} - -func (s *sst) String() string { - buf := strings.Builder{} - - w := NewTextWriter(&buf) - s.WriteTo(w) - - return buf.String() -} - -// V1SystemSymbolTable is the (implied) system symbol table for Ion v1.0. -var V1SystemSymbolTable = NewSharedSymbolTable("$ion", 1, []string{ - "$ion", - "$ion_1_0", - "$ion_symbol_table", - "name", - "version", - "imports", - "symbols", - "max_id", - "$ion_shared_symbol_table", -}) - -// A BogusSST represents an SST imported by an LST that cannot be found in the -// local catalog. It exists to reserve some part of the symbol ID space so other -// symbol tables get mapped to the right IDs. -type bogusSST struct { - name string - version int - maxID uint64 -} - -var _ SharedSymbolTable = &bogusSST{} - -func (s *bogusSST) Name() string { - return s.name -} - -func (s *bogusSST) Version() int { - return s.version -} - -func (s *bogusSST) Imports() []SharedSymbolTable { - return nil -} - -func (s *bogusSST) Symbols() []string { - return nil -} - -func (s *bogusSST) MaxID() uint64 { - return s.maxID -} - -func (s *bogusSST) Adjust(maxID uint64) SharedSymbolTable { - return &bogusSST{ - name: s.name, - version: s.version, - maxID: maxID, - } -} - -func (s *bogusSST) FindByName(sym string) (uint64, bool) { - return 0, false -} - -func (s *bogusSST) FindByID(id uint64) (string, bool) { - return "", false -} - -func (s *bogusSST) WriteTo(w Writer) error { - return &UsageError{"SharedSymbolTable.WriteTo", "bogus symbol table should never be written"} -} - -func (s *bogusSST) String() string { - buf := strings.Builder{} - w := NewTextWriter(&buf) - w.Annotations("$ion_shared_symbol_table", "bogus") - w.BeginStruct() - - w.FieldName("name") - w.WriteString(s.name) - - w.FieldName("version") - w.WriteInt(int64(s.version)) - - w.FieldName("max_id") - w.WriteUint(s.maxID) - - w.EndStruct() - return buf.String() -} - -// A LocalSymbolTable is transmitted in-band along with the binary data -// it describes. It may include SharedSymbolTables by reference. -type lst struct { - imports []SharedSymbolTable - offsets []uint64 - maxImportID uint64 - - symbols []string - index map[string]uint64 -} - -// NewLocalSymbolTable creates a new local symbol table. -func NewLocalSymbolTable(imports []SharedSymbolTable, symbols []string) SymbolTable { - imps, offsets, maxID := processImports(imports) - syms := make([]string, len(symbols)) - copy(syms, symbols) - - index := buildIndex(syms, maxID+1) - - return &lst{ - imports: imps, - offsets: offsets, - maxImportID: maxID, - symbols: syms, - index: index, - } -} - -func (t *lst) Imports() []SharedSymbolTable { - imps := make([]SharedSymbolTable, len(t.imports)) - copy(imps, t.imports) - return imps -} - -func (t *lst) Symbols() []string { - syms := make([]string, len(t.symbols)) - copy(syms, t.symbols) - return syms -} - -func (t *lst) MaxID() uint64 { - return t.maxImportID + uint64(len(t.symbols)) -} - -func (t *lst) FindByName(s string) (uint64, bool) { - for i, imp := range t.imports { - if id, ok := imp.FindByName(s); ok { - return t.offsets[i] + id, true - } - } - - if id, ok := t.index[s]; ok { - return id, true - } - - return 0, false -} - -func (t *lst) FindByID(id uint64) (string, bool) { - if id <= 0 { - return "", false - } - if id <= t.maxImportID { - return t.findByIDInImports(id) - } - - // Local to this symbol table. - idx := id - t.maxImportID - 1 - if idx < uint64(len(t.symbols)) { - return t.symbols[idx], true - } - - return "", false -} - -func (t *lst) findByIDInImports(id uint64) (string, bool) { - i := 1 - off := uint64(0) - - for ; i < len(t.imports); i++ { - if id <= t.offsets[i] { - break - } - off = t.offsets[i] - } - - return t.imports[i-1].FindByID(id - off) -} - -func (t *lst) WriteTo(w Writer) error { - if len(t.imports) == 1 && len(t.symbols) == 0 { - return nil - } - - w.Annotation("$ion_symbol_table") - w.BeginStruct() - - if len(t.imports) > 1 { - w.FieldName("imports") - w.BeginList() - for i := 1; i < len(t.imports); i++ { - imp := t.imports[i] - w.BeginStruct() - - w.FieldName("name") - w.WriteString(imp.Name()) - - w.FieldName("version") - w.WriteInt(int64(imp.Version())) - - w.FieldName("max_id") - w.WriteUint(imp.MaxID()) - - w.EndStruct() - } - w.EndList() - } - - if len(t.symbols) > 0 { - w.FieldName("symbols") - - w.BeginList() - for _, sym := range t.symbols { - w.WriteString(sym) - } - w.EndList() - } - - return w.EndStruct() -} - -func (t *lst) String() string { - buf := strings.Builder{} - - w := NewTextWriter(&buf) - t.WriteTo(w) - - return buf.String() -} - -// A SymbolTableBuilder helps you iteratively build a local symbol table. -type SymbolTableBuilder interface { - SymbolTable - - // Add adds a symbol to this symbol table. - Add(symbol string) (uint64, bool) - // Build creates an immutable local symbol table. - Build() SymbolTable -} - -type symbolTableBuilder struct { - lst -} - -// NewSymbolTableBuilder creates a new symbol table builder with the given imports. -func NewSymbolTableBuilder(imports ...SharedSymbolTable) SymbolTableBuilder { - imps, offsets, maxID := processImports(imports) - return &symbolTableBuilder{ - lst{ - imports: imps, - offsets: offsets, - maxImportID: maxID, - index: make(map[string]uint64), - }, - } -} - -func (b *symbolTableBuilder) Add(symbol string) (uint64, bool) { - if id, ok := b.FindByName(symbol); ok { - return id, false - } - - b.symbols = append(b.symbols, symbol) - id := b.maxImportID + uint64(len(b.symbols)) - b.index[symbol] = id - - return id, true -} - -func (b *symbolTableBuilder) Build() SymbolTable { - symbols := append([]string{}, b.symbols...) - index := make(map[string]uint64) - for s, i := range b.index { - index[s] = uint64(i) - } - - return &lst{ - imports: b.imports, - offsets: b.offsets, - maxImportID: b.maxImportID, - symbols: symbols, - index: index, - } -} - -// ProcessImports processes a slice of imports, returning an (augmented) copy, a set of -// offsets for each import, and the overall max ID. -func processImports(imports []SharedSymbolTable) ([]SharedSymbolTable, []uint64, uint64) { - // Add in V1SystemSymbolTable at the head of the list if it's not already included. - var imps []SharedSymbolTable - if len(imports) > 0 && imports[0].Name() == "$ion" { - imps = make([]SharedSymbolTable, len(imports)) - copy(imps, imports) - } else { - imps = make([]SharedSymbolTable, len(imports)+1) - imps[0] = V1SystemSymbolTable - copy(imps[1:], imports) - } - - // Calculate offsets. - maxID := uint64(0) - offsets := make([]uint64, len(imps)) - for i, imp := range imps { - offsets[i] = maxID - maxID += imp.MaxID() - } - - return imps, offsets, maxID -} - -// BuildIndex builds an index from symbol name to symbol ID. -func buildIndex(symbols []string, offset uint64) map[string]uint64 { - index := make(map[string]uint64) - - for i, sym := range symbols { - if sym != "" { - if _, ok := index[sym]; !ok { - index[sym] = offset + uint64(i) - } - } - } - - return index -} diff --git a/symboltable_test.go b/symboltable_test.go deleted file mode 100644 index 4fe1aaed..00000000 --- a/symboltable_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package ion - -import ( - "fmt" - "testing" -) - -func TestSharedSymbolTable(t *testing.T) { - st := NewSharedSymbolTable("test", 2, []string{ - "abc", - "def", - "foo'bar", - "null", - "def", - "ghi", - }) - - if st.Name() != "test" { - t.Errorf("wrong name: %v", st.Name()) - } - if st.Version() != 2 { - t.Errorf("wrong version: %v", st.Version()) - } - if st.MaxID() != 6 { - t.Errorf("wrong maxid: %v", st.MaxID()) - } - - testFindByName(t, st, "def", 2) - testFindByName(t, st, "null", 4) - testFindByName(t, st, "bogus", 0) - - testFindByID(t, st, 0, "") - testFindByID(t, st, 2, "def") - testFindByID(t, st, 4, "null") - testFindByID(t, st, 7, "") - - testString(t, st, `$ion_shared_symbol_table::{name:"test",version:2,symbols:["abc","def","foo'bar","null","def","ghi"]}`) -} - -func TestLocalSymbolTable(t *testing.T) { - st := NewLocalSymbolTable(nil, []string{"foo", "bar"}) - - if st.MaxID() != 11 { - t.Errorf("wrong maxid: %v", st.MaxID()) - } - - testFindByName(t, st, "$ion", 1) - testFindByName(t, st, "foo", 10) - testFindByName(t, st, "bar", 11) - testFindByName(t, st, "bogus", 0) - - testFindByID(t, st, 0, "") - testFindByID(t, st, 1, "$ion") - testFindByID(t, st, 10, "foo") - testFindByID(t, st, 11, "bar") - testFindByID(t, st, 12, "") - - testString(t, st, `$ion_symbol_table::{symbols:["foo","bar"]}`) -} - -func TestLocalSymbolTableWithImports(t *testing.T) { - shared := NewSharedSymbolTable("shared", 1, []string{ - "foo", - "bar", - }) - imports := []SharedSymbolTable{shared} - - st := NewLocalSymbolTable(imports, []string{ - "foo2", - "bar2", - }) - - if st.MaxID() != 13 { // 9 from $ion.1, 2 from test.1, 2 local. - t.Errorf("wrong maxid: %v", st.MaxID()) - } - - testFindByName(t, st, "$ion", 1) - testFindByName(t, st, "$ion_shared_symbol_table", 9) - testFindByName(t, st, "foo", 10) - testFindByName(t, st, "bar", 11) - testFindByName(t, st, "foo2", 12) - testFindByName(t, st, "bar2", 13) - testFindByName(t, st, "bogus", 0) - - testFindByID(t, st, 0, "") - testFindByID(t, st, 1, "$ion") - testFindByID(t, st, 9, "$ion_shared_symbol_table") - testFindByID(t, st, 10, "foo") - testFindByID(t, st, 11, "bar") - testFindByID(t, st, 12, "foo2") - testFindByID(t, st, 13, "bar2") - testFindByID(t, st, 14, "") - - testString(t, st, `$ion_symbol_table::{imports:[{name:"shared",version:1,max_id:2}],symbols:["foo2","bar2"]}`) -} - -func TestSymbolTableBuilder(t *testing.T) { - b := NewSymbolTableBuilder() - - id, ok := b.Add("name") - if ok { - t.Error("Add(name) returned true") - } - if id != 4 { - t.Errorf("Add(name) returned %v", id) - } - - id, ok = b.Add("foo") - if !ok { - t.Error("Add(foo) returned false") - } - if id != 10 { - t.Errorf("Add(foo) returned %v", id) - } - - id, ok = b.Add("foo") - if ok { - t.Error("Second Add(foo) returned true") - } - if id != 10 { - t.Errorf("Second Add(foo) returned %v", id) - } - - st := b.Build() - if st.MaxID() != 10 { - t.Errorf("maxid returned %v", st.MaxID()) - } - - testFindByName(t, st, "$ion", 1) - testFindByName(t, st, "foo", 10) - testFindByName(t, st, "bogus", 0) - - testFindByID(t, st, 1, "$ion") - testFindByID(t, st, 10, "foo") - testFindByID(t, st, 11, "") -} - -func testFindByName(t *testing.T, st SymbolTable, sym string, expected uint64) { - t.Run("FindByName("+sym+")", func(t *testing.T) { - actual, ok := st.FindByName(sym) - if expected == 0 { - if ok { - t.Fatalf("unexpectedly found: %v", actual) - } - } else { - if !ok { - t.Fatal("unexpectedly not found") - } - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - } - }) -} - -func testFindByID(t *testing.T, st SymbolTable, id uint64, expected string) { - t.Run(fmt.Sprintf("FindByID(%v)", id), func(t *testing.T) { - actual, ok := st.FindByID(id) - if expected == "" { - if ok { - t.Fatalf("unexpectedly found: %v", actual) - } - } else { - if !ok { - t.Fatal("unexpectedly not found") - } - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - } - }) -} - -func testString(t *testing.T, st SymbolTable, expected string) { - t.Run("String()", func(t *testing.T) { - actual := st.String() - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) -} diff --git a/textreader.go b/textreader.go deleted file mode 100644 index d9cf4077..00000000 --- a/textreader.go +++ /dev/null @@ -1,658 +0,0 @@ -package ion - -import ( - "bufio" - "encoding/base64" - "fmt" - "math" - "strconv" -) - -// trs is the state of the text reader. -type trs uint8 - -const ( - trsDone trs = iota - trsBeforeFieldName - trsBeforeTypeAnnotations - trsBeforeContainer - trsAfterValue -) - -func (s trs) String() string { - switch s { - case trsDone: - return "" - case trsBeforeFieldName: - return "" - case trsBeforeTypeAnnotations: - return "" - case trsBeforeContainer: - return "" - case trsAfterValue: - return "" - default: - return strconv.Itoa(int(s)) - } -} - -// A textReader is a Reader that reads text Ion. -type textReader struct { - reader - - tok tokenizer - state trs -} - -func newTextReaderBuf(in *bufio.Reader) Reader { - return &textReader{ - tok: tokenizer{ - in: in, - }, - state: trsBeforeTypeAnnotations, - } -} - -// SymbolTable returns the current symbol table. -func (t *textReader) SymbolTable() SymbolTable { - // TODO: Include me if present in the input stream? - return nil -} - -// Next moves the reader to the next value. -func (t *textReader) Next() bool { - if t.state == trsDone || t.eof { - return false - } - - // If we haven't fully read the current value, skip over it. - err := t.finishValue() - if err != nil { - t.explode(err) - return false - } - - t.clear() - - // Loop until we've consumed enough tokens to know what the next value is. - for { - if err := t.tok.Next(); err != nil { - t.explode(err) - return false - } - - var done bool - var err error - - switch t.state { - case trsAfterValue: - done, err = t.nextAfterValue() - case trsBeforeFieldName: - done, err = t.nextBeforeFieldName() - case trsBeforeTypeAnnotations: - done, err = t.nextBeforeTypeAnnotations() - default: - panic(fmt.Sprintf("unexpected state: %v", t.state)) - } - if err != nil { - t.explode(err) - return false - } - - if done { - // We're done reading tokens. If we hit the end of the current sequence, - // return false. Otherwise, we've got a value for the caller. - return !t.eof - } - } -} - -// NextAfterValue moves to the next value when we're in the -// AfterValue state. -func (t *textReader) nextAfterValue() (bool, error) { - tok := t.tok.Token() - switch tok { - case tokenComma: - // There's another value coming; eat the comma and move to the - // appropriate next state. - switch t.ctx.peek() { - case ctxInStruct: - t.state = trsBeforeFieldName - case ctxInList: - t.state = trsBeforeTypeAnnotations - default: - panic(fmt.Sprintf("unexpected context: %v", t.ctx.peek())) - } - return false, nil - - case tokenCloseBrace: - // No more values in this struct. - if t.ctx.peek() == ctxInStruct { - t.eof = true - return true, nil - } - return false, &UnexpectedTokenError{"}", t.tok.Pos() - 1} - - case tokenCloseBracket: - // No more values in this list. - if t.ctx.peek() == ctxInList { - t.eof = true - return true, nil - } - return false, &UnexpectedTokenError{"]", t.tok.Pos() - 1} - - default: - return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} - } -} - -// NextBeforeFieldName moves to the next value when we're in the -// BeforeFieldName state. -func (t *textReader) nextBeforeFieldName() (bool, error) { - tok := t.tok.Token() - switch tok { - case tokenCloseBrace: - // No more values in this struct. - t.eof = true - return true, nil - - case tokenSymbol, tokenSymbolQuoted, tokenString, tokenLongString: - // Read the field name. - val, err := t.tok.ReadValue(tok) - if err != nil { - return false, err - } - if tok == tokenSymbol { - if err := t.verifyUnquotedSymbol(val, "field name"); err != nil { - return false, err - } - } - - // Skip over the following colon. - if err = t.tok.Next(); err != nil { - return false, err - } - if tok = t.tok.Token(); tok != tokenColon { - return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} - } - - t.fieldName = val - t.state = trsBeforeTypeAnnotations - - return false, nil - - default: - return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} - } -} - -// NextBeforeTypeAnnotations moves to the next value when we're in the -// BeforeTypeAnnotations state. -func (t *textReader) nextBeforeTypeAnnotations() (bool, error) { - tok := t.tok.Token() - switch tok { - case tokenEOF: - if t.ctx.peek() == ctxAtTopLevel { - t.eof = true - return true, nil - } - return false, &UnexpectedEOFError{t.tok.Pos() - 1} - - case tokenSymbolOperator, tokenDot: - if t.ctx.peek() != ctxInSexp { - // Operators can only appear inside an sexp. - return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} - } - fallthrough - - case tokenSymbol, tokenSymbolQuoted: - val, err := t.tok.ReadValue(tok) - if err != nil { - return false, err - } - - ok, ws, err := t.tok.SkipDoubleColon() - if err != nil { - return false, err - } - - if ok { - // val was an annotation; remember it and keep going. - if tok == tokenSymbol { - if err := t.verifyUnquotedSymbol(val, "annotation"); err != nil { - return false, err - } - } - t.annotations = append(t.annotations, val) - return false, nil - } - - // val was a legit symbol value. - if err := t.onSymbol(val, tok, ws); err != nil { - return false, err - } - return true, nil - - case tokenString, tokenLongString: - val, err := t.tok.ReadValue(tok) - if err != nil { - return false, err - } - - t.state = t.stateAfterValue() - t.valueType = StringType - t.value = val - return true, nil - - case tokenBinary, tokenHex, tokenNumber, tokenFloatInf, tokenFloatMinusInf: - if err := t.onNumber(tok); err != nil { - return false, err - } - return true, nil - - case tokenTimestamp: - if err := t.onTimestamp(); err != nil { - return false, err - } - return true, nil - - case tokenOpenDoubleBrace: - if err := t.onLob(); err != nil { - return false, err - } - return true, nil - - case tokenOpenBrace: - t.state = trsBeforeContainer - t.valueType = StructType - t.value = StructType - return true, nil - - case tokenOpenBracket: - t.state = trsBeforeContainer - t.valueType = ListType - t.value = ListType - return true, nil - - case tokenOpenParen: - t.state = trsBeforeContainer - t.valueType = SexpType - t.value = SexpType - return true, nil - - case tokenCloseBracket: - // No more values in this list. - if t.ctx.peek() == ctxInList { - t.eof = true - return true, nil - } - return false, &UnexpectedTokenError{"]", t.tok.Pos() - 1} - - case tokenCloseParen: - // No more values in this sexp. - if t.ctx.peek() == ctxInSexp { - t.eof = true - return true, nil - } - return false, &UnexpectedTokenError{")", t.tok.Pos() - 1} - - default: - return false, &UnexpectedTokenError{tok.String(), t.tok.Pos() - 1} - } -} - -// StepIn steps in to a container. -func (t *textReader) StepIn() error { - if t.err != nil { - return t.err - } - if t.state != trsBeforeContainer { - return &UsageError{"Reader.StepIn", fmt.Sprintf("cannot step in to a %v", t.valueType)} - } - - ctx := containerTypeToCtx(t.valueType) - t.ctx.push(ctx) - - if ctx == ctxInStruct { - t.state = trsBeforeFieldName - } else { - t.state = trsBeforeTypeAnnotations - } - - t.clear() - - t.tok.SetFinished() - return nil -} - -// StepOut steps out of a container. -func (t *textReader) StepOut() error { - if t.err != nil { - return t.err - } - - ctx := t.ctx.peek() - if ctx == ctxAtTopLevel { - return &UsageError{"Reader.StepOut", "cannot step out of top-level datagram"} - } - ctype := ctxToContainerType(ctx) - - // Finish off whatever value *inside* the container that we're currently reading. - _, err := t.tok.FinishValue() - if err != nil { - t.explode(err) - return err - } - - // If we haven't seen the end of the container yet, skip values until we find it. - if !t.eof { - if err := t.tok.SkipContainerContents(ctype); err != nil { - t.explode(err) - return err - } - } - - t.ctx.pop() - t.state = t.stateAfterValue() - t.clear() - t.eof = false - - return nil -} - -// VerifyUnquotedSymbol checks for certain 'special' values that are returned from -// the tokenizer as symbols but cannot be used as field names or annotations. -func (t *textReader) verifyUnquotedSymbol(val string, ctx string) error { - switch val { - case "null", "true", "false", "nan": - return &SyntaxError{fmt.Sprintf("unquoted keyword '%v' as %v", val, ctx), t.tok.Pos() - 1} - } - return nil -} - -// OnSymbol handles finding a symbol-token value. -func (t *textReader) onSymbol(val string, tok token, ws bool) error { - valueType := SymbolType - var value interface{} = val - - if tok == tokenSymbol { - switch val { - case "null": - vt, err := t.onNull(ws) - if err != nil { - return err - } - valueType = vt - value = nil - - case "true": - valueType = BoolType - value = true - - case "false": - valueType = BoolType - value = false - - case "nan": - valueType = FloatType - value = math.NaN() - } - } - - t.state = t.stateAfterValue() - t.valueType = valueType - t.value = value - - return nil -} - -// OnNull handles finding a null token. -func (t *textReader) onNull(ws bool) (Type, error) { - if !ws { - ok, err := t.tok.SkipDot() - if err != nil { - return NoType, err - } - if ok { - return t.readNullType() - } - } - return NullType, nil -} - -// readNullType reads the null.{this} type symbol. -func (t *textReader) readNullType() (Type, error) { - if err := t.tok.Next(); err != nil { - return NoType, err - } - if t.tok.Token() != tokenSymbol { - msg := fmt.Sprintf("invalid symbol null.%v", t.tok.Token()) - return NoType, &SyntaxError{msg, t.tok.Pos() - 1} - } - - val, err := t.tok.ReadValue(tokenSymbol) - if err != nil { - return NoType, err - } - - switch val { - case "null": - return NullType, nil - case "bool": - return BoolType, nil - case "int": - return IntType, nil - case "float": - return FloatType, nil - case "decimal": - return DecimalType, nil - case "timestamp": - return TimestampType, nil - case "symbol": - return SymbolType, nil - case "string": - return StringType, nil - case "blob": - return BlobType, nil - case "clob": - return ClobType, nil - case "list": - return ListType, nil - case "struct": - return StructType, nil - case "sexp": - return SexpType, nil - default: - msg := fmt.Sprintf("invalid symbol null.%v", t.tok.Token()) - return NoType, &SyntaxError{msg, t.tok.Pos() - 1} - } -} - -// OnNumber handles finding a number token. -func (t *textReader) onNumber(tok token) error { - var valueType Type - var value interface{} - - switch tok { - case tokenBinary: - val, err := t.tok.ReadValue(tok) - if err != nil { - return err - } - - valueType = IntType - value, err = parseInt(val, 2) - if err != nil { - return err - } - - case tokenHex: - val, err := t.tok.ReadValue(tok) - if err != nil { - return err - } - - valueType = IntType - value, err = parseInt(val, 16) - if err != nil { - return err - } - - case tokenNumber: - val, tt, err := t.tok.ReadNumber() - if err != nil { - return err - } - - valueType = tt - - switch tt { - case IntType: - value, err = parseInt(val, 10) - case FloatType: - value, err = parseFloat(val) - case DecimalType: - value, err = parseDecimal(val) - default: - panic(fmt.Sprintf("unexpected type %v", tt)) - } - - if err != nil { - return err - } - - case tokenFloatInf: - valueType = FloatType - value = math.Inf(1) - - case tokenFloatMinusInf: - valueType = FloatType - value = math.Inf(-1) - - default: - panic(fmt.Sprintf("unexpected token type %v", tok)) - } - - t.state = t.stateAfterValue() - t.valueType = valueType - t.value = value - - return nil -} - -// OnTimestamp handles finding a timestamp token. -func (t *textReader) onTimestamp() error { - val, err := t.tok.ReadValue(tokenTimestamp) - if err != nil { - return err - } - - value, err := parseTimestamp(val) - if err != nil { - return err - } - - t.state = t.stateAfterValue() - t.valueType = TimestampType - t.value = value - - return nil -} - -// OnLob handles finding a [bc]lob token. -func (t *textReader) onLob() error { - c, err := t.tok.SkipLobWhitespace() - if err != nil { - return err - } - - var ( - valType Type - val []byte - ) - - if c == '"' { - // Short clob. - valType = ClobType - - str, err := t.tok.ReadShortClob() - if err != nil { - return err - } - - val = []byte(str) - - } else if c == '\'' { - // Long clob. - ok, err := t.tok.IsTripleQuote() - if err != nil { - return err - } - if !ok { - return t.tok.invalidChar(c) - } - - valType = ClobType - - str, err := t.tok.ReadLongClob() - if err != nil { - return err - } - - val = []byte(str) - - } else { - // Normal blob. - valType = BlobType - t.tok.unread(c) - - b64, err := t.tok.ReadBlob() - if err != nil { - return err - } - - val, err = base64.StdEncoding.DecodeString(b64) - if err != nil { - return err - } - } - - t.state = t.stateAfterValue() - t.valueType = valType - t.value = val - - return nil -} - -// FinishValue finishes reading the current value, if there is one. -func (t *textReader) finishValue() error { - ok, err := t.tok.FinishValue() - if err != nil { - return err - } - - if ok { - t.state = t.stateAfterValue() - } - - return nil -} - -func (t *textReader) stateAfterValue() trs { - ctx := t.ctx.peek() - switch ctx { - case ctxInList, ctxInStruct: - return trsAfterValue - case ctxInSexp, ctxAtTopLevel: - return trsBeforeTypeAnnotations - default: - panic(fmt.Sprintf("invalid ctx %v", ctx)) - } -} - -// Explode explodes the reader state when something unexpected -// happens and further calls to Next are a bad idea. -func (t *textReader) explode(err error) { - t.state = trsDone - t.err = err -} diff --git a/textreader_test.go b/textreader_test.go deleted file mode 100644 index 68a84ba8..00000000 --- a/textreader_test.go +++ /dev/null @@ -1,788 +0,0 @@ -package ion - -import ( - "bytes" - "math" - "math/big" - "testing" - "time" -) - -func TestIgnoreValues(t *testing.T) { - r := NewReaderStr("(skip ++ me / please) {skip: me, please: 0}\n[skip, me, please]\nfoo") - - _next(t, r, SexpType) - _next(t, r, StructType) - _next(t, r, ListType) - - _symbol(t, r, "foo") - _eof(t, r) -} - -func TestReadSexps(t *testing.T) { - test := func(str string, f containerhandler) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _sexp(t, r, f) - _eof(t, r) - }) - } - - test("(\t)", func(t *testing.T, r Reader) { - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() != nil { - t.Fatal(r.Err()) - } - }) - - test("(foo)", func(t *testing.T, r Reader) { - _symbol(t, r, "foo") - }) - - test("(foo bar baz :: boop)", func(t *testing.T, r Reader) { - _symbol(t, r, "foo") - _symbol(t, r, "bar") - _symbolAF(t, r, "", []string{"baz"}, "boop") - }) -} - -func TestStructs(t *testing.T) { - test := func(str string, f containerhandler) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _struct(t, r, f) - _eof(t, r) - }) - } - - test("{\r\n}", func(t *testing.T, r Reader) { - _eof(t, r) - }) - - test("{foo : bar :: baz}", func(t *testing.T, r Reader) { - _symbolAF(t, r, "foo", []string{"bar"}, "baz") - }) - - test("{foo: a, bar: b, baz: c}", func(t *testing.T, r Reader) { - _symbolAF(t, r, "foo", nil, "a") - _symbolAF(t, r, "bar", nil, "b") - _symbolAF(t, r, "baz", nil, "c") - }) -} - -func TestMultipleStructs(t *testing.T) { - r := NewReaderStr("{} {} {}") - - for i := 0; i < 3; i++ { - _struct(t, r, func(t *testing.T, r Reader) { - _eof(t, r) - }) - } - - _eof(t, r) -} - -func TestNullStructs(t *testing.T) { - r := NewReaderStr("null.struct 'null'::{foo:bar}") - - _null(t, r, StructType) - _nextAF(t, r, StructType, "", []string{"null"}) - _eof(t, r) -} - -func TestLists(t *testing.T) { - test := func(str string, f containerhandler) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _list(t, r, f) - _eof(t, r) - }) - } - - test("[ ]", func(t *testing.T, r Reader) { - _eof(t, r) - }) - - test("[foo]", func(t *testing.T, r Reader) { - _symbol(t, r, "foo") - _eof(t, r) - }) - - test("[foo, bar, baz::boop]", func(t *testing.T, r Reader) { - _symbol(t, r, "foo") - _symbol(t, r, "bar") - _symbolAF(t, r, "", []string{"baz"}, "boop") - _eof(t, r) - }) -} - -func TestReadNestedLists(t *testing.T) { - empty := func(t *testing.T, r Reader) { - _eof(t, r) - } - - r := NewReaderStr("[[], [[]]]") - - _list(t, r, func(t *testing.T, r Reader) { - _list(t, r, empty) - - _list(t, r, func(t *testing.T, r Reader) { - _list(t, r, empty) - }) - - _eof(t, r) - }) - - _eof(t, r) -} - -func TestClobs(t *testing.T) { - test := func(str string, eval []byte) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _next(t, r, ClobType) - - val, err := r.ByteValue() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } - - _eof(t, r) - }) - } - - test("{{\"\"}}", []byte{}) - test("{{ \"hello world\" }}", []byte("hello world")) - test("{{'''hello world'''}}", []byte("hello world")) - test("{{'''hello'''\n'''world'''}}", []byte("helloworld")) -} - -func TestBlobs(t *testing.T) { - test := func(str string, eval []byte) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _next(t, r, BlobType) - - val, err := r.ByteValue() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } - - _eof(t, r) - }) - } - - test("{{}}", []byte{}) - test("{{AA==}}", []byte{0}) - test("{{ SGVsbG8g\r\nV29ybGQ= }}", []byte("Hello World")) -} - -func TestTimestamps(t *testing.T) { - testA := func(str string, etas []string, eval time.Time) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _nextAF(t, r, TimestampType, "", etas) - - val, err := r.TimeValue() - if err != nil { - t.Fatal(err) - } - if !val.Equal(eval) { - t.Errorf("expected %v, got %v", eval, val) - } - - _eof(t, r) - }) - } - - test := func(str string, eval time.Time) { - testA(str, nil, eval) - } - - et := time.Date(2001, time.January, 1, 0, 0, 0, 0, time.UTC) - test("2001T", et) - test("2001-01T", et) - test("2001-01-01", et) - test("2001-01-01T", et) - test("2001-01-01T00:00Z", et) - test("2001-01-01T00:00:00Z", et) - test("2001-01-01T00:00:00.000Z", et) - test("2001-01-01T00:00:00.000+00:00", et) - test("2001-01-01T00:00:00.000000Z", et) - test("2001-01-01T00:00:00.000000000Z", et) - test("2001-01-01T00:00:00.000000000999Z", et) // We truncate, at least for now. - - testA("foo::'bar'::2001-01-01T00:00:00.000Z", []string{"foo", "bar"}, et) -} - -func TestDecimals(t *testing.T) { - testA := func(str string, etas []string, eval string) { - t.Run(str, func(t *testing.T) { - ee := MustParseDecimal(eval) - - r := NewReaderStr(str) - _nextAF(t, r, DecimalType, "", etas) - - val, err := r.DecimalValue() - if err != nil { - t.Fatal(err) - } - if !ee.Equal(val) { - t.Errorf("expected %v, got %v", ee, val) - } - - _eof(t, r) - }) - } - - test := func(str string, eval string) { - testA(str, nil, eval) - } - - test("123.", "123") - test("123.0", "123") - test("123.456", "123.456") - test("123d2", "12300") - test("123d+2", "12300") - test("123d-2", "1.23") - - testA(" foo :: 'bar' :: 123. ", []string{"foo", "bar"}, "123") -} - -func TestFloats(t *testing.T) { - testA := func(str string, etas []string, eval float64) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _floatAF(t, r, "", etas, eval) - _eof(t, r) - }) - } - - test := func(str string, eval float64) { - testA(str, nil, eval) - } - - test("1e100\n", 1e100) - test("1.2e+0", 1.2) - test("-123.456e-78", -123.456e-78) - test("+inf", math.Inf(1)) - test("-inf", math.Inf(-1)) - - testA("foo::'bar'::1e100", []string{"foo", "bar"}, 1e100) -} - -func TestInts(t *testing.T) { - test := func(str string, f func(*testing.T, Reader)) { - t.Run(str, func(t *testing.T) { - r := NewReaderStr(str) - _next(t, r, IntType) - - f(t, r) - - _eof(t, r) - }) - } - - test("null.int", func(t *testing.T, r Reader) { - if !r.IsNull() { - t.Fatal("expected isnull=true, got false") - } - }) - - testInt := func(str string, eval int) { - test(str, func(t *testing.T, r Reader) { - val, err := r.IntValue() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - testInt("0", 0) - testInt("12_345", 12345) - testInt("-1_2_3_4_5", -12345) - testInt("0b00_0101", 5) - testInt("-0b00_0101", -5) - testInt("0x01_02_0e_0F", 0x01020e0f) - testInt("-0x0102_0e0F", -0x01020e0f) - - testInt64 := func(str string, eval int64) { - test(str, func(t *testing.T, r Reader) { - val, err := r.Int64Value() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - testInt64("0x123_FFFF_FFFF", 0x123FFFFFFFF) - testInt64("-0x123_FFFF_FFFF", -0x123FFFFFFFF) - - testBigInt := func(str string, estr string) { - test(str, func(t *testing.T, r Reader) { - val, err := r.BigIntValue() - if err != nil { - t.Fatal(err) - } - - eval, _ := (&big.Int{}).SetString(estr, 0) - if eval.Cmp(val) != 0 { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - testBigInt("0xEFFF_FFFF_FFFF_FFFF", "0xEFFFFFFFFFFFFFFF") - testBigInt("0xFFFF_FFFF_FFFF_FFFF", "0xFFFFFFFFFFFFFFFF") - testBigInt("-0x1_FFFF_FFFF_FFFF_FFFF", "-0x1FFFFFFFFFFFFFFFF") -} - -func TestStrings(t *testing.T) { - r := NewReaderStr(`foo::"bar" "baz" 'a'::'b'::'''beep''' '''boop''' null.string`) - - _stringAF(t, r, "", []string{"foo"}, "bar") - _string(t, r, "baz") - _stringAF(t, r, "", []string{"a", "b"}, "beepboop") - _null(t, r, StringType) - - _eof(t, r) -} - -func TestSymbols(t *testing.T) { - r := NewReaderStr("'null'::foo bar a::b::'baz' null.symbol") - - _symbolAF(t, r, "", []string{"null"}, "foo") - _symbol(t, r, "bar") - _symbolAF(t, r, "", []string{"a", "b"}, "baz") - _null(t, r, SymbolType) - - _eof(t, r) -} - -func TestSpecialSymbols(t *testing.T) { - r := NewReaderStr("null\nnull.struct\ntrue\nfalse\nnan") - - _null(t, r, NullType) - _null(t, r, StructType) - - _bool(t, r, true) - _bool(t, r, false) - _float(t, r, math.NaN()) - _eof(t, r) -} - -func TestOperators(t *testing.T) { - r := NewReaderStr("(a*(b+c))") - - _sexp(t, r, func(t *testing.T, r Reader) { - _symbol(t, r, "a") - _symbol(t, r, "*") - _sexp(t, r, func(t *testing.T, r Reader) { - _symbol(t, r, "b") - _symbol(t, r, "+") - _symbol(t, r, "c") - _eof(t, r) - }) - _eof(t, r) - }) -} - -func TestTopLevelOperators(t *testing.T) { - r := NewReaderStr("a + b") - - _symbol(t, r, "a") - - if r.Next() { - t.Errorf("next returned true") - } - if r.Err() == nil { - t.Error("no error") - } -} - -func TestTrsToString(t *testing.T) { - for i := trsDone; i <= trsAfterValue+1; i++ { - str := i.String() - if str == "" { - t.Errorf("expected a non-empty string for trs %v", uint8(i)) - } - } -} - -type containerhandler func(t *testing.T, r Reader) - -func _sexp(t *testing.T, r Reader, f containerhandler) { - _sexpAF(t, r, "", nil, f) -} - -func _sexpAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { - _containerAF(t, r, SexpType, efn, etas, f) -} - -func _struct(t *testing.T, r Reader, f containerhandler) { - _structAF(t, r, "", nil, f) -} - -func _structAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { - _containerAF(t, r, StructType, efn, etas, f) -} - -func _list(t *testing.T, r Reader, f containerhandler) { - _listAF(t, r, "", nil, f) -} - -func _listAF(t *testing.T, r Reader, efn string, etas []string, f containerhandler) { - _containerAF(t, r, ListType, efn, etas, f) -} - -func _containerAF(t *testing.T, r Reader, et Type, efn string, etas []string, f containerhandler) { - _nextAF(t, r, et, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.%v", et, et) - } - - if err := r.StepIn(); err != nil { - t.Fatal(err) - } - - f(t, r) - - if err := r.StepOut(); err != nil { - t.Fatal(err) - } -} - -func _int(t *testing.T, r Reader, eval int) { - _intAF(t, r, "", nil, eval) -} - -func _intAF(t *testing.T, r Reader, efn string, etas []string, eval int) { - _nextAF(t, r, IntType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.int", eval) - } - - size, err := r.IntSize() - if err != nil { - t.Fatal(err) - } - if size != Int32 { - t.Errorf("expected size=Int32, got %v", size) - } - - val, err := r.IntValue() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _int64(t *testing.T, r Reader, eval int64) { - _int64AF(t, r, "", nil, eval) -} - -func _int64AF(t *testing.T, r Reader, efn string, etas []string, eval int64) { - _nextAF(t, r, IntType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.int", eval) - } - - size, err := r.IntSize() - if err != nil { - t.Fatal(err) - } - if size != Int64 { - t.Errorf("expected size=Int64, got %v", size) - } - - val, err := r.Int64Value() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _uint(t *testing.T, r Reader, eval uint64) { - _uintAF(t, r, "", nil, eval) -} - -func _uintAF(t *testing.T, r Reader, efn string, etas []string, eval uint64) { - _nextAF(t, r, IntType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.int", eval) - } - - size, err := r.IntSize() - if err != nil { - t.Fatal(err) - } - if size != Uint64 { - t.Errorf("expected size=Uint, got %v", size) - } - - val, err := r.Uint64Value() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _bigInt(t *testing.T, r Reader, eval *big.Int) { - _bigIntAF(t, r, "", nil, eval) -} - -func _bigIntAF(t *testing.T, r Reader, efn string, etas []string, eval *big.Int) { - _nextAF(t, r, IntType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.int", eval) - } - - size, err := r.IntSize() - if err != nil { - t.Fatal(err) - } - if size != BigInt { - t.Errorf("expected size=BigInt, got %v", size) - } - - val, err := r.BigIntValue() - if err != nil { - t.Fatal(err) - } - if val.Cmp(eval) != 0 { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _float(t *testing.T, r Reader, eval float64) { - _floatAF(t, r, "", nil, eval) -} - -func _floatAF(t *testing.T, r Reader, efn string, etas []string, eval float64) { - _nextAF(t, r, FloatType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.float", eval) - } - - val, err := r.FloatValue() - if err != nil { - t.Fatal(err) - } - - if math.IsNaN(eval) { - if !math.IsNaN(val) { - t.Errorf("expected %v, got %v", eval, val) - } - } else if eval != val { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _decimal(t *testing.T, r Reader, eval *Decimal) { - _decimalAF(t, r, "", nil, eval) -} - -func _decimalAF(t *testing.T, r Reader, efn string, etas []string, eval *Decimal) { - _nextAF(t, r, DecimalType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.decimal", eval) - } - - val, err := r.DecimalValue() - if err != nil { - t.Fatal(err) - } - - if !eval.Equal(val) { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _timestamp(t *testing.T, r Reader, eval time.Time) { - _timestampAF(t, r, "", nil, eval) -} - -func _timestampAF(t *testing.T, r Reader, efn string, etas []string, eval time.Time) { - _nextAF(t, r, TimestampType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.timestamp", eval) - } - - val, err := r.TimeValue() - if err != nil { - t.Fatal(err) - } - - if !val.Equal(eval) { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _string(t *testing.T, r Reader, eval string) { - _stringAF(t, r, "", nil, eval) -} - -func _stringAF(t *testing.T, r Reader, efn string, etas []string, eval string) { - _nextAF(t, r, StringType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.string", eval) - } - - val, err := r.StringValue() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _symbol(t *testing.T, r Reader, eval string) { - _symbolAF(t, r, "", nil, eval) -} - -func _symbolAF(t *testing.T, r Reader, efn string, etas []string, eval string) { - _nextAF(t, r, SymbolType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.symbol", eval) - } - - val, err := r.StringValue() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _bool(t *testing.T, r Reader, eval bool) { - _boolAF(t, r, "", nil, eval) -} - -func _boolAF(t *testing.T, r Reader, efn string, etas []string, eval bool) { - _nextAF(t, r, BoolType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.bool", eval) - } - - val, err := r.BoolValue() - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _clob(t *testing.T, r Reader, eval []byte) { - _clobAF(t, r, "", nil, eval) -} - -func _clobAF(t *testing.T, r Reader, efn string, etas []string, eval []byte) { - _nextAF(t, r, ClobType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.clob", eval) - } - - val, err := r.ByteValue() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _blob(t *testing.T, r Reader, eval []byte) { - _blobAF(t, r, "", nil, eval) -} - -func _blobAF(t *testing.T, r Reader, efn string, etas []string, eval []byte) { - _nextAF(t, r, BlobType, efn, etas) - if r.IsNull() { - t.Fatalf("expected %v, got null.blob", eval) - } - - val, err := r.ByteValue() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } -} - -func _null(t *testing.T, r Reader, et Type) { - _nullAF(t, r, et, "", nil) -} - -func _nullAF(t *testing.T, r Reader, et Type, efn string, etas []string) { - _nextAF(t, r, et, efn, etas) - if !r.IsNull() { - t.Error("isnull returned false") - } -} - -func _next(t *testing.T, r Reader, et Type) { - _nextAF(t, r, et, "", nil) -} - -func _nextAF(t *testing.T, r Reader, et Type, efn string, etas []string) { - if !r.Next() { - t.Fatal(r.Err()) - } - if r.Type() != et { - t.Fatalf("expected %v, got %v", et, r.Type()) - } - - if efn != r.FieldName() { - t.Errorf("expected fieldname=%v, got %v", efn, r.FieldName()) - } - if !_strequals(etas, r.Annotations()) { - t.Errorf("expected type annotations=%v, got %v", etas, r.Annotations()) - } -} - -func _strequals(a, b []string) bool { - if len(a) != len(b) { - return false - } - - for i := 0; i < len(a); i++ { - if a[i] != b[i] { - return false - } - } - - return true -} - -func _eof(t *testing.T, r Reader) { - if r.Next() { - t.Fatal("next returned true") - } - if r.Err() != nil { - t.Fatal(r.Err()) - } -} diff --git a/textutils.go b/textutils.go deleted file mode 100644 index e4a03de4..00000000 --- a/textutils.go +++ /dev/null @@ -1,382 +0,0 @@ -package ion - -import ( - "fmt" - "io" - "math/big" - "strconv" - "strings" - "time" -) - -// Does this symbol need to be quoted in text form? -func symbolNeedsQuoting(sym string) bool { - switch sym { - case "", "null", "true", "false", "nan": - return true - } - - if !isIdentifierStart(int(sym[0])) { - return true - } - - for i := 1; i < len(sym); i++ { - if !isIdentifierPart(int(sym[i])) { - return true - } - } - - return false -} - -// Is this the text form of a symbol reference ($)? -func isSymbolRef(sym string) bool { - if len(sym) == 0 || sym[0] != '$' { - return false - } - - if len(sym) == 1 { - return false - } - - for i := 1; i < len(sym); i++ { - if !isDigit(int(sym[i])) { - return false - } - } - - return true -} - -// Is this a valid first character for an identifier? -func isIdentifierStart(c int) bool { - if c >= 'a' && c <= 'z' { - return true - } - if c >= 'A' && c <= 'Z' { - return true - } - if c == '_' || c == '$' { - return true - } - return false -} - -// Is this a valid character for later in an identifier? -func isIdentifierPart(c int) bool { - return isIdentifierStart(c) || isDigit(c) -} - -// Is this a valid hex digit? -func isHexDigit(c int) bool { - if isDigit(c) { - return true - } - if c >= 'a' && c <= 'f' { - return true - } - if c >= 'A' && c <= 'F' { - return true - } - return false -} - -// Is this a digit? -func isDigit(c int) bool { - return c >= '0' && c <= '9' -} - -// Is this a valid part of an operator symbol? -func isOperatorChar(c int) bool { - switch c { - case '!', '#', '%', '&', '*', '+', '-', '.', '/', ';', '<', '=', - '>', '?', '@', '^', '`', '|', '~': - return true - default: - return false - } -} - -// Does this character mark the end of a normal (unquoted) value? Does -// *not* check for the start of a comment, because that requires two -// characters. Use tokenizer.isStopChar(c) or check for it yourself. -func isStopChar(c int) bool { - switch c { - case -1, '{', '}', '[', ']', '(', ')', ',', '"', '\'', - ' ', '\t', '\n', '\r': - return true - default: - return false - } -} - -// Is this character whitespace? -func isWhitespace(c int) bool { - switch c { - case ' ', '\t', '\n', '\r': - return true - } - return false -} - -// Formats a float64 in Ion text style. -func formatFloat(val float64) string { - str := strconv.FormatFloat(val, 'e', -1, 64) - - // Ion uses lower case for special values. - switch str { - case "NaN": - return "nan" - case "+Inf": - return "+inf" - case "-Inf": - return "-inf" - } - - idx := strings.Index(str, "e") - if idx < 0 { - // We need to add an 'e' or it will get interpreted as an Ion decimal. - str += "e0" - } else if idx+2 < len(str) && str[idx+2] == '0' { - // FormatFloat returns exponents with a leading ±0 in some cases; strip it. - str = str[:idx+2] + str[idx+3:] - } - - return str -} - -// Write the given symbol out, quoting and encoding if necessary. -func writeSymbol(sym string, out io.Writer) error { - if symbolNeedsQuoting(sym) { - if err := writeRawChar('\'', out); err != nil { - return err - } - if err := writeEscapedSymbol(sym, out); err != nil { - return err - } - return writeRawChar('\'', out) - } - return writeRawString(sym, out) -} - -// Write the given symbol out, escaping any characters that need escaping. -func writeEscapedSymbol(sym string, out io.Writer) error { - for i := 0; i < len(sym); i++ { - c := sym[i] - if c < 32 || c == '\\' || c == '\'' { - if err := writeEscapedChar(c, out); err != nil { - return err - } - } else { - if err := writeRawChar(c, out); err != nil { - return err - } - } - } - return nil -} - -// Write the given string out, escaping any characters that need escaping. -func writeEscapedString(str string, out io.Writer) error { - for i := 0; i < len(str); i++ { - c := str[i] - if c < 32 || c == '\\' || c == '"' { - if err := writeEscapedChar(c, out); err != nil { - return err - } - } else { - if err := writeRawChar(c, out); err != nil { - return err - } - } - } - return nil -} - -// Write out the given character in escaped form. -func writeEscapedChar(c byte, out io.Writer) error { - switch c { - case 0: - return writeRawString("\\0", out) - case '\a': - return writeRawString("\\a", out) - case '\b': - return writeRawString("\\b", out) - case '\t': - return writeRawString("\\t", out) - case '\n': - return writeRawString("\\n", out) - case '\f': - return writeRawString("\\f", out) - case '\r': - return writeRawString("\\r", out) - case '\v': - return writeRawString("\\v", out) - case '\'': - return writeRawString("\\'", out) - case '"': - return writeRawString("\\\"", out) - case '\\': - return writeRawString("\\\\", out) - default: - buf := []byte{'\\', 'x', hexChars[(c>>4)&0xF], hexChars[c&0xF]} - return writeRawChars(buf, out) - } -} - -// Write out the given raw string. -func writeRawString(s string, out io.Writer) error { - _, err := out.Write([]byte(s)) - return err -} - -// Write out the given raw character sequence. -func writeRawChars(cs []byte, out io.Writer) error { - _, err := out.Write(cs) - return err -} - -// Write out the given raw character. -func writeRawChar(c byte, out io.Writer) error { - _, err := out.Write([]byte{c}) - return err -} - -func parseFloat(str string) (float64, error) { - val, err := strconv.ParseFloat(str, 64) - if err != nil { - if ne, ok := err.(*strconv.NumError); ok { - if ne.Err == strconv.ErrRange { - // Ignore me, val will be +-inf which is fine. - return val, nil - } - } - } - return val, err -} - -func parseDecimal(str string) (*Decimal, error) { - return ParseDecimal(str) -} - -func parseInt(str string, radix int) (interface{}, error) { - digits := str - - switch radix { - case 10: - // All set. - - case 2, 16: - neg := false - if digits[0] == '-' { - neg = true - digits = digits[1:] - } - - // Skip over the '0x' prefix. - digits = digits[2:] - if neg { - digits = "-" + digits - } - - default: - panic("unsupported radix") - } - - i, err := strconv.ParseInt(digits, radix, 64) - if err == nil { - return i, nil - } - if err.(*strconv.NumError).Err != strconv.ErrRange { - return nil, err - } - - bi, ok := (&big.Int{}).SetString(digits, radix) - if !ok { - return nil, &strconv.NumError{ - Func: "ParseInt", - Num: str, - Err: strconv.ErrSyntax, - } - } - - return bi, nil -} - -func parseTimestamp(val string) (time.Time, error) { - if len(val) < 5 { - return invalidTimestamp(val) - } - - year, err := strconv.ParseInt(val[:4], 10, 32) - if err != nil { - return invalidTimestamp(val) - } - if len(val) == 5 && (val[4] == 't' || val[4] == 'T') { - // yyyyT - return time.Date(int(year), 1, 1, 0, 0, 0, 0, time.UTC), nil - } - if val[4] != '-' { - return invalidTimestamp(val) - } - - if len(val) < 8 { - return invalidTimestamp(val) - } - - month, err := strconv.ParseInt(val[5:7], 10, 32) - if err != nil { - return invalidTimestamp(val) - } - - if len(val) == 8 && (val[7] == 't' || val[7] == 'T') { - // yyyy-mmT - return time.Date(int(year), time.Month(month), 1, 0, 0, 0, 0, time.UTC), nil - } - if val[7] != '-' { - return invalidTimestamp(val) - } - - if len(val) < 10 { - return invalidTimestamp(val) - } - - day, err := strconv.ParseInt(val[8:10], 10, 32) - if err != nil { - return invalidTimestamp(val) - } - - if len(val) == 10 || (len(val) == 11 && (val[10] == 't' || val[10] == 'T')) { - // yyyy-mm-dd or yyyy-mm-ddT - return time.Date(int(year), time.Month(month), int(day), 0, 0, 0, 0, time.UTC), nil - } - if val[10] != 't' && val[10] != 'T' { - return invalidTimestamp(val) - } - - if len(val) < 17 { - return invalidTimestamp(val) - } - if val[16] != ':' { - return time.Parse("2006-01-02T15:04Z07:00", val) - } - - if len(val) > 19 && val[19] == '.' { - i := 20 - for i < len(val) && isDigit(int(val[i])) { - i++ - } - - if i >= 29 { - // Too much precision for a go Time. - // TODO: We should probably round instead of truncating? Ah well. - return time.Parse(time.RFC3339Nano, val[:29]+val[i:]) - } - } - - return time.Parse(time.RFC3339Nano, val) -} - -func invalidTimestamp(val string) (time.Time, error) { - return time.Time{}, fmt.Errorf("ion: invalid timestamp: %v", val) -} diff --git a/textutils_test.go b/textutils_test.go deleted file mode 100644 index 67144a82..00000000 --- a/textutils_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package ion - -import ( - "strings" - "testing" - "time" -) - -func TestParseTimestamp(t *testing.T) { - test := func(str string, eval string) { - t.Run(str, func(t *testing.T) { - val, err := parseTimestamp(str) - if err != nil { - t.Fatal(err) - } - - et, err := time.Parse(time.RFC3339Nano, eval) - if err != nil { - t.Fatal(err) - } - - if !val.Equal(et) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - test("1234T", "1234-01-01T00:00:00Z") - test("1234-05T", "1234-05-01T00:00:00Z") - test("1234-05-06", "1234-05-06T00:00:00Z") - test("1234-05-06T", "1234-05-06T00:00:00Z") - test("1234-05-06T07:08Z", "1234-05-06T07:08:00Z") - test("1234-05-06T07:08:09Z", "1234-05-06T07:08:09Z") - test("1234-05-06T07:08:09.100Z", "1234-05-06T07:08:09.100Z") - test("1234-05-06T07:08:09.100100Z", "1234-05-06T07:08:09.100100Z") - - test("1234-05-06T07:08+09:10", "1234-05-06T07:08:00+09:10") - test("1234-05-06T07:08:09-10:11", "1234-05-06T07:08:09-10:11") -} - -func TestWriteSymbol(t *testing.T) { - test := func(sym, expected string) { - t.Run(expected, func(t *testing.T) { - buf := strings.Builder{} - if err := writeSymbol(sym, &buf); err != nil { - t.Fatal(err) - } - actual := buf.String() - if actual != expected { - t.Errorf("expected \"%v\", got \"%v\"", expected, actual) - } - }) - } - - test("", "''") - test("null", "'null'") - test("null.null", "'null.null'") - - test("basic", "basic") - test("_basic_", "_basic_") - test("$basic$", "$basic$") - test("$123", "$123") - - test("123", "'123'") - test("abc'def", "'abc\\'def'") - test("abc\"def", "'abc\"def'") -} - -func TestSymbolNeedsQuoting(t *testing.T) { - test := func(sym string, expected bool) { - t.Run(sym, func(t *testing.T) { - actual := symbolNeedsQuoting(sym) - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) - } - - test("", true) - test("null", true) - test("true", true) - test("false", true) - test("nan", true) - - test("basic", false) - test("_basic_", false) - test("basic$123", false) - test("$", false) - test("$basic", false) - test("$123", false) - - test("123", true) - test("abc.def", true) - test("abc,def", true) - test("abc:def", true) - test("abc{def", true) - test("abc}def", true) - test("abc[def", true) - test("abc]def", true) - test("abc'def", true) - test("abc\"def", true) -} - -func TestIsSymbolRef(t *testing.T) { - test := func(sym string, expected bool) { - t.Run(sym, func(t *testing.T) { - actual := isSymbolRef(sym) - if actual != expected { - t.Errorf("expected %v, got %v", expected, actual) - } - }) - } - - test("", false) - test("1", false) - test("a", false) - test("$", false) - test("$1", true) - test("$1234567890", true) - test("$a", false) - test("$1234a567890", false) -} - -func TestWriteEscapedSymbol(t *testing.T) { - test := func(sym, expected string) { - t.Run(expected, func(t *testing.T) { - buf := strings.Builder{} - if err := writeEscapedSymbol(sym, &buf); err != nil { - t.Fatal(err) - } - actual := buf.String() - if actual != expected { - t.Errorf("bad encoding of \"%v\": \"%v\"", - expected, actual) - } - }) - } - - test("basic", "basic") - test("\"basic\"", "\"basic\"") - test("o'clock", "o\\'clock") - test("c:\\", "c:\\\\") -} - -func TestWriteEscapedChar(t *testing.T) { - test := func(c byte, expected string) { - t.Run(expected, func(t *testing.T) { - buf := strings.Builder{} - if err := writeEscapedChar(c, &buf); err != nil { - t.Fatal(err) - } - actual := buf.String() - if actual != expected { - t.Errorf("bad encoding of '%v': \"%v\"", - expected, actual) - } - }) - } - - test(0, "\\0") - test('\n', "\\n") - test(1, "\\x01") - test('\xFF', "\\xFF") -} diff --git a/textwriter.go b/textwriter.go deleted file mode 100644 index 2ae82c1f..00000000 --- a/textwriter.go +++ /dev/null @@ -1,359 +0,0 @@ -package ion - -import ( - "encoding/base64" - "fmt" - "io" - "math/big" - "time" -) - -// TextWriterOpts defines a set of bit flag options for text writers. -type TextWriterOpts uint8 - -const ( - // TextWriterQuietFinish disables emiting a newline in Finish(). Convenient if you - // know you're only emiting one datagram; dangerous if there's a chance you're going - // to emit another datagram using the same Writer. - TextWriterQuietFinish TextWriterOpts = 1 -) - -// textWriter is a writer that writes human-readable text -type textWriter struct { - writer - needsSeparator bool - opts TextWriterOpts -} - -// NewTextWriter returns a new text writer. -func NewTextWriter(out io.Writer) Writer { - return NewTextWriterOpts(out, 0) -} - -// NewTextWriterOpts returns a new text writer with the given options. -func NewTextWriterOpts(out io.Writer, opts TextWriterOpts) Writer { - return &textWriter{ - writer: writer{ - out: out, - }, - opts: opts, - } -} - -// WriteNull writes an untyped null. -func (w *textWriter) WriteNull() error { - return w.writeValue("Writer.WriteNull", textNulls[NoType]) -} - -// WriteNullType writes a typed null. -func (w *textWriter) WriteNullType(t Type) error { - return w.writeValue("Writer.WriteNullType", textNulls[t]) -} - -// WriteBool writes a boolean value. -func (w *textWriter) WriteBool(val bool) error { - str := "false" - if val { - str = "true" - } - return w.writeValue("Writer.WriteBool", str) -} - -// WriteInt writes an integer value. -func (w *textWriter) WriteInt(val int64) error { - return w.writeValue("Writer.WriteInt", fmt.Sprintf("%d", val)) -} - -// WriteUint writes an unsigned integer value. -func (w *textWriter) WriteUint(val uint64) error { - return w.writeValue("Writer.WriteUint", fmt.Sprintf("%d", val)) -} - -// WriteBigInt writes a (big) integer value. -func (w *textWriter) WriteBigInt(val *big.Int) error { - return w.writeValue("Writer.WriteBigInt", val.String()) -} - -// WriteFloat writes a floating-point value. -func (w *textWriter) WriteFloat(val float64) error { - return w.writeValue("Writer.WriteFloat", formatFloat(val)) -} - -// WriteDecimal writes an arbitrary-precision decimal value. -func (w *textWriter) WriteDecimal(val *Decimal) error { - return w.writeValue("Writer.WriteDecimal", val.String()) -} - -// WriteTimestamp writes a timestamp. -func (w *textWriter) WriteTimestamp(val time.Time) error { - return w.writeValue("Writer.WriteTimestamp", val.Format(time.RFC3339Nano)) -} - -// WriteSymbol writes a symbol. -func (w *textWriter) WriteSymbol(val string) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteSymbol"); w.err != nil { - return w.err - } - - if w.err = writeSymbol(val, w.out); w.err != nil { - return w.err - } - - w.endValue() - return nil -} - -// WriteString writes a string. -func (w *textWriter) WriteString(val string) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteString"); w.err != nil { - return w.err - } - - if w.err = writeRawChar('"', w.out); w.err != nil { - return w.err - } - if w.err = writeEscapedString(val, w.out); w.err != nil { - return w.err - } - if w.err = writeRawChar('"', w.out); w.err != nil { - return w.err - } - - w.endValue() - return nil -} - -// WriteClob writes a clob. -func (w *textWriter) WriteClob(val []byte) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { - return w.err - } - - if w.err = writeRawString("{{\"", w.out); w.err != nil { - return w.err - } - for _, c := range val { - if c < 32 || c == '\\' || c == '"' || c > 0x7F { - if err := writeEscapedChar(c, w.out); err != nil { - return err - } - } else { - if err := writeRawChar(c, w.out); err != nil { - return err - } - } - } - if w.err = writeRawString("\"}}", w.out); w.err != nil { - return w.err - } - - w.endValue() - return nil -} - -// WriteBlob writes a blob. -func (w *textWriter) WriteBlob(val []byte) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue("Writer.WriteBlob"); w.err != nil { - return w.err - } - - if w.err = writeRawString("{{", w.out); w.err != nil { - return w.err - } - - enc := base64.NewEncoder(base64.StdEncoding, w.out) - enc.Write(val) - if w.err = enc.Close(); w.err != nil { - return w.err - } - - if w.err = writeRawString("}}", w.out); w.err != nil { - return w.err - } - - w.endValue() - return nil -} - -// BeginList begins writing a list. -func (w *textWriter) BeginList() error { - if w.err == nil { - w.err = w.begin("Writer.BeginList", ctxInList, '[') - } - return w.err -} - -// EndList finishes writing a list. -func (w *textWriter) EndList() error { - if w.err == nil { - w.err = w.end("Writer.EndList", ctxInList, ']') - } - return w.err -} - -// BeginSexp begins writing an s-expression. -func (w *textWriter) BeginSexp() error { - if w.err == nil { - w.err = w.begin("Writer.BeginSexp", ctxInSexp, '(') - } - return w.err -} - -// EndSexp finishes writing an s-expression. -func (w *textWriter) EndSexp() error { - if w.err == nil { - w.err = w.end("Writer.EndSexp", ctxInSexp, ')') - } - return w.err -} - -// BeginStruct begins writing a struct. -func (w *textWriter) BeginStruct() error { - if w.err == nil { - w.err = w.begin("Writer.BeginStruct", ctxInStruct, '{') - } - return w.err -} - -// EndStruct finishes writing a struct. -func (w *textWriter) EndStruct() error { - if w.err == nil { - w.err = w.end("Writer.EndStruct", ctxInStruct, '}') - } - return w.err -} - -// Finish finishes writing the current datagram. -func (w *textWriter) Finish() error { - if w.err != nil { - return w.err - } - if w.ctx.peek() != ctxAtTopLevel { - return &UsageError{"Writer.Finish", "not at top level"} - } - - if w.opts&TextWriterQuietFinish == 0 { - if w.err = writeRawChar('\n', w.out); w.err != nil { - return w.err - } - w.needsSeparator = false - } - - w.clear() - return nil -} - -// writeValue writes a stringified value to the output stream. -func (w *textWriter) writeValue(api string, val string) error { - if w.err != nil { - return w.err - } - if w.err = w.beginValue(api); w.err != nil { - return w.err - } - - if w.err = writeRawString(val, w.out); w.err != nil { - return w.err - } - - w.endValue() - return nil -} - -// beginValue begins the process of writing a value, by writing out -// a separator (if needed), field name (if in a struct), and type -// annotations (if any). -func (w *textWriter) beginValue(api string) error { - if w.needsSeparator { - var sep byte - switch w.ctx.peek() { - case ctxInStruct, ctxInList: - sep = ',' - case ctxInSexp: - sep = ' ' - default: - sep = '\n' - } - - if err := writeRawChar(sep, w.out); err != nil { - return err - } - } - - if w.inStruct() { - if w.fieldName == "" { - return &UsageError{api, "field name not set"} - } - name := w.fieldName - w.fieldName = "" - - if err := writeSymbol(name, w.out); err != nil { - return err - } - if err := writeRawChar(':', w.out); err != nil { - return err - } - } - - if len(w.annotations) > 0 { - as := w.annotations - w.annotations = nil - - for _, a := range as { - if err := writeSymbol(a, w.out); err != nil { - return err - } - if err := writeRawString("::", w.out); err != nil { - return err - } - } - } - - return nil -} - -// endValue finishes the process of writing a value. -func (w *textWriter) endValue() { - w.needsSeparator = true -} - -// begin starts writing a container of the given type. -func (w *textWriter) begin(api string, t ctx, c byte) error { - if err := w.beginValue(api); err != nil { - return err - } - - w.ctx.push(t) - w.needsSeparator = false - - return writeRawChar(c, w.out) -} - -// end finishes writing a container of the given type -func (w *textWriter) end(api string, t ctx, c byte) error { - if w.ctx.peek() != t { - return &UsageError{api, "not in that kind of container"} - } - - if err := writeRawChar(c, w.out); err != nil { - return err - } - - w.clear() - w.ctx.pop() - w.endValue() - - return nil -} diff --git a/textwriter_test.go b/textwriter_test.go deleted file mode 100644 index b3124c32..00000000 --- a/textwriter_test.go +++ /dev/null @@ -1,351 +0,0 @@ -package ion - -import ( - "math" - "math/big" - "strings" - "testing" - "time" -) - -func TestWriteTextTopLevelFieldName(t *testing.T) { - writeText(func(w Writer) { - if err := w.FieldName("foo"); err == nil { - t.Error("expected an error") - } - }) -} - -func TestWriteTextEmptyStruct(t *testing.T) { - testTextWriter(t, "{}", func(w Writer) { - if err := w.BeginStruct(); err != nil { - t.Fatal(err) - } - - if err := w.EndStruct(); err != nil { - t.Fatal(err) - } - - if err := w.EndStruct(); err == nil { - t.Fatal("no error from ending struct too many times") - } - }) -} - -func TestWriteTextAnnotatedStruct(t *testing.T) { - testTextWriter(t, "foo::$bar::'.baz'::{}", func(w Writer) { - w.Annotation("foo") - w.Annotation("$bar") - w.Annotation(".baz") - w.BeginStruct() - err := w.EndStruct() - - if err != nil { - t.Fatal(err) - } - }) -} - -func TestWriteTextNestedStruct(t *testing.T) { - testTextWriter(t, "{foo:'true'::{},'null':{}}", func(w Writer) { - w.BeginStruct() - - w.FieldName("foo") - w.Annotation("true") - w.BeginStruct() - w.EndStruct() - - w.FieldName("null") - w.BeginStruct() - w.EndStruct() - - w.EndStruct() - }) -} - -func TestWriteTextEmptyList(t *testing.T) { - testTextWriter(t, "[]", func(w Writer) { - if err := w.BeginList(); err != nil { - t.Fatal(err) - } - - if err := w.EndList(); err != nil { - t.Fatal(err) - } - - if err := w.EndList(); err == nil { - t.Error("no error calling endlist at top level") - } - }) -} - -func TestWriteTextNestedLists(t *testing.T) { - testTextWriter(t, "[{},foo::{},'null'::[]]", func(w Writer) { - w.BeginList() - - w.BeginStruct() - w.EndStruct() - - w.Annotation("foo") - w.BeginStruct() - w.EndStruct() - - w.Annotation("null") - w.BeginList() - w.EndList() - - w.EndList() - }) -} - -func TestWriteTextSexps(t *testing.T) { - testTextWriter(t, "()\n(())\n(() ())", func(w Writer) { - w.BeginSexp() - w.EndSexp() - - w.BeginSexp() - w.BeginSexp() - w.EndSexp() - w.EndSexp() - - w.BeginSexp() - w.BeginSexp() - w.EndSexp() - w.BeginSexp() - w.EndSexp() - w.EndSexp() - }) -} - -func TestWriteTextNulls(t *testing.T) { - expected := "[null,foo::null.null,null.bool,null.int,null.float,null.decimal," + - "null.timestamp,null.symbol,null.string,null.clob,null.blob," + - "null.list,'null'::null.sexp,null.struct]" - - testTextWriter(t, expected, func(w Writer) { - w.BeginList() - - w.WriteNull() - w.Annotation("foo") - w.WriteNullType(NullType) - w.WriteNullType(BoolType) - w.WriteNullType(IntType) - w.WriteNullType(FloatType) - w.WriteNullType(DecimalType) - w.WriteNullType(TimestampType) - w.WriteNullType(SymbolType) - w.WriteNullType(StringType) - w.WriteNullType(ClobType) - w.WriteNullType(BlobType) - w.WriteNullType(ListType) - w.Annotation("null") - w.WriteNullType(SexpType) - w.WriteNullType(StructType) - - w.EndList() - }) -} - -func TestWriteTextBool(t *testing.T) { - expected := "true\n(false '123'::true)\n'false'::false" - testTextWriter(t, expected, func(w Writer) { - w.WriteBool(true) - - w.BeginSexp() - - w.WriteBool(false) - w.Annotation("123") - w.WriteBool(true) - - w.EndSexp() - - w.Annotation("false") - w.WriteBool(false) - }) -} - -func TestWriteTextInt(t *testing.T) { - expected := "(zero::0 1 -1 (9223372036854775807 -9223372036854775808))" - testTextWriter(t, expected, func(w Writer) { - w.BeginSexp() - - w.Annotation("zero") - w.WriteInt(0) - w.WriteInt(1) - w.WriteInt(-1) - - w.BeginSexp() - w.WriteInt(math.MaxInt64) - w.WriteInt(math.MinInt64) - w.EndSexp() - - w.EndSexp() - }) -} - -func TestWriteTextBigInt(t *testing.T) { - expected := "[0,big::18446744073709551616]" - testTextWriter(t, expected, func(w Writer) { - w.BeginList() - - w.WriteBigInt(big.NewInt(0)) - - var val, max, one big.Int - max.SetUint64(math.MaxUint64) - one.SetInt64(1) - val.Add(&max, &one) - - w.Annotation("big") - w.WriteBigInt(&val) - - w.EndList() - }) -} - -func TestWriteTextFloat(t *testing.T) { - expected := "{z:0e+0,nz:-0e+0,s:1.234e+1,l:1.234e-55,n:nan,i:+inf,ni:-inf}" - testTextWriter(t, expected, func(w Writer) { - w.BeginStruct() - - w.FieldName("z") - w.WriteFloat(0.0) - w.FieldName("nz") - w.WriteFloat(-1.0 / math.Inf(1)) - - w.FieldName("s") - w.WriteFloat(12.34) - w.FieldName("l") - w.WriteFloat(12.34e-56) - - w.FieldName("n") - w.WriteFloat(math.NaN()) - w.FieldName("i") - w.WriteFloat(math.Inf(1)) - w.FieldName("ni") - w.WriteFloat(math.Inf(-1)) - - w.EndStruct() - }) -} - -func TestWriteTextDecimal(t *testing.T) { - expected := "0.\n-1.23d-98" - testTextWriter(t, expected, func(w Writer) { - w.WriteDecimal(MustParseDecimal("0")) - w.WriteDecimal(MustParseDecimal("-123d-100")) - }) -} - -func TestWriteTextTimestamp(t *testing.T) { - expected := "1970-01-01T00:00:00.001Z\n1970-01-01T01:23:00+01:23" - testTextWriter(t, expected, func(w Writer) { - w.WriteTimestamp(time.Unix(0, 1000000).In(time.UTC)) - w.WriteTimestamp(time.Unix(0, 0).In(time.FixedZone("wtf", 4980))) - }) -} - -func TestWriteTextSymbol(t *testing.T) { - expected := "{foo:bar,empty:'','null':'null',f:a::b::u::'lo🇺🇸',$123:$456}" - testTextWriter(t, expected, func(w Writer) { - w.BeginStruct() - - w.FieldName("foo") - w.WriteSymbol("bar") - w.FieldName("empty") - w.WriteSymbol("") - w.FieldName("null") - w.WriteSymbol("null") - - w.FieldName("f") - w.Annotation("a") - w.Annotation("b") - w.Annotation("u") - w.WriteSymbol("lo🇺🇸") - - w.FieldName("$123") - w.WriteSymbol("$456") - - w.EndStruct() - }) -} - -func TestWriteTextString(t *testing.T) { - expected := `("hello" "" ("\\\"\n\"\\" zany::"🤪"))` - testTextWriter(t, expected, func(w Writer) { - w.BeginSexp() - w.WriteString("hello") - w.WriteString("") - - w.BeginSexp() - w.WriteString("\\\"\n\"\\") - w.Annotation("zany") - w.WriteString("🤪") - w.EndSexp() - - w.EndSexp() - }) -} - -func TestWriteTextBlob(t *testing.T) { - expected := "{{AAEC/f7/}}\n{{SGVsbG8gV29ybGQ=}}\nempty::{{}}" - testTextWriter(t, expected, func(w Writer) { - w.WriteBlob([]byte{0, 1, 2, 0xFD, 0xFE, 0xFF}) - w.WriteBlob([]byte("Hello World")) - w.Annotation("empty") - w.WriteBlob(nil) - }) -} - -func TestWriteTextClob(t *testing.T) { - expected := "{hello:{{\"world\"}},bits:{{\"\\0\\x01\\xFE\\xFF\"}}}" - testTextWriter(t, expected, func(w Writer) { - w.BeginStruct() - w.FieldName("hello") - w.WriteClob([]byte("world")) - w.FieldName("bits") - w.WriteClob([]byte{0, 1, 0xFE, 0xFF}) - w.EndStruct() - }) -} - -func TestWriteTextFinish(t *testing.T) { - expected := "1\nfoo\n\"bar\"\n{}\n" - testTextWriter(t, expected, func(w Writer) { - w.WriteInt(1) - w.WriteSymbol("foo") - w.WriteString("bar") - w.BeginStruct() - w.EndStruct() - if err := w.Finish(); err != nil { - t.Fatal(err) - } - }) -} - -func TestWriteTextBadFinish(t *testing.T) { - buf := strings.Builder{} - w := NewTextWriter(&buf) - - w.BeginStruct() - err := w.Finish() - - if err == nil { - t.Error("should not be able to finish in the middle of a struct") - } -} - -func testTextWriter(t *testing.T, expected string, f func(Writer)) { - actual := writeText(f) - if actual != expected { - t.Errorf("expected: %v, actual: %v", expected, actual) - } -} - -func writeText(f func(Writer)) string { - buf := strings.Builder{} - w := NewTextWriter(&buf) - - f(w) - - return buf.String() -} diff --git a/tokenizer.go b/tokenizer.go deleted file mode 100644 index 3413de7f..00000000 --- a/tokenizer.go +++ /dev/null @@ -1,1265 +0,0 @@ -package ion - -import ( - "bufio" - "bytes" - "fmt" - "io" - "strings" -) - -type token int - -const ( - tokenError token = iota - - tokenEOF // End of input - - tokenNumber // Haven't seen enough to know which, yet - tokenBinary // 0b[01]+ - tokenHex // 0x[0-9a-fA-F]+ - tokenFloatInf // +inf - tokenFloatMinusInf // -inf - tokenTimestamp // 2001-01-01T00:00:00.000Z - - tokenSymbol // [a-zA-Z_]+ - tokenSymbolQuoted // '[^']+' - tokenSymbolOperator // +-/* - - tokenString // "[^"]+" - tokenLongString // '''[^']+''' - - tokenDot // . - tokenComma // , - tokenColon // : - tokenDoubleColon // :: - - tokenOpenParen // ( - tokenCloseParen // ) - tokenOpenBrace // { - tokenCloseBrace // } - tokenOpenBracket // [ - tokenCloseBracket // ] - tokenOpenDoubleBrace // {{ - tokenCloseDoubleBrace // }} -) - -func (t token) String() string { - switch t { - case tokenError: - return "" - case tokenEOF: - return "" - case tokenNumber: - return "" - case tokenBinary: - return "" - case tokenHex: - return "" - case tokenFloatInf: - return "+inf" - case tokenFloatMinusInf: - return "-inf" - case tokenTimestamp: - return "" - case tokenSymbol: - return "" - case tokenSymbolQuoted: - return "" - case tokenSymbolOperator: - return "" - - case tokenString: - return "" - case tokenLongString: - return "" - - case tokenDot: - return "." - case tokenComma: - return "," - case tokenColon: - return ":" - case tokenDoubleColon: - return "::" - - case tokenOpenParen: - return "(" - case tokenCloseParen: - return ")" - - case tokenOpenBrace: - return "{" - case tokenCloseBrace: - return "}" - - case tokenOpenBracket: - return "[" - case tokenCloseBracket: - return "]" - - case tokenOpenDoubleBrace: - return "{{" - case tokenCloseDoubleBrace: - return "}}" - - default: - return "" - } -} - -type tokenizer struct { - in *bufio.Reader - buffer []int - - token token - unfinished bool - pos uint64 -} - -func tokenizeString(in string) *tokenizer { - return tokenizeBytes([]byte(in)) -} - -func tokenizeBytes(in []byte) *tokenizer { - return tokenize(bytes.NewReader(in)) -} - -func tokenize(in io.Reader) *tokenizer { - return &tokenizer{ - in: bufio.NewReader(in), - } -} - -// Token returns the type of the current token. -func (t *tokenizer) Token() token { - return t.token -} - -func (t *tokenizer) Pos() uint64 { - return t.pos -} - -// Next advances to the next token in the input stream. -func (t *tokenizer) Next() error { - var c int - var err error - - if t.unfinished { - c, err = t.skipValue() - } else { - c, _, err = t.skipWhitespace() - } - - if err != nil { - return err - } - - switch { - case c == -1: - return t.ok(tokenEOF, true) - - case c == ':': - c2, err := t.peek() - if err != nil { - return err - } - if c2 == ':' { - t.read() - return t.ok(tokenDoubleColon, false) - } - return t.ok(tokenColon, false) - - case c == '{': - c2, err := t.peek() - if err != nil { - return err - } - if c2 == '{' { - t.read() - return t.ok(tokenOpenDoubleBrace, true) - } - return t.ok(tokenOpenBrace, true) - - case c == '}': - return t.ok(tokenCloseBrace, false) - - case c == '[': - return t.ok(tokenOpenBracket, true) - - case c == ']': - return t.ok(tokenCloseBracket, false) - - case c == '(': - return t.ok(tokenOpenParen, true) - - case c == ')': - return t.ok(tokenCloseParen, false) - - case c == ',': - return t.ok(tokenComma, false) - - case c == '.': - c2, err := t.peek() - if err != nil { - return err - } - if isOperatorChar(c2) { - t.unread(c) - return t.ok(tokenSymbolOperator, true) - } - return t.ok(tokenDot, false) - - case c == '\'': - ok, err := t.IsTripleQuote() - if err != nil { - return err - } - if ok { - return t.ok(tokenLongString, true) - } - return t.ok(tokenSymbolQuoted, true) - - case c == '+': - ok, err := t.isInf(c) - if err != nil { - return err - } - if ok { - return t.ok(tokenFloatInf, false) - } - t.unread(c) - return t.ok(tokenSymbolOperator, true) - - case c == '-': - c2, err := t.peek() - if err != nil { - return err - } - - if isDigit(c2) { - t.read() - tt, err := t.scanForNumericType(c2) - if err != nil { - return err - } - if tt == tokenTimestamp { - // can't have negative timestamps. - return t.invalidChar(c2) - } - t.unread(c2) - t.unread(c) - return t.ok(tt, true) - } - - ok, err := t.isInf(c) - if err != nil { - return err - } - if ok { - return t.ok(tokenFloatMinusInf, false) - } - - t.unread(c) - return t.ok(tokenSymbolOperator, true) - - case isOperatorChar(c): - t.unread(c) - return t.ok(tokenSymbolOperator, true) - - case c == '"': - return t.ok(tokenString, true) - - case isIdentifierStart(c): - t.unread(c) - return t.ok(tokenSymbol, true) - - case isDigit(c): - tt, err := t.scanForNumericType(c) - if err != nil { - return err - } - - t.unread(c) - return t.ok(tt, true) - - default: - return t.invalidChar(c) - } -} - -func (t *tokenizer) ok(tok token, more bool) error { - t.token = tok - t.unfinished = more - return nil -} - -// SetFinished marks the current token finished (indicating that the caller has -// chosen to step in to a list, sexp, or struct and Next should not skip over its -// contents in search of the next token). -func (t *tokenizer) SetFinished() { - t.unfinished = false -} - -// FinishValue skips to the end of the current value if (and only if) -// we're currently in the middle of reading it. -func (t *tokenizer) FinishValue() (bool, error) { - if !t.unfinished { - return false, nil - } - - c, err := t.skipValue() - if err != nil { - return true, err - } - - t.unread(c) - t.unfinished = false - return true, nil -} - -// ReadValue reads the value of a token of the given type. -func (t *tokenizer) ReadValue(tok token) (string, error) { - var str string - var err error - - switch tok { - case tokenSymbol: - str, err = t.readSymbol() - case tokenSymbolQuoted: - str, err = t.readQuotedSymbol() - case tokenSymbolOperator, tokenDot: - str, err = t.readOperator() - case tokenString: - str, err = t.readString() - case tokenLongString: - str, err = t.readLongString() - case tokenBinary: - str, err = t.readBinary() - case tokenHex: - str, err = t.readHex() - case tokenTimestamp: - str, err = t.readTimestamp() - default: - panic(fmt.Sprintf("unsupported token type %v", tok)) - } - - if err != nil { - return "", err - } - - t.unfinished = false - return str, nil -} - -// ReadNumber reads a number and determines the type. -func (t *tokenizer) ReadNumber() (string, Type, error) { - w := strings.Builder{} - - c, err := t.read() - if err != nil { - return "", NoType, err - } - - if c == '-' { - w.WriteByte('-') - c, err = t.read() - if err != nil { - return "", NoType, err - } - } - - first := c - oldlen := w.Len() - - c, err = t.readDigits(c, &w) - if err != nil { - return "", NoType, err - } - - if first == '0' { - if w.Len()-oldlen > 1 { - return "", NoType, &SyntaxError{"invalid leading zeroes", t.pos - 1} - } - } - - tt := IntType - - if c == '.' { - w.WriteByte('.') - tt = DecimalType - - if c, err = t.read(); err != nil { - return "", NoType, err - } - if c, err = t.readDigits(c, &w); err != nil { - return "", NoType, err - } - } - - switch c { - case 'e', 'E': - tt = FloatType - - w.WriteByte(byte(c)) - if c, err = t.readExponent(&w); err != nil { - return "", NoType, err - } - - case 'd', 'D': - tt = DecimalType - - w.WriteByte(byte(c)) - if c, err = t.readExponent(&w); err != nil { - return "", NoType, err - } - } - - ok, err := t.isStopChar(c) - if err != nil { - return "", NoType, err - } - if !ok { - return "", NoType, t.invalidChar(c) - } - t.unread(c) - - return w.String(), tt, nil -} - -func (t *tokenizer) readExponent(w io.ByteWriter) (int, error) { - c, err := t.read() - if err != nil { - return 0, err - } - - if c == '+' || c == '-' { - w.WriteByte(byte(c)) - if c, err = t.read(); err != nil { - return 0, err - } - } - - return t.readDigits(c, w) -} - -func (t *tokenizer) readDigits(c int, w io.ByteWriter) (int, error) { - if !isDigit(c) { - return c, nil - } - w.WriteByte(byte(c)) - - return t.readRadixDigits(isDigit, w) -} - -// ReadSymbol reads an unquoted symbol value. -func (t *tokenizer) readSymbol() (string, error) { - ret := strings.Builder{} - - c, err := t.peek() - if err != nil { - return "", err - } - - for isIdentifierPart(c) { - ret.WriteByte(byte(c)) - t.read() - c, err = t.peek() - if err != nil { - return "", err - } - } - - return ret.String(), nil -} - -// ReadQuotedSymbol reads a quoted symbol. -func (t *tokenizer) readQuotedSymbol() (string, error) { - ret := strings.Builder{} - - for { - c, err := t.read() - if err != nil { - return "", err - } - - switch c { - case -1, '\n': - return "", t.invalidChar(c) - - case '\'': - return ret.String(), nil - - case '\\': - c, err = t.peek() - if err != nil { - return "", err - } - - if c == '\n' { - t.read() - continue - } - - r, err := t.readEscapedChar(false) - if err != nil { - return "", err - } - ret.WriteRune(r) - - default: - ret.WriteByte(byte(c)) - } - } -} - -func (t *tokenizer) readOperator() (string, error) { - ret := strings.Builder{} - - c, err := t.peek() - if err != nil { - return "", err - } - - for isOperatorChar(c) { - ret.WriteByte(byte(c)) - t.read() - c, err = t.peek() - if err != nil { - return "", err - } - } - - return ret.String(), nil -} - -// ReadString reads a quoted string. -func (t *tokenizer) readString() (string, error) { - ret := strings.Builder{} - - for { - c, err := t.read() - if err != nil { - return "", err - } - - switch c { - case -1, '\n': - return "", t.invalidChar(c) - - case '"': - return ret.String(), nil - - case '\\': - c, err = t.peek() - if err != nil { - return "", err - } - - if c == '\n' { - t.read() - continue - } - - r, err := t.readEscapedChar(false) - if err != nil { - return "", err - } - ret.WriteRune(r) - - default: - ret.WriteByte(byte(c)) - } - } -} - -// ReadLongString reads a triple-quoted string. -func (t *tokenizer) readLongString() (string, error) { - ret := strings.Builder{} - - for { - c, err := t.read() - if err != nil { - return "", err - } - - switch c { - case -1: - return "", t.invalidChar(c) - - case '\'': - ok, err := t.skipEndOfLongString(t.skipCommentsHandler) - if err != nil { - return "", err - } - if ok { - return ret.String(), nil - } - - case '\\': - c, err = t.peek() - if err != nil { - return "", err - } - - if c == '\n' { - t.read() - continue - } - - r, err := t.readEscapedChar(false) - if err != nil { - return "", err - } - ret.WriteRune(r) - - default: - ret.WriteByte(byte(c)) - } - } -} - -// ReadEscapedChar reads an escaped character. -func (t *tokenizer) readEscapedChar(clob bool) (rune, error) { - // We just read the '\', grab the next char. - c, err := t.read() - if err != nil { - return 0, err - } - - switch c { - case '0': - return '\x00', nil - case 'a': - return '\a', nil - case 'b': - return '\b', nil - case 't': - return '\t', nil - case 'n': - return '\n', nil - case 'f': - return '\f', nil - case 'r': - return '\r', nil - case 'v': - return '\v', nil - case '?': - return '?', nil - case '/': - return '/', nil - case '\'': - return '\'', nil - case '"': - return '"', nil - case '\\': - return '\\', nil - case 'U': - if clob { - return 0, t.invalidChar('U') - } - return t.readHexEscapeSeq(8) - case 'u': - return t.readHexEscapeSeq(4) - case 'x': - return t.readHexEscapeSeq(2) - } - - return 0, &SyntaxError{fmt.Sprintf("bad escape sequence '\\%c'", c), t.pos - 2} -} - -func (t *tokenizer) readHexEscapeSeq(len int) (rune, error) { - val := rune(0) - - for len > 0 { - c, err := t.read() - if err != nil { - return 0, err - } - - d, err := t.fromHex(c) - if err != nil { - return 0, err - } - - val = (val << 4) | rune(d) - len-- - } - - return val, nil -} - -func (t *tokenizer) fromHex(c int) (int, error) { - if c >= '0' && c <= '9' { - return c - '0', nil - } - if c >= 'a' && c <= 'f' { - return 10 + (c - 'a'), nil - } - if c >= 'A' && c <= 'F' { - return 10 + (c - 'A'), nil - } - return 0, t.invalidChar(c) -} - -func (t *tokenizer) readBinary() (string, error) { - isB := func(c int) bool { - return c == 'b' || c == 'B' - } - isDigit := func(c int) bool { - return c == '0' || c == '1' - } - return t.readRadix(isB, isDigit) -} - -func (t *tokenizer) readHex() (string, error) { - isX := func(c int) bool { - return c == 'x' || c == 'X' - } - return t.readRadix(isX, isHexDigit) -} - -func (t *tokenizer) readRadix(pok, dok matcher) (string, error) { - w := strings.Builder{} - - c, err := t.read() - if err != nil { - return "", err - } - - if c == '-' { - w.WriteByte('-') - c, err = t.read() - if err != nil { - return "", err - } - } - - if c != '0' { - return "", t.invalidChar(c) - } - w.WriteByte('0') - - c, err = t.read() - if err != nil { - return "", err - } - if !pok(c) { - return "", t.invalidChar(c) - } - w.WriteByte(byte(c)) - - c, err = t.readRadixDigits(dok, &w) - if err != nil { - return "", err - } - - ok, err := t.isStopChar(c) - if err != nil { - return "", err - } - if !ok { - return "", t.invalidChar(c) - } - t.unread(c) - - return w.String(), nil -} - -func (t *tokenizer) readRadixDigits(dok matcher, w io.ByteWriter) (int, error) { - var c int - var err error - - for { - c, err = t.read() - if err != nil { - return 0, err - } - if c == '_' { - continue - } - if !dok(c) { - return c, nil - } - w.WriteByte(byte(c)) - } -} - -func (t *tokenizer) readTimestamp() (string, error) { - w := strings.Builder{} - - c, err := t.readTimestampDigits(4, &w) - if err != nil { - return "", err - } - if c == 'T' { - // yyyyT - w.WriteByte('T') - return w.String(), nil - } - if c != '-' { - return "", t.invalidChar(c) - } - w.WriteByte('-') - - if c, err = t.readTimestampDigits(2, &w); err != nil { - return "", err - } - if c == 'T' { - // yyyy-mmT - w.WriteByte('T') - return w.String(), nil - } - if c != '-' { - return "", t.invalidChar(c) - } - w.WriteByte('-') - - if c, err = t.readTimestampDigits(2, &w); err != nil { - return "", err - } - if c != 'T' { - // yyyy-mm-dd - return t.readTimestampFinish(c, &w) - } - w.WriteByte('T') - - if c, err = t.read(); err != nil { - return "", err - } - if !isDigit(c) { - // yyyy-mm-ddT(+hh:mm)? - if c, err = t.readTimestampOffset(c, &w); err != nil { - return "", err - } - return t.readTimestampFinish(c, &w) - } - w.WriteByte(byte(c)) - - if c, err = t.readTimestampDigits(1, &w); err != nil { - return "", err - } - if c != ':' { - return "", t.invalidChar(c) - } - w.WriteByte(':') - - if c, err = t.readTimestampDigits(2, &w); err != nil { - return "", err - } - if c != ':' { - // yyyy-mm-ddThh:mmZ - if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { - return "", err - } - return t.readTimestampFinish(c, &w) - } - w.WriteByte(':') - - if c, err = t.readTimestampDigits(2, &w); err != nil { - return "", err - } - if c != '.' { - // yyyy-mm-ddThh:mm:ssZ - if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { - return "", err - } - return t.readTimestampFinish(c, &w) - } - w.WriteByte('.') - - // yyyy-mm-ddThh:mm:ss.ssssZ - if c, err = t.read(); err != nil { - return "", err - } - if isDigit(c) { - if c, err = t.readDigits(c, &w); err != nil { - return "", err - } - } - - if c, err = t.readTimestampOffsetOrZ(c, &w); err != nil { - return "", err - } - return t.readTimestampFinish(c, &w) -} - -func (t *tokenizer) readTimestampOffsetOrZ(c int, w io.ByteWriter) (int, error) { - if c == '-' || c == '+' { - return t.readTimestampOffset(c, w) - } - if c == 'z' || c == 'Z' { - w.WriteByte(byte(c)) - return t.read() - } - return 0, t.invalidChar(c) -} - -func (t *tokenizer) readTimestampOffset(c int, w io.ByteWriter) (int, error) { - if c != '-' && c != '+' { - return c, nil - } - w.WriteByte(byte(c)) - - c, err := t.readTimestampDigits(2, w) - if err != nil { - return 0, err - } - if c != ':' { - return 0, t.invalidChar(c) - } - w.WriteByte(':') - return t.readTimestampDigits(2, w) -} - -func (t *tokenizer) readTimestampDigits(n int, w io.ByteWriter) (int, error) { - for n > 0 { - c, err := t.read() - if err != nil { - return 0, err - } - if !isDigit(c) { - return 0, t.invalidChar(c) - } - w.WriteByte(byte(c)) - n-- - } - return t.read() -} - -func (t *tokenizer) readTimestampFinish(c int, w fmt.Stringer) (string, error) { - ok, err := t.isStopChar(c) - if err != nil { - return "", err - } - if !ok { - return "", t.invalidChar(c) - } - t.unread(c) - return w.String(), nil -} - -func (t *tokenizer) ReadBlob() (string, error) { - w := strings.Builder{} - - var ( - c int - err error - ) - - for { - if c, _, err = t.skipLobWhitespace(); err != nil { - return "", err - } - if c == -1 { - return "", t.invalidChar(c) - } - if c == '}' { - break - } - w.WriteByte(byte(c)) - } - - if c, err = t.read(); err != nil { - return "", err - } - if c != '}' { - return "", t.invalidChar(c) - } - - t.unfinished = false - return w.String(), nil -} - -func (t *tokenizer) ReadShortClob() (string, error) { - str, err := t.readString() - if err != nil { - return "", err - } - - c, _, err := t.skipLobWhitespace() - if err != nil { - return "", err - } - if c != '}' { - return "", t.invalidChar(c) - } - - if c, err = t.read(); err != nil { - return "", err - } - if c != '}' { - return "", t.invalidChar(c) - } - - t.unfinished = false - return str, nil -} - -func (t *tokenizer) ReadLongClob() (string, error) { - str, err := t.readLongString() - if err != nil { - return "", err - } - - c, _, err := t.skipLobWhitespace() - if err != nil { - return "", err - } - if c != '}' { - return "", t.invalidChar(c) - } - - if c, err = t.read(); err != nil { - return "", err - } - if c != '}' { - return "", t.invalidChar(c) - } - - t.unfinished = false - return str, nil -} - -// IsTripleQuote returns true if this is a triple-quote sequence ('''). -func (t *tokenizer) IsTripleQuote() (bool, error) { - // We've just read a '\'', check if the next two are too. - cs, err := t.peekN(2) - if err == io.EOF { - return false, nil - } - if err != nil { - return false, err - } - - if cs[0] == '\'' && cs[1] == '\'' { - t.skipN(2) - return true, nil - } - - return false, nil -} - -// IsInf returns true if the given character begins a '+inf' or -// '-inf' keyword. -func (t *tokenizer) isInf(c int) (bool, error) { - if c != '+' && c != '-' { - return false, nil - } - - cs, err := t.peekN(5) - if err != nil && err != io.EOF { - return false, err - } - - if len(cs) < 3 || cs[0] != 'i' || cs[1] != 'n' || cs[2] != 'f' { - // Definitely not +-inf. - return false, nil - } - - if len(cs) == 3 || isStopChar(cs[3]) { - // Cleanly-terminated +-inf. - t.skipN(3) - return true, nil - } - - if cs[3] == '/' && len(cs) > 4 && (cs[4] == '/' || cs[4] == '*') { - t.skipN(3) - // +-inf followed immediately by a comment works too. - return true, nil - } - - return false, nil -} - -// ScanForNumericType attempts to determine what type of number we -// have by peeking at a fininte number of characters. We can rule -// out binary (0b...), hex (0x...), and timestamps (....-) via this -// method. There are a couple other cases where we *could* distinguish, -// but it's unclear that it's worth it. -func (t *tokenizer) scanForNumericType(c int) (token, error) { - if !isDigit(c) { - panic("scanForNumericType with non-digit") - } - - cs, err := t.peekN(4) - if err != nil && err != io.EOF { - return tokenError, err - } - - if c == '0' && len(cs) > 0 { - switch { - case cs[0] == 'b' || cs[0] == 'B': - return tokenBinary, nil - - case cs[0] == 'x' || cs[0] == 'X': - return tokenHex, nil - } - } - - if len(cs) >= 4 { - if isDigit(cs[0]) && isDigit(cs[1]) && isDigit(cs[2]) { - if cs[3] == '-' || cs[3] == 'T' { - return tokenTimestamp, nil - } - } - } - - // Can't tell yet; wait until actually reading it to find out. - return tokenNumber, nil -} - -// Is this character a valid way to end a 'normal' (unquoted) value? -// Peeks in case of '/', so don't call it with a character you've -// peeked. -func (t *tokenizer) isStopChar(c int) (bool, error) { - if isStopChar(c) { - return true, nil - } - - if c == '/' { - c2, err := t.peek() - if err != nil { - return false, err - } - if c2 == '/' || c2 == '*' { - // Comment, also all done. - return true, nil - } - } - - return false, nil -} - -type matcher func(int) bool - -// Expect reads a byte of input and asserts that it matches some -// condition, returning an error if it does not. -func (t *tokenizer) expect(f matcher) error { - c, err := t.read() - if err != nil { - return err - } - if !f(c) { - return t.invalidChar(c) - } - return nil -} - -// InvalidChar returns an error complaining that the given character was -// unexpected. -func (t *tokenizer) invalidChar(c int) error { - if c == -1 { - return &UnexpectedEOFError{t.pos - 1} - } - return &UnexpectedRuneError{rune(c), t.pos - 1} -} - -// SkipN skips over the next n bytes of input. Presumably you've -// already peeked at them, and decided they're not worth keeping. -func (t *tokenizer) skipN(n int) error { - for i := 0; i < n; i++ { - c, err := t.read() - if err != nil { - return err - } - if c == -1 { - break - } - } - return nil -} - -// PeekN peeks at the next n bytes of input. Unlike read/peek, does -// NOT return -1 to indicate EOF. If it cannot peek N bytes ahead -// because of an EOF (or other error), it returns the bytes it was -// able to peek at along with the error. -func (t *tokenizer) peekN(n int) ([]int, error) { - var ret []int - var err error - - // Read ahead. - for i := 0; i < n; i++ { - var c int - c, err = t.read() - if err != nil { - break - } - if c == -1 { - err = io.EOF - break - } - ret = append(ret, c) - } - - // Put back the ones we got. - if err == io.EOF { - t.unread(-1) - } - for i := len(ret) - 1; i >= 0; i-- { - t.unread(ret[i]) - } - - return ret, err -} - -// Peek at the next byte of input without removing it. Other conditions -// from Read all apply. -func (t *tokenizer) peek() (int, error) { - if len(t.buffer) > 0 { - // Short-circuit and peek from the buffer. - return t.buffer[len(t.buffer)-1], nil - } - - c, err := t.read() - if err != nil { - return 0, err - } - - t.unread(c) - return c, nil -} - -// Read reads a byte of input from the underlying reader. EOF is -// returned as (-1, nil) rather than (0, io.EOF), because I find it -// easier to reason about that way. Newlines are normalized to '\n'. -func (t *tokenizer) read() (int, error) { - t.pos++ - if len(t.buffer) > 0 { - // We've already peeked ahead; read from our buffer. - c := t.buffer[len(t.buffer)-1] - t.buffer = t.buffer[:len(t.buffer)-1] - return c, nil - } - - c, err := t.in.ReadByte() - if err == io.EOF { - return -1, nil - } - if err != nil { - return 0, &IOError{err} - } - - // Normalize \r and \r\n to just \n. - if c == '\r' { - cs, err := t.in.Peek(1) - if err != nil && err != io.EOF { - // Not EOF, because we haven't dealt with the '\r' yet. - return 0, &IOError{err} - } - if len(cs) > 0 && cs[0] == '\n' { - // Skip over the '\n' as well. - t.in.ReadByte() - } - return '\n', nil - } - - return int(c), nil -} - -// Unread pushes a character (or -1) back into the input stream to -// be read again later. -func (t *tokenizer) unread(c int) { - t.pos-- - t.buffer = append(t.buffer, c) -} diff --git a/tokenizer_test.go b/tokenizer_test.go deleted file mode 100644 index f4af1584..00000000 --- a/tokenizer_test.go +++ /dev/null @@ -1,571 +0,0 @@ -package ion - -import ( - "io" - "testing" -) - -func TestNext(t *testing.T) { - tok := tokenizeString("foo::'foo':[] 123, {})") - - next := func(tt token) { - if err := tok.Next(); err != nil { - t.Fatal(err) - } - if tok.Token() != tt { - t.Fatalf("expected %v, got %v", tt, tok.Token()) - } - } - - next(tokenSymbol) - next(tokenDoubleColon) - next(tokenSymbolQuoted) - next(tokenColon) - next(tokenOpenBracket) - next(tokenNumber) - next(tokenComma) - next(tokenOpenBrace) -} - -func TestReadSymbol(t *testing.T) { - test := func(str string, expected string, next token) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - if err := tok.Next(); err != nil { - t.Fatal(err) - } - - if tok.Token() != tokenSymbol { - t.Fatal("not a symbol") - } - - actual, err := tok.readSymbol() - if err != nil { - t.Fatal(err) - } - - if actual != expected { - t.Errorf("expected '%v', got '%v'", expected, actual) - } - - if err := tok.Next(); err != nil { - t.Fatal(err) - } - if tok.Token() != next { - t.Errorf("expected next=%v, got next=%v", next, tok.Token()) - } - }) - } - - test("a", "a", tokenEOF) - test("abc", "abc", tokenEOF) - test("null +inf", "null", tokenFloatInf) - test("false,", "false", tokenComma) - test("nan]", "nan", tokenCloseBracket) -} - -func TestReadSymbols(t *testing.T) { - tok := tokenizeString("foo bar baz beep boop null") - expected := []string{"foo", "bar", "baz", "beep", "boop", "null"} - - for i := 0; i < len(expected); i++ { - if err := tok.Next(); err != nil { - t.Fatal(err) - } - if tok.Token() != tokenSymbol { - t.Fatalf("expected %v, got %v", tokenSymbol, tok.Token()) - } - - val, err := tok.readSymbol() - if err != nil { - t.Fatal(err) - } - - if val != expected[i] { - t.Errorf("expected %v, got %v", expected[i], val) - } - } -} - -func TestReadQuotedSymbol(t *testing.T) { - test := func(str string, expected string, next int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - if err := tok.Next(); err != nil { - t.Fatal(err) - } - - if tok.Token() != tokenSymbolQuoted { - t.Fatal("not a quoted symbol") - } - - actual, err := tok.readQuotedSymbol() - if err != nil { - t.Fatal(err) - } - - if actual != expected { - t.Errorf("expected '%v', got '%v'", expected, actual) - } - - c, err := tok.read() - if err != nil { - t.Fatal(err) - } - if c != next { - t.Errorf("expected next=%q, got next=%q", next, c) - } - }) - } - - test("'a'", "a", -1) - test("'a b c'", "a b c", -1) - test("'null' ", "null", ' ') - test("'false',", "false", ',') - test("'nan']", "nan", ']') - - test("'a\\'b'", "a'b", -1) - test("'a\\\nb'", "ab", -1) - test("'a\\\\b'", "a\\b", -1) - test("'a\x20b'", "a b", -1) - test("'a\\u2248b'", "a≈b", -1) - test("'a\\U0001F44Db'", "a👍b", -1) -} - -func TestReadTimestamp(t *testing.T) { - test := func(str string, eval string, next int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - if err := tok.Next(); err != nil { - t.Fatal(err) - } - if tok.Token() != tokenTimestamp { - t.Fatalf("unexpected token %v", tok.Token()) - } - - val, err := tok.ReadValue(tokenTimestamp) - if err != nil { - t.Fatal(err) - } - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - - c, err := tok.read() - if err != nil { - t.Fatal(err) - } - if c != next { - t.Errorf("expected %q, got %q", next, c) - } - }) - } - - test("2001T", "2001T", -1) - test("2001-01T,", "2001-01T", ',') - test("2001-01-02}", "2001-01-02", '}') - test("2001-01-02T ", "2001-01-02T", ' ') - test("2001-01-02T+00:00\t", "2001-01-02T+00:00", '\t') - test("2001-01-02T-00:00\n", "2001-01-02T-00:00", '\n') - test("2001-01-02T03:04+00:00 ", "2001-01-02T03:04+00:00", ' ') - test("2001-01-02T03:04-00:00 ", "2001-01-02T03:04-00:00", ' ') - test("2001-01-02T03:04Z ", "2001-01-02T03:04Z", ' ') - test("2001-01-02T03:04z ", "2001-01-02T03:04z", ' ') - test("2001-01-02T03:04:05Z ", "2001-01-02T03:04:05Z", ' ') - test("2001-01-02T03:04:05+00:00 ", "2001-01-02T03:04:05+00:00", ' ') - test("2001-01-02T03:04:05.666Z ", "2001-01-02T03:04:05.666Z", ' ') - test("2001-01-02T03:04:05.666666z ", "2001-01-02T03:04:05.666666z", ' ') -} - -func TestIsTripleQuote(t *testing.T) { - test := func(str string, eok bool, next int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - - ok, err := tok.IsTripleQuote() - if err != nil { - t.Fatal(err) - } - if ok != eok { - t.Errorf("expected ok=%v, got ok=%v", eok, ok) - } - - read(t, tok, next) - }) - } - - test("''string'''", true, 's') - test("'string'''", false, '\'') - test("'", false, '\'') - test("", false, -1) -} - -func TestIsInf(t *testing.T) { - test := func(str string, eok bool, next int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - c, err := tok.read() - if err != nil { - t.Fatal(err) - } - - ok, err := tok.isInf(c) - if err != nil { - t.Fatal(err) - } - - if ok != eok { - t.Errorf("expected %v, got %v", eok, ok) - } - - c, err = tok.read() - if err != nil { - t.Fatal(err) - } - if c != next { - t.Errorf("expected '%c', got '%c'", next, c) - } - }) - } - - test("+inf", true, -1) - test("-inf", true, -1) - test("+inf ", true, ' ') - test("-inf\t", true, '\t') - test("-inf\n", true, '\n') - test("+inf,", true, ',') - test("-inf}", true, '}') - test("+inf)", true, ')') - test("-inf]", true, ']') - test("+inf//", true, '/') - test("+inf/*", true, '/') - - test("+inf/", false, 'i') - test("-inf/0", false, 'i') - test("+int", false, 'i') - test("-iot", false, 'i') - test("+unf", false, 'u') - test("_inf", false, 'i') - - test("-in", false, 'i') - test("+i", false, 'i') - test("+", false, -1) - test("-", false, -1) -} - -func TestScanForNumericType(t *testing.T) { - test := func(str string, ett token) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - c, err := tok.read() - if err != nil { - t.Fatal(err) - } - - tt, err := tok.scanForNumericType(c) - if err != nil { - t.Fatal(err) - } - if tt != ett { - t.Errorf("expected %v, got %v", ett, tt) - } - }) - } - - test("0b0101", tokenBinary) - test("0B", tokenBinary) - test("0xABCD", tokenHex) - test("0X", tokenHex) - test("0000-00-00", tokenTimestamp) - test("0000T", tokenTimestamp) - - test("0", tokenNumber) - test("1b0101", tokenNumber) - test("1B", tokenNumber) - test("1x0101", tokenNumber) - test("1X", tokenNumber) - test("1234", tokenNumber) - test("12345", tokenNumber) - test("1,23T", tokenNumber) - test("12,3T", tokenNumber) - test("123,T", tokenNumber) -} - -func TestSkipWhitespace(t *testing.T) { - test := func(str string, eok bool, ec int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - c, ok, err := tok.skipWhitespace() - if err != nil { - t.Fatal(err) - } - - if ok != eok { - t.Errorf("expected ok=%v, got ok=%v", eok, ok) - } - if c != ec { - t.Errorf("expected c='%c', got c='%c'", ec, c) - } - }) - } - - test("/ 0)", false, '/') - test("xyz_", false, 'x') - test(" / 0)", true, '/') - test(" xyz_", true, 'x') - test(" \t\r\n / 0)", true, '/') - test("\t\t // comment\t\r\n\t\t x", true, 'x') - test(" \r\n /* comment *//* \r\n comment */x", true, 'x') -} - -func TestSkipLobWhitespace(t *testing.T) { - test := func(str string, eok bool, ec int) { - t.Run(str, func(t *testing.T) { - tok := tokenizeString(str) - c, ok, err := tok.skipLobWhitespace() - if err != nil { - t.Fatal(err) - } - - if ok != eok { - t.Errorf("expected ok=%v, got ok=%v", eok, ok) - } - if c != ec { - t.Errorf("expected c='%c', got c='%c'", ec, c) - } - }) - } - - test("///=", false, '/') - test("xyz_", false, 'x') - test(" ///=", true, '/') - test(" xyz_", true, 'x') - test("\r\n\t///=", true, '/') - test("\r\n\txyz_", true, 'x') -} - -func TestSkipCommentsHandler(t *testing.T) { - t.Run("SingleLine", func(t *testing.T) { - tok := tokenizeString("/comment\nok") - ok, err := tok.skipCommentsHandler() - if err != nil { - t.Fatal(err) - } - if !ok { - t.Error("expected ok=true, got ok=false") - } - - read(t, tok, 'o') - read(t, tok, 'k') - read(t, tok, -1) - }) - - t.Run("Block", func(t *testing.T) { - tok := tokenizeString("*comm\nent*/ok") - ok, err := tok.skipCommentsHandler() - if err != nil { - t.Fatal(err) - } - if !ok { - t.Error("expected ok=true, got ok=false") - } - - read(t, tok, 'o') - read(t, tok, 'k') - read(t, tok, -1) - }) - - t.Run("FalseAlarm", func(t *testing.T) { - tok := tokenizeString(" 0)") - ok, err := tok.skipCommentsHandler() - if err != nil { - t.Fatal(err) - } - if ok { - t.Error("expected ok=false, got ok=true") - } - - read(t, tok, ' ') - read(t, tok, '0') - read(t, tok, ')') - read(t, tok, -1) - }) -} - -func TestSkipSingleLineComment(t *testing.T) { - tok := tokenizeString("single-line comment\r\nok") - err := tok.skipSingleLineComment() - if err != nil { - t.Fatal(err) - } - - read(t, tok, 'o') - read(t, tok, 'k') - read(t, tok, -1) -} - -func TestSkipSingleLineCommentOnLastLine(t *testing.T) { - tok := tokenizeString("single-line comment") - err := tok.skipSingleLineComment() - if err != nil { - t.Fatal(err) - } - - read(t, tok, -1) -} - -func TestSkipBlockComment(t *testing.T) { - tok := tokenizeString("this is/ a\nmulti-line /** comment.**/ok") - err := tok.skipBlockComment() - if err != nil { - t.Fatal(err) - } - - read(t, tok, 'o') - read(t, tok, 'k') - read(t, tok, -1) -} - -func TestSkipInvalidBlockComment(t *testing.T) { - tok := tokenizeString("this is a comment that never ends") - err := tok.skipBlockComment() - if err == nil { - t.Error("did not fail on bad block comment") - } -} - -func TestPeekN(t *testing.T) { - tok := tokenizeString("abc\r\ndef") - - peekN(t, tok, 1, nil, 'a') - peekN(t, tok, 2, nil, 'a', 'b') - peekN(t, tok, 3, nil, 'a', 'b', 'c') - - read(t, tok, 'a') - read(t, tok, 'b') - - peekN(t, tok, 3, nil, 'c', '\n', 'd') - peekN(t, tok, 2, nil, 'c', '\n') - peekN(t, tok, 3, nil, 'c', '\n', 'd') - - read(t, tok, 'c') - read(t, tok, '\n') - read(t, tok, 'd') - - peekN(t, tok, 3, io.EOF, 'e', 'f') - peekN(t, tok, 3, io.EOF, 'e', 'f') - peekN(t, tok, 2, nil, 'e', 'f') - - read(t, tok, 'e') - read(t, tok, 'f') - read(t, tok, -1) - - peekN(t, tok, 10, io.EOF) -} - -func peekN(t *testing.T, tok *tokenizer, n int, ee error, ecs ...int) { - cs, err := tok.peekN(n) - if err != ee { - t.Fatalf("expected err=%v, got err=%v", ee, err) - } - if !equal(ecs, cs) { - t.Errorf("expected %v, got %v", ecs, cs) - } -} - -func equal(a, b []int) bool { - if len(a) != len(b) { - return false - } - - for i := range a { - if a[i] != b[i] { - return false - } - } - - return true -} - -func TestPeek(t *testing.T) { - tok := tokenizeString("abc") - - peek(t, tok, 'a') - peek(t, tok, 'a') - read(t, tok, 'a') - - peek(t, tok, 'b') - tok.unread('a') - - peek(t, tok, 'a') - read(t, tok, 'a') - read(t, tok, 'b') - peek(t, tok, 'c') - peek(t, tok, 'c') - - read(t, tok, 'c') - peek(t, tok, -1) - peek(t, tok, -1) - read(t, tok, -1) -} - -func peek(t *testing.T, tok *tokenizer, expected int) { - c, err := tok.peek() - if err != nil { - t.Fatal(err) - } - if c != expected { - t.Errorf("expected %v, got %v", expected, c) - } -} - -func TestReadUnread(t *testing.T) { - tok := tokenizeString("abc\rd\ne\r\n") - - read(t, tok, 'a') - tok.unread('a') - - read(t, tok, 'a') - read(t, tok, 'b') - read(t, tok, 'c') - tok.unread('c') - tok.unread('b') - - read(t, tok, 'b') - read(t, tok, 'c') - read(t, tok, '\n') - tok.unread('\n') - - read(t, tok, '\n') - read(t, tok, 'd') - read(t, tok, '\n') - read(t, tok, 'e') - read(t, tok, '\n') - read(t, tok, -1) - - tok.unread(-1) - tok.unread('\n') - - read(t, tok, '\n') - read(t, tok, -1) - read(t, tok, -1) -} - -func TestTokenToString(t *testing.T) { - for i := tokenError; i <= tokenCloseDoubleBrace+1; i++ { - str := i.String() - if str == "" { - t.Errorf("expected non-empty string for token %v", int(i)) - } - } -} - -func read(t *testing.T, tok *tokenizer, expected int) { - c, err := tok.read() - if err != nil { - t.Fatal(err) - } - if c != expected { - t.Errorf("expected %v, got %v", expected, c) - } -} diff --git a/type.go b/type.go deleted file mode 100644 index f0165909..00000000 --- a/type.go +++ /dev/null @@ -1,125 +0,0 @@ -package ion - -import "fmt" - -// A Type represents the type of an Ion Value. -type Type uint8 - -const ( - // NoType is returned by a Reader that is not currently pointing at a value. - NoType Type = iota - - // NullType is the type of the (unqualified) Ion null value. - NullType - - // BoolType is the type of an Ion boolean, true or false. - BoolType - - // IntType is the type of a signed Ion integer of arbitrary size. - IntType - - // FloatType is the type of a fixed-precision Ion floating-point value. - FloatType - - // DecimalType is the type of an arbitrary-precision Ion decimal value. - DecimalType - - // TimestampType is the type of an arbitrary-precision Ion timestamp. - TimestampType - - // SymbolType is the type of an Ion symbol, mapped to an integer ID by a SymbolTable - // to (potentially) save space. - SymbolType - - // StringType is the type of a non-symbol Unicode string, represented directly. - StringType - - // ClobType is the type of a character large object. Like a BlobType, it stores an - // arbitrary sequence of bytes, but it represents them in text form as an escaped-ASCII - // string rather than a base64-encoded string. - ClobType - - // BlobType is the type of a binary large object; a sequence of arbitrary bytes. - BlobType - - // ListType is the type of a list, recursively containing zero or more Ion values. - ListType - - // SexpType is the type of an s-expression. Like a ListType, it contains a sequence - // of zero or more Ion values, but with a lisp-like syntax when encoded as text. - SexpType - - // StructType is the type of a structure, recursively containing a sequence of named - // (by an Ion symbol) Ion values. - StructType -) - -// String implements fmt.Stringer for Type. -func (t Type) String() string { - switch t { - case NoType: - return "" - case NullType: - return "null" - case BoolType: - return "bool" - case IntType: - return "int" - case FloatType: - return "float" - case DecimalType: - return "decimal" - case TimestampType: - return "timestamp" - case StringType: - return "string" - case SymbolType: - return "symbol" - case BlobType: - return "blob" - case ClobType: - return "clob" - case StructType: - return "struct" - case ListType: - return "list" - case SexpType: - return "sexp" - default: - return fmt.Sprintf("", uint8(t)) - } -} - -// IntSize represents the size of an integer. -type IntSize uint8 - -const ( - // NullInt is the size of null.int and other things that aren't actually ints. - NullInt IntSize = iota - // Int32 is the size of an Ion integer that can be losslessly stored in an int32. - Int32 - // Int64 is the size of an Ion integer that can be losslessly stored in an int64. - Int64 - // Uint64 is the size of an Ion integer that can be losslessly stored in a uint64. - Uint64 - // BigInt is the size of an Ion integer that can only be losslessly stored in a big.Int. - BigInt -) - -// String implements fmt.Stringer for IntSize. -func (i IntSize) String() string { - switch i { - case NullInt: - return "null.int" - case Int32: - return "int32" - case Int64: - return "int64" - case Uint64: - return "uint64" - case BigInt: - return "big.Int" - default: - return fmt.Sprintf("", uint8(i)) - } -} diff --git a/type_test.go b/type_test.go deleted file mode 100644 index e1702baa..00000000 --- a/type_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package ion - -import "testing" - -func TestTypeToString(t *testing.T) { - for i := NoType; i <= StructType+1; i++ { - str := i.String() - if str == "" { - t.Errorf("expected a non-empty string for type %v", uint8(i)) - } - } -} - -func TestIntSizeToString(t *testing.T) { - for i := NullInt; i <= BigInt+1; i++ { - str := i.String() - if str == "" { - t.Errorf("expected a non-empty string for size %v", uint8(i)) - } - } -} diff --git a/unmarshal.go b/unmarshal.go deleted file mode 100644 index e5e61090..00000000 --- a/unmarshal.go +++ /dev/null @@ -1,673 +0,0 @@ -package ion - -import ( - "bytes" - "errors" - "fmt" - "io" - "math/big" - "reflect" - "strconv" - "strings" -) - -var ( - // ErrNoInput is returned when there is no input to decode - ErrNoInput = errors.New("ion: no input to decode") -) - -// Unmarshal unmarshals Ion data to the given object. -func Unmarshal(data []byte, v interface{}) error { - return NewDecoder(NewReader(bytes.NewReader(data))).DecodeTo(v) -} - -// UnmarshalStr unmarshals Ion data from a string to the given object. -func UnmarshalStr(data string, v interface{}) error { - return Unmarshal([]byte(data), v) -} - -// UnmarshalFrom unmarshal Ion data from a reader to the given object. -func UnmarshalFrom(r Reader, v interface{}) error { - d := Decoder{ - r: r, - } - return d.DecodeTo(v) -} - -// A Decoder decodes go values from an Ion reader. -type Decoder struct { - r Reader -} - -// NewDecoder creates a new decoder. -func NewDecoder(r Reader) *Decoder { - return &Decoder{ - r: r, - } -} - -// NewTextDecoder creates a new text decoder. Well, a decoder that uses a reader with -// no shared symbol tables, it'll work to read binary too if the binary doesn't reference -// any shared symbol tables. -func NewTextDecoder(in io.Reader) *Decoder { - return NewDecoder(NewReader(in)) -} - -// Decode decodes a value from the underlying Ion reader without any expectations -// about what it's going to get. Structs become map[string]interface{}s, Lists and -// Sexps become []interface{}s. -func (d *Decoder) Decode() (interface{}, error) { - if !d.r.Next() { - if d.r.Err() != nil { - return nil, d.r.Err() - } - return nil, ErrNoInput - } - - return d.decode() -} - -// Helper form of Decode for when you've already called Next. -func (d *Decoder) decode() (interface{}, error) { - if d.r.IsNull() { - return nil, nil - } - - switch d.r.Type() { - case BoolType: - return d.r.BoolValue() - - case IntType: - return d.decodeInt() - - case FloatType: - return d.r.FloatValue() - - case DecimalType: - return d.r.DecimalValue() - - case TimestampType: - return d.r.TimeValue() - - case StringType, SymbolType: - return d.r.StringValue() - - case BlobType, ClobType: - return d.r.ByteValue() - - case StructType: - return d.decodeMap() - - case ListType, SexpType: - return d.decodeSlice() - - default: - panic("wat?") - } -} - -func (d *Decoder) decodeInt() (interface{}, error) { - size, err := d.r.IntSize() - if err != nil { - return nil, err - } - - switch size { - case NullInt: - return nil, nil - case Int32: - return d.r.IntValue() - case Int64: - return d.r.Int64Value() - default: - return d.r.BigIntValue() - } -} - -// DecodeMap decodes an Ion struct to a go map. -func (d *Decoder) decodeMap() (map[string]interface{}, error) { - if err := d.r.StepIn(); err != nil { - return nil, err - } - - result := map[string]interface{}{} - - for d.r.Next() { - name := d.r.FieldName() - value, err := d.decode() - if err != nil { - return nil, err - } - result[name] = value - } - - if err := d.r.StepOut(); err != nil { - return nil, err - } - - return result, nil -} - -// DecodeSlice decodes an Ion list or sexp to a go slice. -func (d *Decoder) decodeSlice() ([]interface{}, error) { - if err := d.r.StepIn(); err != nil { - return nil, err - } - - result := []interface{}{} - - for d.r.Next() { - value, err := d.decode() - if err != nil { - return nil, err - } - result = append(result, value) - } - - if err := d.r.StepOut(); err != nil { - return nil, err - } - - return result, nil -} - -// DecodeTo decodes an Ion value from the underlying Ion reader into the -// value provided. -func (d *Decoder) DecodeTo(v interface{}) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr { - return errors.New("ion: v must be a pointer") - } - if rv.IsNil() { - return errors.New("ion: v must not be nil") - } - - if !d.r.Next() { - if d.r.Err() != nil { - return d.r.Err() - } - return ErrNoInput - } - - return d.decodeTo(rv) -} - -func (d *Decoder) decodeTo(v reflect.Value) error { - if !v.IsValid() { - // Don't actually have anywhere to put this value; skip it. - return nil - } - - isNull := d.r.IsNull() - v = indirect(v, isNull) - if isNull { - v.Set(reflect.Zero(v.Type())) - return nil - } - - switch d.r.Type() { - case BoolType: - return d.decodeBoolTo(v) - - case IntType: - return d.decodeIntTo(v) - - case FloatType: - return d.decodeFloatTo(v) - - case DecimalType: - return d.decodeDecimalTo(v) - - case TimestampType: - return d.decodeTimestampTo(v) - - case StringType, SymbolType: - return d.decodeStringTo(v) - - case BlobType, ClobType: - return d.decodeLobTo(v) - - case StructType: - return d.decodeStructTo(v) - - case ListType, SexpType: - return d.decodeSliceTo(v) - - default: - panic("wat?") - } -} - -func (d *Decoder) decodeBoolTo(v reflect.Value) error { - val, err := d.r.BoolValue() - if err != nil { - return err - } - - switch v.Kind() { - case reflect.Bool: - // Too easy. - v.SetBool(val) - return nil - - case reflect.Interface: - if v.NumMethod() == 0 { - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode bool to %v", v.Type().String()) -} - -var bigIntType = reflect.TypeOf(big.Int{}) - -func (d *Decoder) decodeIntTo(v reflect.Value) error { - switch v.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - val, err := d.r.Int64Value() - if err != nil { - return err - } - if v.OverflowInt(val) { - return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) - } - v.SetInt(val) - return nil - - case reflect.Uint8, reflect.Uint16, reflect.Uint32: - val, err := d.r.Int64Value() - if err != nil { - return err - } - if val < 0 || v.OverflowUint(uint64(val)) { - return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) - } - v.SetUint(uint64(val)) - return nil - - case reflect.Uint, reflect.Uint64, reflect.Uintptr: - val, err := d.r.BigIntValue() - if err != nil { - return err - } - if !val.IsUint64() { - return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) - } - uiv := val.Uint64() - if v.OverflowUint(uiv) { - return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) - } - v.SetUint(uiv) - return nil - - case reflect.Struct: - if v.Type() == bigIntType { - val, err := d.r.BigIntValue() - if err != nil { - return err - } - v.Set(reflect.ValueOf(*val)) - return nil - } - - case reflect.Interface: - if v.NumMethod() == 0 { - val, err := d.decodeInt() - if err != nil { - return err - } - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode int to %v", v.Type().String()) -} - -func (d *Decoder) decodeFloatTo(v reflect.Value) error { - val, err := d.r.FloatValue() - if err != nil { - return err - } - - switch v.Kind() { - case reflect.Float32, reflect.Float64: - if v.OverflowFloat(val) { - return fmt.Errorf("ion: value %v won't fit in type %v", val, v.Type().String()) - } - v.SetFloat(val) - return nil - - case reflect.Struct: - if v.Type() == decimalType { - flt := strconv.FormatFloat(val, 'g', -1, 64) - dec, err := ParseDecimal(strings.Replace(flt, "e", "d", 1)) - if err != nil { - return err - } - v.Set(reflect.ValueOf(*dec)) - return nil - } - - case reflect.Interface: - if v.NumMethod() == 0 { - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode float to %v", v.Type().String()) -} - -func (d *Decoder) decodeDecimalTo(v reflect.Value) error { - val, err := d.r.DecimalValue() - if err != nil { - return err - } - - switch v.Kind() { - case reflect.Struct: - if v.Type() == decimalType { - v.Set(reflect.ValueOf(*val)) - return nil - } - - case reflect.Interface: - if v.NumMethod() == 0 { - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode decimal to %v", v.Type().String()) -} - -func (d *Decoder) decodeTimestampTo(v reflect.Value) error { - val, err := d.r.TimeValue() - if err != nil { - return err - } - - switch v.Kind() { - case reflect.Struct: - if v.Type() == timeType { - v.Set(reflect.ValueOf(val)) - return nil - } - - case reflect.Interface: - if v.NumMethod() == 0 { - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode timestamp to %v", v.Type().String()) -} - -func (d *Decoder) decodeStringTo(v reflect.Value) error { - val, err := d.r.StringValue() - if err != nil { - return err - } - - switch v.Kind() { - case reflect.String: - v.SetString(val) - return nil - - case reflect.Interface: - if v.NumMethod() == 0 { - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode string to %v", v.Type().String()) -} - -func (d *Decoder) decodeLobTo(v reflect.Value) error { - val, err := d.r.ByteValue() - if err != nil { - return err - } - - switch v.Kind() { - case reflect.Slice: - if v.Type().Elem().Kind() == reflect.Uint8 { - v.SetBytes(val) - return nil - } - - case reflect.Array: - if v.Type().Elem().Kind() == reflect.Uint8 { - i := reflect.Copy(v, reflect.ValueOf(val)) - for ; i < v.Len(); i++ { - v.Index(i).SetUint(0) - } - return nil - } - - case reflect.Interface: - if v.NumMethod() == 0 { - v.Set(reflect.ValueOf(val)) - return nil - } - } - return fmt.Errorf("ion: cannot decode lob to %v", v.Type().String()) -} - -func (d *Decoder) decodeStructTo(v reflect.Value) error { - switch v.Kind() { - case reflect.Struct: - return d.decodeStructToStruct(v) - - case reflect.Map: - return d.decodeStructToMap(v) - - case reflect.Interface: - if v.NumMethod() == 0 { - m, err := d.decodeMap() - if err != nil { - return err - } - v.Set(reflect.ValueOf(m)) - return nil - } - } - return fmt.Errorf("ion: cannot decode struct to %v", v.Type().String()) -} - -func (d *Decoder) decodeStructToStruct(v reflect.Value) error { - fields := fieldsFor(v.Type()) - - if err := d.r.StepIn(); err != nil { - return err - } - - for d.r.Next() { - name := d.r.FieldName() - field := findField(fields, name) - if field != nil { - subv, err := findSubvalue(v, field) - if err != nil { - return err - } - - if err := d.decodeTo(subv); err != nil { - return err - } - } - } - - return d.r.StepOut() -} - -func findField(fields []field, name string) *field { - var f *field - for i := range fields { - ff := &fields[i] - if ff.name == name { - return ff - } - if f == nil && strings.EqualFold(ff.name, name) { - f = ff - } - } - return f -} - -func findSubvalue(v reflect.Value, f *field) (reflect.Value, error) { - for _, i := range f.path { - if v.Kind() == reflect.Ptr { - if v.IsNil() { - if !v.CanSet() { - return reflect.Value{}, fmt.Errorf("ion: cannot set embedded pointer to unexported struct: %v", v.Type().Elem()) - } - v.Set(reflect.New(v.Type().Elem())) - } - v = v.Elem() - } - v = v.Field(i) - } - return v, nil -} - -func (d *Decoder) decodeStructToMap(v reflect.Value) error { - t := v.Type() - switch t.Key().Kind() { - case reflect.String: - default: - return fmt.Errorf("ion: cannot decode struct to %v", t.String()) - } - - if v.IsNil() { - v.Set(reflect.MakeMap(t)) - } - - subv := reflect.New(t.Elem()).Elem() - - if err := d.r.StepIn(); err != nil { - return err - } - - for d.r.Next() { - name := d.r.FieldName() - if err := d.decodeTo(subv); err != nil { - return err - } - - var kv reflect.Value - switch t.Key().Kind() { - case reflect.String: - kv = reflect.ValueOf(name) - default: - panic("wat?") - } - - if kv.IsValid() { - v.SetMapIndex(kv, subv) - } - } - - return d.r.StepOut() -} - -func (d *Decoder) decodeSliceTo(v reflect.Value) error { - k := v.Kind() - - // If all we know is we need an interface{}, decode an []interface{} with - // types based on the Ion value stream. - if k == reflect.Interface && v.NumMethod() == 0 { - s, err := d.decodeSlice() - if err != nil { - return err - } - v.Set(reflect.ValueOf(s)) - return nil - } - - // Only other valid targets are arrays and slices. - if k != reflect.Array && k != reflect.Slice { - return fmt.Errorf("ion: cannot unmarshal slice to %v", v.Type().String()) - } - - if err := d.r.StepIn(); err != nil { - return err - } - - i := 0 - - // Decode values into the array or slice. - for d.r.Next() { - if v.Kind() == reflect.Slice { - // If it's a slice, we can grow it as needed. - if i >= v.Cap() { - newcap := v.Cap() + v.Cap()/2 - if newcap < 4 { - newcap = 4 - } - newv := reflect.MakeSlice(v.Type(), v.Len(), newcap) - reflect.Copy(newv, v) - v.Set(newv) - } - if i >= v.Len() { - v.SetLen(i + 1) - } - } - - if i < v.Len() { - if err := d.decodeTo(v.Index(i)); err != nil { - return err - } - } - - i++ - } - - if err := d.r.StepOut(); err != nil { - return err - } - - if i < v.Len() { - if v.Kind() == reflect.Array { - // Zero out any additional values. - z := reflect.Zero(v.Type().Elem()) - for ; i < v.Len(); i++ { - v.Index(i).Set(z) - } - } else { - v.SetLen(i) - } - } - - return nil -} - -// Dig in through any pointers to find the actual underlying value that we want -// to set. If wantPtr is false, the algorithm terminates at a non-ptr value (e.g., -// if passed an *int, it returns the int it points to, allocating such an int if the -// pointer is currently nil). If wantPtr is true, it terminates on a pointer to that -// value (allowing said pointer to be set to nil, generally). -func indirect(v reflect.Value, wantPtr bool) reflect.Value { - for { - if v.Kind() == reflect.Interface && !v.IsNil() { - e := v.Elem() - if e.Kind() == reflect.Ptr && !e.IsNil() && (!wantPtr || e.Elem().Kind() == reflect.Ptr) { - v = e - continue - } - } - - if v.Kind() != reflect.Ptr { - break - } - - if v.Elem().Kind() != reflect.Ptr && wantPtr && v.CanSet() { - break - } - - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) - } - - v = v.Elem() - } - - return v -} diff --git a/unmarshal_test.go b/unmarshal_test.go deleted file mode 100644 index e3e2d2d6..00000000 --- a/unmarshal_test.go +++ /dev/null @@ -1,539 +0,0 @@ -package ion - -import ( - "bytes" - "math" - "math/big" - "reflect" - "testing" - "time" -) - -func TestUnmarshalBool(t *testing.T) { - test := func(str string, eval bool) { - t.Run(str, func(t *testing.T) { - var val bool - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - test("null", false) - test("true", true) - test("false", false) -} -func TestUnmarshalBoolPtr(t *testing.T) { - test := func(str string, eval interface{}) { - t.Run(str, func(t *testing.T) { - var bval bool - val := &bval - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if eval == nil { - if val != nil { - t.Errorf("expected , got %v", *val) - } - } else { - switch { - case val == nil: - t.Errorf("expected %v, got ", eval) - case *val != eval.(bool): - t.Errorf("expected %v, got %v", eval, *val) - } - } - }) - } - - test("null", nil) - test("null.bool", nil) - test("false", false) - test("true", true) -} - -func TestUnmarshalInt(t *testing.T) { - testInt8 := func(str string, eval int8) { - t.Run(str, func(t *testing.T) { - var val int8 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testInt8("null", 0) - testInt8("0", 0) - testInt8("0x7F", 0x7F) - testInt8("-0x80", -0x80) - - testInt16 := func(str string, eval int16) { - t.Run(str, func(t *testing.T) { - var val int16 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testInt16("0x7F", 0x7F) - testInt16("-0x80", -0x80) - testInt16("0x7FFF", 0x7FFF) - testInt16("-0x8000", -0x8000) - - testInt32 := func(str string, eval int32) { - t.Run(str, func(t *testing.T) { - var val int32 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testInt32("0x7FFF", 0x7FFF) - testInt32("-0x8000", -0x8000) - testInt32("0x7FFFFFFF", 0x7FFFFFFF) - testInt32("-0x80000000", -0x80000000) - - testInt := func(str string, eval int) { - t.Run(str, func(t *testing.T) { - var val int - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testInt("0x7FFF", 0x7FFF) - testInt("-0x8000", -0x8000) - testInt("0x7FFFFFFF", 0x7FFFFFFF) - testInt("-0x80000000", -0x80000000) - - testInt64 := func(str string, eval int64) { - t.Run(str, func(t *testing.T) { - var val int64 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testInt64("0x7FFFFFFF", 0x7FFFFFFF) - testInt64("-0x80000000", -0x80000000) - testInt64("0x7FFFFFFFFFFFFFFF", 0x7FFFFFFFFFFFFFFF) - testInt64("-0x8000000000000000", -0x8000000000000000) -} - -func TestUnmarshalUint(t *testing.T) { - testUint8 := func(str string, eval uint8) { - t.Run(str, func(t *testing.T) { - var val uint8 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testUint8("null", 0) - testUint8("0", 0) - testUint8("0xFF", 0xFF) - - testUint16 := func(str string, eval uint16) { - t.Run(str, func(t *testing.T) { - var val uint16 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testUint16("0xFF", 0xFF) - testUint16("0xFFFF", 0xFFFF) - - testUint32 := func(str string, eval uint32) { - t.Run(str, func(t *testing.T) { - var val uint32 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testUint32("0xFFFF", 0xFFFF) - testUint32("0xFFFFFFFF", 0xFFFFFFFF) - - testUint := func(str string, eval uint) { - t.Run(str, func(t *testing.T) { - var val uint - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testUint("0xFFFF", 0xFFFF) - testUint("0xFFFFFFFF", 0xFFFFFFFF) - - testUintptr := func(str string, eval uintptr) { - t.Run(str, func(t *testing.T) { - var val uintptr - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testUintptr("0xFFFF", 0xFFFF) - testUintptr("0xFFFFFFFF", 0xFFFFFFFF) - - testUint64 := func(str string, eval uint64) { - t.Run(str, func(t *testing.T) { - var val uint64 - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testUint64("0xFFFFFFFF", 0xFFFFFFFF) - testUint64("0xFFFFFFFFFFFFFFFF", 0xFFFFFFFFFFFFFFFF) -} - -func TestUnmarshalBigInt(t *testing.T) { - test := func(str string, eval *big.Int) { - t.Run(str, func(t *testing.T) { - var val big.Int - err := UnmarshalStr(str, &val) - if err != nil { - t.Fatal(err) - } - - if val.Cmp(eval) != 0 { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - test("null", new(big.Int)) - test("1", new(big.Int).SetUint64(1)) - test("-0xFFFFFFFFFFFFFFFF", new(big.Int).Neg(new(big.Int).SetUint64(0xFFFFFFFFFFFFFFFF))) -} - -func TestDecodeFloat(t *testing.T) { - test32 := func(str string, eval float32) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val float32 - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - test32("null", 0) - test32("1e0", 1) - test32("1e38", 1e38) - test32("+inf", float32(math.Inf(1))) - - test64 := func(str string, eval float64) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val float64 - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - test64("1e0", 1) - test64("1e308", 1e308) - test64("+inf", math.Inf(1)) -} - -func TestDecodeDecimal(t *testing.T) { - test := func(str string, eval *Decimal) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val *Decimal - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if !val.Equal(eval) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - test("1e10", MustParseDecimal("1d10")) - test("1.20", MustParseDecimal("1.20")) -} - -func TestDecodeTimeTo(t *testing.T) { - test := func(str string, eval time.Time) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val time.Time - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - test("null", time.Time{}) - test("2020T", time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) -} - -func TestDecodeStringTo(t *testing.T) { - test := func(str string, eval string) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val string - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if val != eval { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - test("null", "") - test("hello", "hello") - test("\"hello\"", "hello") -} - -func TestDecodeLobTo(t *testing.T) { - testSlice := func(str string, eval []byte) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val []byte - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testSlice("null", nil) - testSlice("{{}}", []byte{}) - testSlice("{{aGVsbG8=}}", []byte("hello")) - testSlice("{{'''hello'''}}", []byte("hello")) - - testArray := func(str string, eval []byte) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - - var val [8]byte - err := d.DecodeTo(&val) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(val[:], eval) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - testArray("null", make([]byte, 8)) - testArray("{{aGVsbG8=}}", append([]byte("hello"), []byte{0, 0, 0}...)) -} - -func TestDecodeStructTo(t *testing.T) { - test := func(str string, val, eval interface{}) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - err := d.DecodeTo(val) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - type foo struct { - Foo string - Baz int `json:"bar"` - } - - test("{}", &struct{}{}, &struct{}{}) - test("{bogus:(ignore me)}", &foo{}, &foo{}) - test("{foo:bar}", &foo{}, &foo{"bar", 0}) - test("{bar:42}", &foo{}, &foo{"", 42}) - test("{foo:bar,bar:42,bogus:(ignore me)}", &foo{}, &foo{"bar", 42}) - - test("{}", &map[string]string{}, &map[string]string{}) - test("{foo:bar}", &map[string]string{}, &map[string]string{"foo": "bar"}) - test("{a:4,b:2}", &map[string]int{}, &map[string]int{"a": 4, "b": 2}) -} - -func TestDecodeListTo(t *testing.T) { - test := func(str string, val, eval interface{}) { - t.Run(str, func(t *testing.T) { - d := NewDecoder(NewReaderStr(str)) - err := d.DecodeTo(val) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - f := false - pf := &f - ppf := &pf - - test("[]", &[]bool{}, &[]bool{}) - test("[]", &[]bool{true}, &[]bool{}) - - test("[false]", &[]bool{}, &[]bool{false}) - test("[false]", &[]*bool{}, &[]*bool{pf}) - test("[false,false]", &[]**bool{}, &[]**bool{ppf, ppf}) - - test("[true,false]", &[]interface{}{}, &[]interface{}{true, false}) - - var i interface{} - var ei interface{} = []interface{}{true, false} - test("[true,false]", &i, &ei) -} - -func TestDecode(t *testing.T) { - test := func(data string, eval interface{}) { - t.Run(data, func(t *testing.T) { - d := NewDecoder(NewReaderStr(data)) - val, err := d.Decode() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(val, eval) { - t.Errorf("expected %v, got %v", eval, val) - } - }) - } - - test("null", nil) - test("null.null", nil) - - test("null.bool", nil) - test("true", true) - test("false", false) - - test("null.int", nil) - test("0", int(0)) - test("2147483647", math.MaxInt32) - test("-2147483648", math.MinInt32) - test("2147483648", int64(math.MaxInt32)+1) - test("-2147483649", int64(math.MinInt32)-1) - test("9223372036854775808", new(big.Int).SetUint64(math.MaxInt64+1)) - - test("0e0", float64(0.0)) - test("1e100", float64(1e100)) - - test("0.", MustParseDecimal("0.")) - - test("2020T", time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)) - - test("hello", "hello") - test("\"hello\"", "hello") - - test("null.blob", nil) - test("{{}}", []byte{}) - test("{{aGVsbG8=}}", []byte("hello")) - - test("null.clob", nil) - test("{{''''''}}", []byte{}) - test("{{'''hello'''}}", []byte("hello")) - - test("null.struct", nil) - test("{}", map[string]interface{}{}) - test("{a:1,b:two}", map[string]interface{}{ - "a": 1, - "b": "two", - }) - - test("null.list", nil) - test("[]", []interface{}{}) - test("[1, two]", []interface{}{1, "two"}) - - test("null.sexp", nil) - test("()", []interface{}{}) - test("(1 + two)", []interface{}{1, "+", "two"}) -} diff --git a/writer.go b/writer.go deleted file mode 100644 index 2a85216a..00000000 --- a/writer.go +++ /dev/null @@ -1,165 +0,0 @@ -package ion - -import ( - "errors" - "io" - "math/big" - "time" -) - -// A Writer writes a stream of Ion values. -// -// The various Write methods write atomic values to the current output stream. The -// Begin methods begin writing a list, sexp, or struct respectively. Subsequent -// calls to Write will write values inside of the container until a matching -// End method is called. -// -// var w Writer -// w.BeginSexp() -// { -// w.WriteInt(1) -// w.WriteSymbol("+") -// w.WriteInt(1) -// } -// w.EndSexp() -// -// When writing values inside a struct, the FieldName method must be called before -// each value to set the value's field name. The Annotation method may likewise -// be called before writing any value to add an annotation to the value. -// -// var w Writer -// w.Annotation("user") -// w.BeginStruct() -// { -// w.FieldName("id") -// w.WriteString("qu33nb33") -// w.FieldName("name") -// w.WriteString("Beyoncé") -// } -// w.EndStruct() -// -// When you're done writing values, you should call Finish to ensure everything has -// been flushed from in-memory buffers. While individual methods all return an error -// on failure, implementations will remember any errors, no-op subsequent calls, and -// return the previous error. This lets you keep code a bit cleaner by only checking -// the return value of the final method call (generally Finish). -// -// var w Writer -// writeSomeStuff(w) -// if err := w.Finish(); err != nil { -// return err -// } -// -type Writer interface { - - // FieldName sets the field name for the next value written. - FieldName(val string) error - - // Annotation adds a single annotation to the next value written. - Annotation(val string) error - - // Annotations adds multiple annotations to the next value written. - Annotations(vals ...string) error - - // WriteNull writes an untyped null value. - WriteNull() error - // WriteNullType writes a null value with a type qualifier, e.g. null.bool. - WriteNullType(t Type) error - - // WriteBool writes a boolean value. - WriteBool(val bool) error - - // WriteInt writes an integer value. - WriteInt(val int64) error - // WriteUint writes an unsigned integer value. - WriteUint(val uint64) error - // WriteBigInt writes a big integer value. - WriteBigInt(val *big.Int) error - // WriteFloat writes a floating-point value. - WriteFloat(val float64) error - // WriteDecimal writes an arbitrary-precision decimal value. - WriteDecimal(val *Decimal) error - - // WriteTimestamp writes a timestamp value. - WriteTimestamp(val time.Time) error - - // WriteSymbol writes a symbol value. - WriteSymbol(val string) error - // WriteString writes a string value. - WriteString(val string) error - - // WriteClob writes a clob value. - WriteClob(val []byte) error - // WriteBlob writes a blob value. - WriteBlob(val []byte) error - - // BeginList begins writing a list value. - BeginList() error - // EndList finishes writing a list value. - EndList() error - - // BeginSexp begins writing an s-expression value. - BeginSexp() error - // EndSexp finishes writing an s-expression value. - EndSexp() error - - // BeginStruct begins writing a struct value. - BeginStruct() error - // EndStruct finishes writing a struct value. - EndStruct() error - - // Finish finishes writing values and flushes any buffered data. - Finish() error -} - -// A writer holds shared stuff for all writers. -type writer struct { - out io.Writer - ctx ctxstack - err error - - fieldName string - annotations []string -} - -// FieldName sets the field name for the next value written. -// It may only be called while writing a struct. -func (w *writer) FieldName(val string) error { - if w.err != nil { - return w.err - } - if !w.inStruct() { - w.err = errors.New("ion: Writer.FieldName called when not writing a struct") - return w.err - } - - w.fieldName = val - return nil -} - -// Annotation adds an annotation to the next value written. -func (w *writer) Annotation(val string) error { - if w.err == nil { - w.annotations = append(w.annotations, val) - } - return w.err -} - -// Annotations adds one or more annotations to the next value written. -func (w *writer) Annotations(val ...string) error { - if w.err == nil { - w.annotations = append(w.annotations, val...) - } - return w.err -} - -// InStruct returns true if we're currently writing a struct. -func (w *writer) inStruct() bool { - return w.ctx.peek() == ctxInStruct -} - -// Clear clears field name and annotations after writing a value. -func (w *writer) clear() { - w.fieldName = "" - w.annotations = nil -} From 02fdea824f8ddcc2e0bc53a4edc0a8516f37318f Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 20:25:56 -0700 Subject: [PATCH 44/56] Updates to blacklist and location of ion-tests --- ion/reader_test.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/ion/reader_test.go b/ion/reader_test.go index 2e204de1..77c55357 100644 --- a/ion/reader_test.go +++ b/ion/reader_test.go @@ -5,22 +5,24 @@ import ( "io/ioutil" "os" "path/filepath" + "strings" "testing" ) var blacklist = map[string]bool{ - "ion-tests/iontestdata/good/emptyAnnotatedInt.10n": true, - "ion-tests/iontestdata/good/subfieldVarUInt32bit.ion": true, - "ion-tests/iontestdata/good/utf16.ion": true, - "ion-tests/iontestdata/good/utf32.ion": true, - "ion-tests/iontestdata/good/whitespace.ion": true, - "ion-tests/iontestdata/good/item1.10n": true, + "../ion-tests/iontestdata/good/emptyAnnotatedInt.10n": true, + "../ion-tests/iontestdata/good/subfieldVarUInt32bit.ion": true, + "../ion-tests/iontestdata/good/utf16.ion": true, + "../ion-tests/iontestdata/good/utf32.ion": true, + "../ion-tests/iontestdata/good/whitespace.ion": true, + "../ion-tests/iontestdata/good/item1.10n": true, + "../ion-tests/iontestdata/good/typecodes/T7-large.10n" : true, } type drainfunc func(t *testing.T, r Reader, f string) func TestReadFiles(t *testing.T) { - testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { + testReadDir(t, "../ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { drain(t, r, 0) }) } @@ -59,7 +61,7 @@ func print(level int, obj interface{}) { } func TestDecodeFiles(t *testing.T) { - testReadDir(t, "ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { + testReadDir(t, "../ion-tests/iontestdata/good", func(t *testing.T, r Reader, f string) { // fmt.Println(f) d := NewDecoder(r) for { @@ -112,8 +114,11 @@ func testReadFile(t *testing.T, path string, d drainfunc) { if _, ok := blacklist[path]; ok { return } + if strings.HasSuffix(path, "md") { + return + } - // fmt.Println(path) + //fmt.Printf("**** PATH = %s\n", path) file, err := os.Open(path) if err != nil { From 58c7b218a5af4ccf9b8d6e4ad8bf01af099bc705 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 20:29:10 -0700 Subject: [PATCH 45/56] keeping internal folder --- internal/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 internal/.keep diff --git a/internal/.keep b/internal/.keep new file mode 100644 index 00000000..e69de29b From 8b3a97dd12e355ccb353acfd9ef1f2aee0baa3fe Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 20:32:48 -0700 Subject: [PATCH 46/56] Remove leftover line from conflict resolution --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index c50e0393..7c627339 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,3 @@ -<<<<<<< HEAD # Amazon Ion Go [![Build Status](https://github.com/amzn/ion-go/workflows/Go%20Build/badge.svg)](https://github.com/amzn/ion-go/actions?query=workflow%3A%22Go+Build%22) From 9769d1466767757d8fd0dcdee23f3b837efe3112 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 21:03:01 -0700 Subject: [PATCH 47/56] Removing and adding later to kick GitHub Action --- .github/workflows/go.yml | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 .github/workflows/go.yml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml deleted file mode 100644 index a1a266b7..00000000 --- a/.github/workflows/go.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Go Build -on: [push, pull_request] -jobs: - - build: - name: Build - runs-on: ubuntu-latest - steps: - - - name: Set up Go 1.13 - uses: actions/setup-go@v1 - with: - go-version: 1.13 - id: go - - - name: Check out code into the Go module directory - uses: actions/checkout@v2.0.0 - - - name: Check out submodules - uses: textbook/git-checkout-submodule-action@2.0.0 - - - name: Get dependencies - run: | - go get -v -t -d ./... - if [ -f Gopkg.toml ]; then - curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh - dep ensure - fi - - - name: Build - run: go build -v ./... - - - name: Test - run: go test -v ./... From 40bd8453cfde1e2249612c22743901f694d15b88 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Thu, 7 May 2020 21:05:04 -0700 Subject: [PATCH 48/56] Adding back GitHub Action --- .github/workflows/go.yml | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 .github/workflows/go.yml diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 00000000..a1a266b7 --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,34 @@ +name: Go Build +on: [push, pull_request] +jobs: + + build: + name: Build + runs-on: ubuntu-latest + steps: + + - name: Set up Go 1.13 + uses: actions/setup-go@v1 + with: + go-version: 1.13 + id: go + + - name: Check out code into the Go module directory + uses: actions/checkout@v2.0.0 + + - name: Check out submodules + uses: textbook/git-checkout-submodule-action@2.0.0 + + - name: Get dependencies + run: | + go get -v -t -d ./... + if [ -f Gopkg.toml ]; then + curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh + dep ensure + fi + + - name: Build + run: go build -v ./... + + - name: Test + run: go test -v ./... From 678efdf482f7e0b1bc5275a602bdd322b537ac14 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Fri, 8 May 2020 00:17:15 -0700 Subject: [PATCH 49/56] Trying code coverage --- .github/workflows/go.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index a1a266b7..4b51cb6e 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -31,4 +31,10 @@ jobs: run: go build -v ./... - name: Test - run: go test -v ./... + run: go test -v ./... -coverprofile coverage.txt + + + - name: Upload Coverage report to CodeCov + uses: codecov/codecov-action@v1.0.0 + with: + file: ./coverage.txt From 16d77f330ef9bca04d5e5f612146f88b96758194 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Fri, 8 May 2020 00:21:34 -0700 Subject: [PATCH 50/56] Adding codecov token Add token after adding it to repo's secrets --- .github/workflows/go.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4b51cb6e..56dd1ab7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -37,4 +37,5 @@ jobs: - name: Upload Coverage report to CodeCov uses: codecov/codecov-action@v1.0.0 with: - file: ./coverage.txt + token: ${{secrets.CODECOV_TOKEN}} + file: ./coverage.txt From 3cc0061370ab8bd4fecda119e108a69b85d7b8eb Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Fri, 8 May 2020 00:36:54 -0700 Subject: [PATCH 51/56] Kick the build again --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 56dd1ab7..939f5117 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -30,7 +30,7 @@ jobs: - name: Build run: go build -v ./... - - name: Test + - name: Test with coverage run: go test -v ./... -coverprofile coverage.txt From b1bb77776e86c0eef1ee532d83a41b93ed9c5ab8 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Fri, 8 May 2020 00:43:49 -0700 Subject: [PATCH 52/56] Use latest codecov action version --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 939f5117..053ad305 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -35,7 +35,7 @@ jobs: - name: Upload Coverage report to CodeCov - uses: codecov/codecov-action@v1.0.0 + uses: codecov/codecov-action@v1 with: token: ${{secrets.CODECOV_TOKEN}} file: ./coverage.txt From 2dff52be909d4d822c8077e9fe3d84c9858abd30 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Fri, 8 May 2020 10:55:51 -0700 Subject: [PATCH 53/56] Update NOTICE per Apache v2 --- NOTICE | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/NOTICE b/NOTICE index 0fca886e..f3f80116 100644 --- a/NOTICE +++ b/NOTICE @@ -1 +1,5 @@ Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +(https://github.com/fernomac/ion-go/blob/master/NOTICE) +Amazon Ion Go +Copyright 2019 David Murray From 9ae6004dcd72a2334b1ede26d929332e4acb1e19 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Fri, 8 May 2020 13:15:34 -0700 Subject: [PATCH 54/56] Adding PR template --- .github/PULL_REQUEST_TEMPLATE.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..c1711c2b --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ + + +Issue #, if available: + +Description of changes: + +By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. + From 749c0a70e42cfe4897820686ba56564f6f89981b Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Sat, 16 May 2020 22:38:52 -0700 Subject: [PATCH 55/56] fix go import in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7c627339..28d55e6a 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ It is recommended that you hook this in your favorite IDE (`Tools` > `File Watch ## Usage -Import `github.com/amzn/ion-go` and you're off to the races. +Import `github.com/amzn/ion-go/ion` and you're off to the races. ### Marshaling and Unmarshaling From 0cddde40fbd69fc0ef1bec24bb0ea9de85244551 Mon Sep 17 00:00:00 2001 From: Therapon Skoteiniotis Date: Sat, 16 May 2020 22:57:45 -0700 Subject: [PATCH 56/56] Address PR comments --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 28d55e6a..071c98d7 100644 --- a/README.md +++ b/README.md @@ -64,10 +64,10 @@ Import `github.com/amzn/ion-go/ion` and you're off to the races. ### Marshaling and Unmarshaling -Similar to Golang's built-in [json](https://golang.org/pkg/encoding/json/) package, -you can marshal and unmarshal go types to Ion. Marshaling requires you to specify +Similar to GoLang's built-in [json](https://golang.org/pkg/encoding/json/) package, +you can marshal and unmarshal Go types to Ion. Marshaling requires you to specify whether you'd like text or binary Ion. Unmarshaling is smart enough to do the right -thing. Both respect json name tags, and `Marshal` honors omitempty. +thing. Both respect json name tags, and `Marshal` honors `omitempty`. ```Go type T struct { @@ -81,7 +81,7 @@ type T struct { func main() { t := T{} - err := ion.Unmarshal([]byte("{A:\"Ion!\",B:{C:2,D:[3,4]}}"), &t) + err := ion.Unmarshal([]byte(`{A:"Ion!",B:{C:2,D:[3,4]}}`), &t) if err != nil { panic(err) }