From bd986615d6ed7cdb914d93581d723a44624bed56 Mon Sep 17 00:00:00 2001 From: yakud Date: Tue, 5 Apr 2022 18:19:01 +0200 Subject: [PATCH] AddOverflow, SubOverflow, MulOverflow, DivOverflow functions --- decimal.go | 139 ++++++++++++++++++++++++++++++++++++++++++++---- decimal_test.go | 45 ++++++++++++++++ 2 files changed, 173 insertions(+), 11 deletions(-) diff --git a/decimal.go b/decimal.go index 1afbe39..73cbfb2 100644 --- a/decimal.go +++ b/decimal.go @@ -51,7 +51,7 @@ func (d *Decimal) UnmarshalYAML(unmarshal func(interface{}) error) error { return nil } -// return d == y +// Eq return d == y func (d *Decimal) Eq(y *Decimal) bool { xx := NewDecimal(d) yy := NewDecimal(y) @@ -65,7 +65,7 @@ func (d *Decimal) Eq(y *Decimal) bool { return xx.value.Eq(yy.value) } -// return d > y +// Gt return d > y func (d *Decimal) Gt(y *Decimal) bool { xx := NewDecimal(d) yy := NewDecimal(y) @@ -79,7 +79,7 @@ func (d *Decimal) Gt(y *Decimal) bool { return xx.value.Gt(yy.value) } -// return d < y +// Lt return d < y func (d *Decimal) Lt(y *Decimal) bool { xx := NewDecimal(d) yy := NewDecimal(y) @@ -93,7 +93,7 @@ func (d *Decimal) Lt(y *Decimal) bool { return xx.value.Lt(yy.value) } -// d = d + y and return d +// Add d = d + y and return d func (d *Decimal) Add(y *Decimal) *Decimal { xx := NewDecimal(d) yy := NewDecimal(y) @@ -110,7 +110,33 @@ func (d *Decimal) Add(y *Decimal) *Decimal { return d } -// d = d - y and return d +// AddOverflow d = d + y and return d +func (d *Decimal) AddOverflow(y *Decimal) (*Decimal, bool) { + xx := NewDecimal(d) + yy := NewDecimal(y) + + if yy.mantissa > xx.mantissa { + _, overflow := xx.RescaleOverflow(yy.mantissa) + if overflow { + return nil, true + } + } else if yy.mantissa < xx.mantissa { + _, overflow := yy.RescaleOverflow(xx.mantissa) + if overflow { + return nil, true + } + } + + _, overflow := d.value.AddOverflow(xx.value, yy.value) + if overflow { + return nil, true + } + d.mantissa = xx.mantissa + + return d, false +} + +// Sub d = d - y and return d func (d *Decimal) Sub(y *Decimal) *Decimal { xx := NewDecimal(d) yy := NewDecimal(y) @@ -127,7 +153,37 @@ func (d *Decimal) Sub(y *Decimal) *Decimal { return d } -// d = d * y and return d +// SubOverflow d = d - y and return d +func (d *Decimal) SubOverflow(y *Decimal) (*Decimal, bool) { + xx := NewDecimal(d) + yy := NewDecimal(y) + + if yy.mantissa > xx.mantissa { + _, overflow := xx.RescaleOverflow(yy.mantissa) + if overflow { + return nil, true + } + } else if yy.mantissa < xx.mantissa { + _, overflow := yy.RescaleOverflow(xx.mantissa) + if overflow { + return nil, true + } + } + + if xx.Lt(yy) { + return nil, true + } + + _, overflow := d.value.SubOverflow(xx.value, yy.value) + if overflow { + return nil, true + } + d.mantissa = xx.mantissa + + return d, false +} + +// Mul d = d * y and return d func (d *Decimal) Mul(y *Decimal) *Decimal { xx := NewDecimal(d) yy := NewDecimal(y) @@ -144,9 +200,36 @@ func (d *Decimal) Mul(y *Decimal) *Decimal { return d } +// MulOverflow d = d * y and return d +func (d *Decimal) MulOverflow(y *Decimal) (*Decimal, bool) { + xx := NewDecimal(d) + yy := NewDecimal(y) + + if yy.mantissa > xx.mantissa { + _, overflow := xx.RescaleOverflow(yy.mantissa) + if overflow { + return nil, true + } + } else if yy.mantissa < xx.mantissa { + _, overflow := yy.RescaleOverflow(xx.mantissa) + if overflow { + return nil, true + } + } + + _, overflow := d.value.MulOverflow(xx.value, yy.value) + if overflow { + return nil, true + } + + d.mantissa = xx.mantissa + yy.mantissa + + return d, false +} + const defaultDivScale = 20 -// d = d / y and return d +// Div d = d / y and return d func (d *Decimal) Div(y *Decimal) *Decimal { if y.Eq(Zero) { return NewDecimalZero() @@ -173,8 +256,43 @@ func (d *Decimal) Div(y *Decimal) *Decimal { return d } -func (d *Decimal) SetFromBig(value *big.Int, mantissa uint8) (*Decimal, bool) { - overflow := d.value.SetFromBig(value) +// DivOverflow d = d / y and return d +func (d *Decimal) DivOverflow(y *Decimal) (*Decimal, bool) { + if y.Eq(Zero) { + return NewDecimalZero(), false + } + + xx := NewDecimal(d) + yy := NewDecimal(y) + + var scalerest uint8 + e := int64(xx.mantissa) - int64(yy.mantissa) - int64(defaultDivScale) + if e > MaxUint8 { + return nil, true + } + + if e < 0 { + _, overflow := xx.value.MulOverflow(xx.value, ExpScale(int16(-e))) + if overflow { + return nil, true + } + scalerest = defaultDivScale + } else { + _, overflow := yy.value.MulOverflow(yy.value, ExpScale(int16(e))) + if overflow { + return nil, true + } + scalerest = xx.mantissa + } + + d.value.Div(xx.value, yy.value) + d.mantissa = scalerest + + return d, false +} + +func (d *Decimal) SetFromBig(value *big.Int, mantissa uint8) (v *Decimal, overflow bool) { + overflow = d.value.SetFromBig(value) d.SetMantissa(mantissa) return d, overflow } @@ -197,7 +315,7 @@ func (d *Decimal) GetMantissa() uint8 { return d.mantissa } -func (d *Decimal) FromString(value string) bool { +func (d *Decimal) FromString(value string) (ok bool) { if d == nil { *d = *NewDecimalZero() } @@ -207,7 +325,6 @@ func (d *Decimal) FromString(value string) bool { return true } - var ok bool var mantissa uint8 = 0 var valBig = new(big.Int) var parts = strings.Split(value, ".") diff --git a/decimal_test.go b/decimal_test.go index 60cf181..907845f 100644 --- a/decimal_test.go +++ b/decimal_test.go @@ -1,6 +1,7 @@ package decimal import ( + "math/big" "testing" "github.com/holiman/uint256" @@ -279,6 +280,50 @@ func TestDecimal_Sub(t *testing.T) { } } +func TestDecimal_MulOverflow(t *testing.T) { + a, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + + c, overflow := NewDecimalFromBig(a, 0).MulOverflow(NewDecimalFromBig(b, 0)) + assert.True(t, overflow) + assert.Nil(t, c) +} + +func TestDecimal_DivOverflow(t *testing.T) { + a, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + + c, overflow := NewDecimalFromBig(a, 0).DivOverflow(NewDecimalFromBig(b, 0)) + assert.True(t, overflow) + assert.Nil(t, c) +} + +func TestDecimal_AddOverflow(t *testing.T) { + a, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + + c, overflow := NewDecimalFromBig(a, 0).AddOverflow(NewDecimalFromBig(b, 0)) + assert.True(t, overflow) + assert.Nil(t, c) +} + +func TestDecimal_SubOverflow(t *testing.T) { + a, ok := new(big.Int).SetString("ffffffffffff", 16) + assert.True(t, ok) + b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16) + assert.True(t, ok) + + c, overflow := NewDecimalFromBig(a, 0).SubOverflow(NewDecimalFromBig(b, 0)) + assert.True(t, overflow) + assert.Nil(t, c) +} + func TestDecimal_Mul(t *testing.T) { x := NewDecimalFromUint256(uint256.NewInt(0).SetUint64(10), 0) y := NewDecimalFromUint256(uint256.NewInt(0).SetUint64(10), 0)