From caa4fab1688a4331ada3b6cdfaeeeb135da7c565 Mon Sep 17 00:00:00 2001
From: Pablo Lopez
Date: Wed, 29 Mar 2023 22:04:36 +0300
Subject: [PATCH] add UnmarshalJSON and MarshalJSON to Decimal (#191)
* add UnmarshalJSON and MarshalJSON to Decimal
* add more tests
---
ion/decimal.go | 48 +++++++++++++++++++++++++++++++
ion/decimal_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 118 insertions(+)
diff --git a/ion/decimal.go b/ion/decimal.go
index fa94c369..2eab9023 100644
--- a/ion/decimal.go
+++ b/ion/decimal.go
@@ -390,3 +390,51 @@ func (d *Decimal) String() string {
return b.String()
}
}
+
+// UnmarshalJSON implements the json.Unmarshaler interface.
+func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error {
+ str := string(decimalBytes)
+ if str == "null" {
+ return nil
+ }
+ str = strings.Replace(str, "E", "D", 1)
+ str = strings.Replace(str, "e", "d", 1)
+ parsed, err := ParseDecimal(str)
+ if err != nil {
+ return fmt.Errorf("error unmarshalling decimal '%s': %w", str, err)
+ }
+ *d = *parsed
+ return nil
+}
+
+// MarshalJSON implements the json.Marshaler interface.
+func (d *Decimal) MarshalJSON() ([]byte, error) {
+ absN := new(big.Int).Abs(d.n).String()
+ scale := int(-d.scale)
+ sign := d.n.Sign()
+
+ var str string
+ if scale == 0 {
+ str = absN
+ } else if scale > 0 {
+ // add zeroes to the right
+ str = absN + strings.Repeat("0", scale)
+ } else {
+ // add zeroes to the left
+ absScale := -scale
+ nLen := len(absN)
+
+ if absScale >= nLen {
+ str = "0." + strings.Repeat("0", absScale-nLen) + absN
+ } else {
+ str = absN[:nLen-absScale] + "." + absN[nLen-absScale:]
+ }
+ str = strings.TrimRight(str, "0")
+ str = strings.TrimSuffix(str, ".")
+ }
+
+ if sign == -1 {
+ str = "-" + str
+ }
+ return []byte(str), nil
+}
diff --git a/ion/decimal_test.go b/ion/decimal_test.go
index 11d0fe0e..8511ae73 100644
--- a/ion/decimal_test.go
+++ b/ion/decimal_test.go
@@ -16,6 +16,7 @@
package ion
import (
+ "encoding/json"
"fmt"
"math/big"
"testing"
@@ -329,3 +330,72 @@ func TestUpscale(t *testing.T) {
actual := d.upscale(4).String()
assert.Equal(t, "10.0000", actual)
}
+
+func TestMarshalJSON(t *testing.T) {
+ test := func(a string, expected string) {
+ t.Run("("+a+")", func(t *testing.T) {
+ ad, err := ParseDecimal(a)
+ require.NoError(t, err)
+
+ am, err := ad.MarshalJSON()
+ require.NoError(t, err)
+
+ assert.Equal(t, []byte(expected), am)
+ })
+ }
+ test("123000", "123000")
+
+ test("1.01", "1.01")
+ test("0.01", "0.01")
+ test("0.0", "0")
+ test("0.123456789012345678901234567890", "0.12345678901234567890123456789") // Trims trailing zeros
+ test("123456789012345678901234567890.123456789012345678901234567890", "123456789012345678901234567890.12345678901234567890123456789") // Trims trailing zeros
+
+ test("1d-2", "0.01")
+ test("1d-3", "0.001")
+ test("1d2", "100")
+
+ test("-1d3", "-1000")
+ test("-1d-3", "-0.001")
+ test("-0.0", "0")
+ test("-0.1", "-0.1")
+}
+
+func TestUnmarshalJSON(t *testing.T) {
+ test := func(a string, expected string) {
+ t.Run("("+a+")", func(t *testing.T) {
+ expectedDec := MustParseDecimal(expected)
+
+ var r struct {
+ D *Decimal `json:"d"`
+ }
+ err := json.Unmarshal([]byte(`{"d":`+a+`}`), &r)
+ require.NoError(t, err)
+
+ assert.Truef(t, expectedDec.Equal(r.D), "expected %v, got %v", expected, r.D)
+ })
+ }
+
+ test("123000", "123000")
+ test("123.1", "123.1")
+ test("123.10", "123.1")
+ test("-123000", "-123000")
+ test("-123.1", "-123.1")
+ test("-123.10", "-123.1")
+
+ test("1e+2", "100")
+ test("1e2", "100")
+ test("1E2", "100")
+ test("1E+2", "100")
+
+ test("-1e+2", "-100")
+ test("-1e2", "-100")
+ test("-1E2", "-100")
+ test("-1E+2", "-100")
+
+ test("1e-2", "0.01")
+ test("1E-2", "0.01")
+
+ test("-1e-2", "-0.01")
+ test("-1E-2", "-0.01")
+}