From caa4fab1688a4331ada3b6cdfaeeeb135da7c565 Mon Sep 17 00:00:00 2001 From: Pablo Lopez Date: Wed, 29 Mar 2023 22:04:36 +0300 Subject: [PATCH] add UnmarshalJSON and MarshalJSON to Decimal (#191) * add UnmarshalJSON and MarshalJSON to Decimal * add more tests --- ion/decimal.go | 48 +++++++++++++++++++++++++++++++ ion/decimal_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/ion/decimal.go b/ion/decimal.go index fa94c369..2eab9023 100644 --- a/ion/decimal.go +++ b/ion/decimal.go @@ -390,3 +390,51 @@ func (d *Decimal) String() string { return b.String() } } + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (d *Decimal) UnmarshalJSON(decimalBytes []byte) error { + str := string(decimalBytes) + if str == "null" { + return nil + } + str = strings.Replace(str, "E", "D", 1) + str = strings.Replace(str, "e", "d", 1) + parsed, err := ParseDecimal(str) + if err != nil { + return fmt.Errorf("error unmarshalling decimal '%s': %w", str, err) + } + *d = *parsed + return nil +} + +// MarshalJSON implements the json.Marshaler interface. +func (d *Decimal) MarshalJSON() ([]byte, error) { + absN := new(big.Int).Abs(d.n).String() + scale := int(-d.scale) + sign := d.n.Sign() + + var str string + if scale == 0 { + str = absN + } else if scale > 0 { + // add zeroes to the right + str = absN + strings.Repeat("0", scale) + } else { + // add zeroes to the left + absScale := -scale + nLen := len(absN) + + if absScale >= nLen { + str = "0." + strings.Repeat("0", absScale-nLen) + absN + } else { + str = absN[:nLen-absScale] + "." + absN[nLen-absScale:] + } + str = strings.TrimRight(str, "0") + str = strings.TrimSuffix(str, ".") + } + + if sign == -1 { + str = "-" + str + } + return []byte(str), nil +} diff --git a/ion/decimal_test.go b/ion/decimal_test.go index 11d0fe0e..8511ae73 100644 --- a/ion/decimal_test.go +++ b/ion/decimal_test.go @@ -16,6 +16,7 @@ package ion import ( + "encoding/json" "fmt" "math/big" "testing" @@ -329,3 +330,72 @@ func TestUpscale(t *testing.T) { actual := d.upscale(4).String() assert.Equal(t, "10.0000", actual) } + +func TestMarshalJSON(t *testing.T) { + test := func(a string, expected string) { + t.Run("("+a+")", func(t *testing.T) { + ad, err := ParseDecimal(a) + require.NoError(t, err) + + am, err := ad.MarshalJSON() + require.NoError(t, err) + + assert.Equal(t, []byte(expected), am) + }) + } + test("123000", "123000") + + test("1.01", "1.01") + test("0.01", "0.01") + test("0.0", "0") + test("0.123456789012345678901234567890", "0.12345678901234567890123456789") // Trims trailing zeros + test("123456789012345678901234567890.123456789012345678901234567890", "123456789012345678901234567890.12345678901234567890123456789") // Trims trailing zeros + + test("1d-2", "0.01") + test("1d-3", "0.001") + test("1d2", "100") + + test("-1d3", "-1000") + test("-1d-3", "-0.001") + test("-0.0", "0") + test("-0.1", "-0.1") +} + +func TestUnmarshalJSON(t *testing.T) { + test := func(a string, expected string) { + t.Run("("+a+")", func(t *testing.T) { + expectedDec := MustParseDecimal(expected) + + var r struct { + D *Decimal `json:"d"` + } + err := json.Unmarshal([]byte(`{"d":`+a+`}`), &r) + require.NoError(t, err) + + assert.Truef(t, expectedDec.Equal(r.D), "expected %v, got %v", expected, r.D) + }) + } + + test("123000", "123000") + test("123.1", "123.1") + test("123.10", "123.1") + test("-123000", "-123000") + test("-123.1", "-123.1") + test("-123.10", "-123.1") + + test("1e+2", "100") + test("1e2", "100") + test("1E2", "100") + test("1E+2", "100") + + test("-1e+2", "-100") + test("-1e2", "-100") + test("-1E2", "-100") + test("-1E+2", "-100") + + test("1e-2", "0.01") + test("1E-2", "0.01") + + test("-1e-2", "-0.01") + test("-1E-2", "-0.01") +}