diff --git a/internal/interpreter/evaluate_expr.go b/internal/interpreter/evaluate_expr.go index 1b88d9f..b718693 100644 --- a/internal/interpreter/evaluate_expr.go +++ b/internal/interpreter/evaluate_expr.go @@ -148,21 +148,22 @@ func (st *programState) subOp(left parser.ValueExpr, right parser.ValueExpr) (Va return (*leftValue).evalSub(st, right) } -func (st *programState) numOp(left parser.ValueExpr, right parser.ValueExpr, op func(left *big.Int, right *big.Int) Value) (Value, InterpreterError) { - parsedLeft, err := evaluateExprAs(st, left, expectNumber) +func (st *programState) eqOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { + parsedLeft, err := evaluateExprAs(st, left, expectAnything) if err != nil { return nil, err } - parsedRight, err := evaluateExprAs(st, right, expectNumber) + parsedRight, err := evaluateExprAs(st, right, expectAnything) if err != nil { return nil, err } - return op(parsedLeft, parsedRight), nil + // TODO remove reflect usage + return Bool(reflect.DeepEqual(parsedLeft, parsedRight)), nil } -func (st *programState) eqOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { +func (st *programState) neqOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { parsedLeft, err := evaluateExprAs(st, left, expectAnything) if err != nil { return nil, err @@ -174,66 +175,85 @@ func (st *programState) eqOp(left parser.ValueExpr, right parser.ValueExpr) (Val } // TODO remove reflect usage - return Bool(reflect.DeepEqual(parsedLeft, parsedRight)), nil + return Bool(!(reflect.DeepEqual(parsedLeft, parsedRight))), nil } -func (st *programState) neqOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { - parsedLeft, err := evaluateExprAs(st, left, expectAnything) +func (st *programState) ltOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { + cmp, err := st.evaluateExprAsCmp(left) if err != nil { return nil, err } - parsedRight, err := evaluateExprAs(st, right, expectAnything) + cmpResult, err := (*cmp).evalCmp(st, right) if err != nil { return nil, err } - // TODO remove reflect usage - return Bool(!(reflect.DeepEqual(parsedLeft, parsedRight))), nil -} + switch *cmpResult { + case -1: + return Bool(true), nil + default: + return Bool(false), nil + } -func (st *programState) ltOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { - return st.numOp(left, right, func(left, right *big.Int) Value { - switch left.Cmp(right) { - case -1: - return Bool(true) - default: - return Bool(false) - } - }) } func (st *programState) gtOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { - return st.numOp(left, right, func(left, right *big.Int) Value { - switch left.Cmp(right) { - case 1: - return Bool(true) - default: - return Bool(false) - } - }) + cmp, err := st.evaluateExprAsCmp(left) + if err != nil { + return nil, err + } + + cmpResult, err := (*cmp).evalCmp(st, right) + if err != nil { + return nil, err + } + + switch *cmpResult { + case 1: + return Bool(true), nil + default: + return Bool(false), nil + } } func (st *programState) lteOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { - return st.numOp(left, right, func(left, right *big.Int) Value { - switch left.Cmp(right) { - case -1, 0: - return Bool(true) - default: - return Bool(false) - } - }) + cmp, err := st.evaluateExprAsCmp(left) + if err != nil { + return nil, err + } + + cmpResult, err := (*cmp).evalCmp(st, right) + if err != nil { + return nil, err + } + + switch *cmpResult { + case -1, 0: + return Bool(true), nil + default: + return Bool(false), nil + } + } func (st *programState) gteOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { - return st.numOp(left, right, func(left, right *big.Int) Value { - switch left.Cmp(right) { - case 1, 0: - return Bool(true) - default: - return Bool(false) - } - }) + cmp, err := st.evaluateExprAsCmp(left) + if err != nil { + return nil, err + } + + cmpResult, err := (*cmp).evalCmp(st, right) + if err != nil { + return nil, err + } + + switch *cmpResult { + case 1, 0: + return Bool(true), nil + default: + return Bool(false), nil + } } func (st *programState) andOp(left parser.ValueExpr, right parser.ValueExpr) (Value, InterpreterError) { diff --git a/internal/interpreter/overloads.go b/internal/interpreter/overloads.go index a0a9d0f..ce0694e 100644 --- a/internal/interpreter/overloads.go +++ b/internal/interpreter/overloads.go @@ -25,21 +25,15 @@ func (m MonetaryInt) evalAdd(st *programState, other parser.ValueExpr) (Value, I } func (m Monetary) evalAdd(st *programState, other parser.ValueExpr) (Value, InterpreterError) { - m2, err := evaluateExprAs(st, other, expectMonetary) + b2, err := evaluateExprAs(st, other, expectMonetaryOfAsset(string(m.Asset))) if err != nil { return nil, err } - - if m.Asset != m2.Asset { - return nil, MismatchedCurrencyError{ - Expected: m.Asset.String(), - Got: m2.Asset.String(), - } - } - + b1 := big.Int(m.Amount) + sum := new(big.Int).Add(&b1, b2) return Monetary{ Asset: m.Asset, - Amount: m.Amount.Add(m2.Amount), + Amount: MonetaryInt(*sum), }, nil } @@ -80,3 +74,53 @@ func (m Monetary) evalSub(st *programState, other parser.ValueExpr) (Value, Inte }, nil } + +type opCmp interface { + evalCmp(st *programState, other parser.ValueExpr) (*int, InterpreterError) +} + +var _ opCmp = (*MonetaryInt)(nil) +var _ opCmp = (*Monetary)(nil) + +func (m MonetaryInt) evalCmp(st *programState, other parser.ValueExpr) (*int, InterpreterError) { + b2, err := evaluateExprAs(st, other, expectNumber) + if err != nil { + return nil, err + } + + b1 := big.Int(m) + + cmp := b1.Cmp(b2) + return &cmp, nil +} + +func (m Monetary) evalCmp(st *programState, other parser.ValueExpr) (*int, InterpreterError) { + b2, err := evaluateExprAs(st, other, expectMonetaryOfAsset(string(m.Asset))) + if err != nil { + return nil, err + } + + b1 := big.Int(m.Amount) + + cmp := b1.Cmp(b2) + return &cmp, nil +} + +func (st *programState) evaluateExprAsCmp(expr parser.ValueExpr) (*opCmp, InterpreterError) { + exprCmp, err := evaluateExprAs(st, expr, expectOneOf( + expectMapped(expectMonetary, func(m Monetary) opCmp { + return m + }), + + // while "x.map(identity)" is the same as "x", just writing "expectNumber" would't typecheck + expectMapped(expectNumber, func(bi big.Int) opCmp { + return MonetaryInt(bi) + }), + )) + + if err != nil { + return nil, err + } + + return exprCmp, nil +} diff --git a/internal/interpreter/value.go b/internal/interpreter/value.go index 61fc4c5..5a48396 100644 --- a/internal/interpreter/value.go +++ b/internal/interpreter/value.go @@ -234,14 +234,6 @@ func NewMonetaryInt(n int64) MonetaryInt { return MonetaryInt(*bi) } -func (m MonetaryInt) Add(other MonetaryInt) MonetaryInt { - bi := big.Int(m) - otherBi := big.Int(other) - - sum := new(big.Int).Add(&bi, &otherBi) - return MonetaryInt(*sum) -} - func (m MonetaryInt) Sub(other MonetaryInt) MonetaryInt { bi := big.Int(m) otherBi := big.Int(other)