Skip to content

Commit

Permalink
Merge pull request #8 from stasundr/wei_converter
Browse files Browse the repository at this point in the history
AddOverflow, SubOverflow, MulOverflow, DivOverflow functions
  • Loading branch information
yakud authored Apr 5, 2022
2 parents 554b101 + bd98661 commit f79326e
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 11 deletions.
139 changes: 128 additions & 11 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (d *Decimal) UnmarshalYAML(unmarshal func(interface{}) error) error {
return nil
}

// return d == y
// Eq return d == y
func (d *Decimal) Eq(y *Decimal) bool {
xx := NewDecimal(d)
yy := NewDecimal(y)
Expand All @@ -65,7 +65,7 @@ func (d *Decimal) Eq(y *Decimal) bool {
return xx.value.Eq(yy.value)
}

// return d > y
// Gt return d > y
func (d *Decimal) Gt(y *Decimal) bool {
xx := NewDecimal(d)
yy := NewDecimal(y)
Expand All @@ -79,7 +79,7 @@ func (d *Decimal) Gt(y *Decimal) bool {
return xx.value.Gt(yy.value)
}

// return d < y
// Lt return d < y
func (d *Decimal) Lt(y *Decimal) bool {
xx := NewDecimal(d)
yy := NewDecimal(y)
Expand All @@ -93,7 +93,7 @@ func (d *Decimal) Lt(y *Decimal) bool {
return xx.value.Lt(yy.value)
}

// d = d + y and return d
// Add d = d + y and return d
func (d *Decimal) Add(y *Decimal) *Decimal {
xx := NewDecimal(d)
yy := NewDecimal(y)
Expand All @@ -110,7 +110,33 @@ func (d *Decimal) Add(y *Decimal) *Decimal {
return d
}

// d = d - y and return d
// AddOverflow d = d + y and return d
func (d *Decimal) AddOverflow(y *Decimal) (*Decimal, bool) {
xx := NewDecimal(d)
yy := NewDecimal(y)

if yy.mantissa > xx.mantissa {
_, overflow := xx.RescaleOverflow(yy.mantissa)
if overflow {
return nil, true
}
} else if yy.mantissa < xx.mantissa {
_, overflow := yy.RescaleOverflow(xx.mantissa)
if overflow {
return nil, true
}
}

_, overflow := d.value.AddOverflow(xx.value, yy.value)
if overflow {
return nil, true
}
d.mantissa = xx.mantissa

return d, false
}

// Sub d = d - y and return d
func (d *Decimal) Sub(y *Decimal) *Decimal {
xx := NewDecimal(d)
yy := NewDecimal(y)
Expand All @@ -127,7 +153,37 @@ func (d *Decimal) Sub(y *Decimal) *Decimal {
return d
}

// d = d * y and return d
// SubOverflow d = d - y and return d
func (d *Decimal) SubOverflow(y *Decimal) (*Decimal, bool) {
xx := NewDecimal(d)
yy := NewDecimal(y)

if yy.mantissa > xx.mantissa {
_, overflow := xx.RescaleOverflow(yy.mantissa)
if overflow {
return nil, true
}
} else if yy.mantissa < xx.mantissa {
_, overflow := yy.RescaleOverflow(xx.mantissa)
if overflow {
return nil, true
}
}

if xx.Lt(yy) {
return nil, true
}

_, overflow := d.value.SubOverflow(xx.value, yy.value)
if overflow {
return nil, true
}
d.mantissa = xx.mantissa

return d, false
}

// Mul d = d * y and return d
func (d *Decimal) Mul(y *Decimal) *Decimal {
xx := NewDecimal(d)
yy := NewDecimal(y)
Expand All @@ -144,9 +200,36 @@ func (d *Decimal) Mul(y *Decimal) *Decimal {
return d
}

// MulOverflow d = d * y and return d
func (d *Decimal) MulOverflow(y *Decimal) (*Decimal, bool) {
xx := NewDecimal(d)
yy := NewDecimal(y)

if yy.mantissa > xx.mantissa {
_, overflow := xx.RescaleOverflow(yy.mantissa)
if overflow {
return nil, true
}
} else if yy.mantissa < xx.mantissa {
_, overflow := yy.RescaleOverflow(xx.mantissa)
if overflow {
return nil, true
}
}

_, overflow := d.value.MulOverflow(xx.value, yy.value)
if overflow {
return nil, true
}

d.mantissa = xx.mantissa + yy.mantissa

return d, false
}

const defaultDivScale = 20

// d = d / y and return d
// Div d = d / y and return d
func (d *Decimal) Div(y *Decimal) *Decimal {
if y.Eq(Zero) {
return NewDecimalZero()
Expand All @@ -173,8 +256,43 @@ func (d *Decimal) Div(y *Decimal) *Decimal {
return d
}

func (d *Decimal) SetFromBig(value *big.Int, mantissa uint8) (*Decimal, bool) {
overflow := d.value.SetFromBig(value)
// DivOverflow d = d / y and return d
func (d *Decimal) DivOverflow(y *Decimal) (*Decimal, bool) {
if y.Eq(Zero) {
return NewDecimalZero(), false
}

xx := NewDecimal(d)
yy := NewDecimal(y)

var scalerest uint8
e := int64(xx.mantissa) - int64(yy.mantissa) - int64(defaultDivScale)
if e > MaxUint8 {
return nil, true
}

if e < 0 {
_, overflow := xx.value.MulOverflow(xx.value, ExpScale(int16(-e)))
if overflow {
return nil, true
}
scalerest = defaultDivScale
} else {
_, overflow := yy.value.MulOverflow(yy.value, ExpScale(int16(e)))
if overflow {
return nil, true
}
scalerest = xx.mantissa
}

d.value.Div(xx.value, yy.value)
d.mantissa = scalerest

return d, false
}

func (d *Decimal) SetFromBig(value *big.Int, mantissa uint8) (v *Decimal, overflow bool) {
overflow = d.value.SetFromBig(value)
d.SetMantissa(mantissa)
return d, overflow
}
Expand All @@ -197,7 +315,7 @@ func (d *Decimal) GetMantissa() uint8 {
return d.mantissa
}

func (d *Decimal) FromString(value string) bool {
func (d *Decimal) FromString(value string) (ok bool) {
if d == nil {
*d = *NewDecimalZero()
}
Expand All @@ -207,7 +325,6 @@ func (d *Decimal) FromString(value string) bool {
return true
}

var ok bool
var mantissa uint8 = 0
var valBig = new(big.Int)
var parts = strings.Split(value, ".")
Expand Down
45 changes: 45 additions & 0 deletions decimal_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package decimal

import (
"math/big"
"testing"

"github.com/holiman/uint256"
Expand Down Expand Up @@ -279,6 +280,50 @@ func TestDecimal_Sub(t *testing.T) {
}
}

func TestDecimal_MulOverflow(t *testing.T) {
a, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)
b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)

c, overflow := NewDecimalFromBig(a, 0).MulOverflow(NewDecimalFromBig(b, 0))
assert.True(t, overflow)
assert.Nil(t, c)
}

func TestDecimal_DivOverflow(t *testing.T) {
a, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)
b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)

c, overflow := NewDecimalFromBig(a, 0).DivOverflow(NewDecimalFromBig(b, 0))
assert.True(t, overflow)
assert.Nil(t, c)
}

func TestDecimal_AddOverflow(t *testing.T) {
a, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)
b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)

c, overflow := NewDecimalFromBig(a, 0).AddOverflow(NewDecimalFromBig(b, 0))
assert.True(t, overflow)
assert.Nil(t, c)
}

func TestDecimal_SubOverflow(t *testing.T) {
a, ok := new(big.Int).SetString("ffffffffffff", 16)
assert.True(t, ok)
b, ok := new(big.Int).SetString("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 16)
assert.True(t, ok)

c, overflow := NewDecimalFromBig(a, 0).SubOverflow(NewDecimalFromBig(b, 0))
assert.True(t, overflow)
assert.Nil(t, c)
}

func TestDecimal_Mul(t *testing.T) {
x := NewDecimalFromUint256(uint256.NewInt(0).SetUint64(10), 0)
y := NewDecimalFromUint256(uint256.NewInt(0).SetUint64(10), 0)
Expand Down

0 comments on commit f79326e

Please sign in to comment.