From addef24b9abc24b3d1ed66c3bfd8e4b5d4adcb97 Mon Sep 17 00:00:00 2001 From: Boban Acimovic Date: Sun, 4 Dec 2022 17:51:40 +0100 Subject: [PATCH] impl numeric range unmarshal test --- pgtype/numeric.go | 34 +++++++++++++++++ pgtype/range.go | 36 +++++++++--------- pgtype/range_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 18 deletions(-) diff --git a/pgtype/numeric.go b/pgtype/numeric.go index a5f4ed3ac..bf42d5569 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "math/big" @@ -240,6 +241,39 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return n.numberTextBytes(), nil } +func (n *Numeric) UnmarshalJSON(b []byte) error { + var s *string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + + if s == nil { + *n = Numeric{} + return nil + } + + switch *s { + case "infinity": + *n = Numeric{NaN: true, InfinityModifier: Infinity, Valid: true} + case "-infinity": + *n = Numeric{NaN: true, InfinityModifier: -Infinity, Valid: true} + default: + num, exp, err := parseNumericString(*s) + if err != nil { + return fmt.Errorf("failed to decode %s to numeric: %w", *s, err) + } + + *n = Numeric{ + Int: num, + Exp: exp, + Valid: true, + } + } + + return nil +} + // numberString returns a string of the number. undefined if NaN, infinite, or NULL func (n Numeric) numberTextBytes() []byte { intStr := n.Int.String() diff --git a/pgtype/range.go b/pgtype/range.go index a4db79ad8..ff6ba4288 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -322,8 +322,8 @@ func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error { return nil } -func (src Range[T]) MarshalJSON() ([]byte, error) { - if !src.Valid { +func (r Range[T]) MarshalJSON() ([]byte, error) { + if !r.Valid { return []byte("null"), nil } @@ -331,9 +331,9 @@ func (src Range[T]) MarshalJSON() ([]byte, error) { m: &encodePlanRangeCodecJson{}, } - buf, err := enc.Encode(src, []byte(`"`)) + buf, err := enc.Encode(r, []byte(`"`)) if err != nil { - return nil, fmt.Errorf("failed to encode %v as range: %w", src, err) + return nil, fmt.Errorf("failed to encode %v as range: %w", r, err) } buf = append(buf, `"`...) @@ -341,7 +341,7 @@ func (src Range[T]) MarshalJSON() ([]byte, error) { return buf, nil } -func (dst *Range[T]) UnmarshalJSON(b []byte) error { +func (r *Range[T]) UnmarshalJSON(b []byte) error { if b[0] == byte('"') && b[len(b)-1] == byte('"') { b = b[1 : len(b)-1] } @@ -349,34 +349,34 @@ func (dst *Range[T]) UnmarshalJSON(b []byte) error { s := string(b) if s == "null" { - *dst = Range[T]{} + *r = Range[T]{} return nil } - r, err := parseUntypedTextRange(s) + utr, err := parseUntypedTextRange(s) if err != nil { return fmt.Errorf("failed to decode %s to range: %w", s, err) } - *dst = Range[T]{ - LowerType: r.LowerType, - UpperType: r.UpperType, + *r = Range[T]{ + LowerType: utr.LowerType, + UpperType: utr.UpperType, Valid: true, } - if dst.LowerType == Empty && dst.UpperType == Empty { + if r.LowerType == Empty && r.UpperType == Empty { return nil } - if dst.LowerType != Unbounded { - if err = dst.unmarshalJSON(r.Lower, &dst.Lower); err != nil { - return fmt.Errorf("failed to decode %s to range lower: %w", r.Lower, err) + if r.LowerType != Unbounded { + if err = r.unmarshalJSON(utr.Lower, &r.Lower); err != nil { + return fmt.Errorf("failed to decode %s to range lower: %w", utr.Lower, err) } } - if dst.UpperType != Unbounded { - if err = dst.unmarshalJSON(r.Upper, &dst.Upper); err != nil { - return fmt.Errorf("failed to decode %s to range upper: %w", r.Upper, err) + if r.UpperType != Unbounded { + if err = r.unmarshalJSON(utr.Upper, &r.Upper); err != nil { + return fmt.Errorf("failed to decode %s to range upper: %w", utr.Upper, err) } } @@ -384,7 +384,7 @@ func (dst *Range[T]) UnmarshalJSON(b []byte) error { } func (_ *Range[T]) unmarshalJSON(data string, v *T) error { - buf := make([]byte, 0, len(data)) + buf := make([]byte, 0, len(data)+2) buf = append(buf, `"`...) buf = append(buf, data...) buf = append(buf, `"`...) diff --git a/pgtype/range_test.go b/pgtype/range_test.go index 5ed7f385f..2a0534134 100644 --- a/pgtype/range_test.go +++ b/pgtype/range_test.go @@ -431,3 +431,91 @@ func TestRangeDateUnmarshalJSON(t *testing.T) { } } } + +func TestRangeNumericUnmarshalJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + src string + result Range[Numeric] + }{ + {src: "null", result: Range[Numeric]{}}, + {src: `"empty"`, result: Range[Numeric]{ + LowerType: Empty, + UpperType: Empty, + Valid: true, + }}, + {src: `"(-16,16)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"(-16,16]"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"[-16,16)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[-16,16]"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Inclusive, + UpperType: Inclusive, + Valid: true, + }}, + {src: `"(,16)"`, result: Range[Numeric]{ + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Unbounded, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[-16,)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + LowerType: Inclusive, + UpperType: Unbounded, + Valid: true, + }}, + {src: `"(-infinity,16)"`, result: Range[Numeric]{ + Lower: Numeric{InfinityModifier: NegativeInfinity, NaN: true, Valid: true}, + Upper: Numeric{Int: big.NewInt(16), Valid: true}, + LowerType: Exclusive, + UpperType: Exclusive, + Valid: true, + }}, + {src: `"[-16,infinity)"`, result: Range[Numeric]{ + Lower: Numeric{Int: big.NewInt(-16), Valid: true}, + Upper: Numeric{InfinityModifier: Infinity, NaN: true, Valid: true}, + LowerType: Inclusive, + UpperType: Exclusive, + Valid: true, + }}, + } + + for i, tt := range tests { + var r Range[Numeric] + err := r.UnmarshalJSON([]byte(tt.src)) + if err != nil { + t.Fatalf("%d: %v", i, err) + } + + if r.Lower.Int.Cmp(tt.result.Lower.Int) != 0 || + r.Lower.InfinityModifier != tt.result.Lower.InfinityModifier || + r.LowerType != tt.result.LowerType || + r.Upper.Int.Cmp(tt.result.Upper.Int) != 0 || + r.Upper.InfinityModifier != tt.result.Upper.InfinityModifier || + r.UpperType != tt.result.UpperType || + r.Valid != r.Valid { + t.Errorf("%d: expected %s to decode to %v, got %v", i, tt.src, tt.result, r) + } + } +}