Skip to content

Commit

Permalink
impl numeric range unmarshal test
Browse files Browse the repository at this point in the history
  • Loading branch information
acim committed Dec 4, 2022
1 parent 8900eb1 commit addef24
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 18 deletions.
34 changes: 34 additions & 0 deletions pgtype/numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"math/big"
Expand Down Expand Up @@ -240,6 +241,39 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
return n.numberTextBytes(), nil
}

func (n *Numeric) UnmarshalJSON(b []byte) error {
var s *string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}

if s == nil {
*n = Numeric{}
return nil
}

switch *s {
case "infinity":
*n = Numeric{NaN: true, InfinityModifier: Infinity, Valid: true}
case "-infinity":
*n = Numeric{NaN: true, InfinityModifier: -Infinity, Valid: true}
default:
num, exp, err := parseNumericString(*s)
if err != nil {
return fmt.Errorf("failed to decode %s to numeric: %w", *s, err)
}

*n = Numeric{
Int: num,
Exp: exp,
Valid: true,
}
}

return nil
}

// numberString returns a string of the number. undefined if NaN, infinite, or NULL
func (n Numeric) numberTextBytes() []byte {
intStr := n.Int.String()
Expand Down
36 changes: 18 additions & 18 deletions pgtype/range.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,69 +322,69 @@ func (r *Range[T]) SetBoundTypes(lower, upper BoundType) error {
return nil
}

func (src Range[T]) MarshalJSON() ([]byte, error) {
if !src.Valid {
func (r Range[T]) MarshalJSON() ([]byte, error) {
if !r.Valid {
return []byte("null"), nil
}

enc := encodePlanRangeCodecRangeValuerToText{
m: &encodePlanRangeCodecJson{},
}

buf, err := enc.Encode(src, []byte(`"`))
buf, err := enc.Encode(r, []byte(`"`))
if err != nil {
return nil, fmt.Errorf("failed to encode %v as range: %w", src, err)
return nil, fmt.Errorf("failed to encode %v as range: %w", r, err)
}

buf = append(buf, `"`...)

return buf, nil
}

func (dst *Range[T]) UnmarshalJSON(b []byte) error {
func (r *Range[T]) UnmarshalJSON(b []byte) error {
if b[0] == byte('"') && b[len(b)-1] == byte('"') {
b = b[1 : len(b)-1]
}

s := string(b)

if s == "null" {
*dst = Range[T]{}
*r = Range[T]{}
return nil
}

r, err := parseUntypedTextRange(s)
utr, err := parseUntypedTextRange(s)
if err != nil {
return fmt.Errorf("failed to decode %s to range: %w", s, err)
}

*dst = Range[T]{
LowerType: r.LowerType,
UpperType: r.UpperType,
*r = Range[T]{
LowerType: utr.LowerType,
UpperType: utr.UpperType,
Valid: true,
}

if dst.LowerType == Empty && dst.UpperType == Empty {
if r.LowerType == Empty && r.UpperType == Empty {
return nil
}

if dst.LowerType != Unbounded {
if err = dst.unmarshalJSON(r.Lower, &dst.Lower); err != nil {
return fmt.Errorf("failed to decode %s to range lower: %w", r.Lower, err)
if r.LowerType != Unbounded {
if err = r.unmarshalJSON(utr.Lower, &r.Lower); err != nil {
return fmt.Errorf("failed to decode %s to range lower: %w", utr.Lower, err)
}
}

if dst.UpperType != Unbounded {
if err = dst.unmarshalJSON(r.Upper, &dst.Upper); err != nil {
return fmt.Errorf("failed to decode %s to range upper: %w", r.Upper, err)
if r.UpperType != Unbounded {
if err = r.unmarshalJSON(utr.Upper, &r.Upper); err != nil {
return fmt.Errorf("failed to decode %s to range upper: %w", utr.Upper, err)
}
}

return nil
}

func (_ *Range[T]) unmarshalJSON(data string, v *T) error {
buf := make([]byte, 0, len(data))
buf := make([]byte, 0, len(data)+2)
buf = append(buf, `"`...)
buf = append(buf, data...)
buf = append(buf, `"`...)
Expand Down
88 changes: 88 additions & 0 deletions pgtype/range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,91 @@ func TestRangeDateUnmarshalJSON(t *testing.T) {
}
}
}

func TestRangeNumericUnmarshalJSON(t *testing.T) {
t.Parallel()

tests := []struct {
src string
result Range[Numeric]
}{
{src: "null", result: Range[Numeric]{}},
{src: `"empty"`, result: Range[Numeric]{
LowerType: Empty,
UpperType: Empty,
Valid: true,
}},
{src: `"(-16,16)"`, result: Range[Numeric]{
Lower: Numeric{Int: big.NewInt(-16), Valid: true},
Upper: Numeric{Int: big.NewInt(16), Valid: true},
LowerType: Exclusive,
UpperType: Exclusive,
Valid: true,
}},
{src: `"(-16,16]"`, result: Range[Numeric]{
Lower: Numeric{Int: big.NewInt(-16), Valid: true},
Upper: Numeric{Int: big.NewInt(16), Valid: true},
LowerType: Exclusive,
UpperType: Inclusive,
Valid: true,
}},
{src: `"[-16,16)"`, result: Range[Numeric]{
Lower: Numeric{Int: big.NewInt(-16), Valid: true},
Upper: Numeric{Int: big.NewInt(16), Valid: true},
LowerType: Inclusive,
UpperType: Exclusive,
Valid: true,
}},
{src: `"[-16,16]"`, result: Range[Numeric]{
Lower: Numeric{Int: big.NewInt(-16), Valid: true},
Upper: Numeric{Int: big.NewInt(16), Valid: true},
LowerType: Inclusive,
UpperType: Inclusive,
Valid: true,
}},
{src: `"(,16)"`, result: Range[Numeric]{
Upper: Numeric{Int: big.NewInt(16), Valid: true},
LowerType: Unbounded,
UpperType: Exclusive,
Valid: true,
}},
{src: `"[-16,)"`, result: Range[Numeric]{
Lower: Numeric{Int: big.NewInt(-16), Valid: true},
LowerType: Inclusive,
UpperType: Unbounded,
Valid: true,
}},
{src: `"(-infinity,16)"`, result: Range[Numeric]{
Lower: Numeric{InfinityModifier: NegativeInfinity, NaN: true, Valid: true},
Upper: Numeric{Int: big.NewInt(16), Valid: true},
LowerType: Exclusive,
UpperType: Exclusive,
Valid: true,
}},
{src: `"[-16,infinity)"`, result: Range[Numeric]{
Lower: Numeric{Int: big.NewInt(-16), Valid: true},
Upper: Numeric{InfinityModifier: Infinity, NaN: true, Valid: true},
LowerType: Inclusive,
UpperType: Exclusive,
Valid: true,
}},
}

for i, tt := range tests {
var r Range[Numeric]
err := r.UnmarshalJSON([]byte(tt.src))
if err != nil {
t.Fatalf("%d: %v", i, err)
}

if r.Lower.Int.Cmp(tt.result.Lower.Int) != 0 ||
r.Lower.InfinityModifier != tt.result.Lower.InfinityModifier ||
r.LowerType != tt.result.LowerType ||
r.Upper.Int.Cmp(tt.result.Upper.Int) != 0 ||
r.Upper.InfinityModifier != tt.result.Upper.InfinityModifier ||
r.UpperType != tt.result.UpperType ||
r.Valid != r.Valid {
t.Errorf("%d: expected %s to decode to %v, got %v", i, tt.src, tt.result, r)
}
}
}

0 comments on commit addef24

Please sign in to comment.