diff --git a/pkg/binfile/expr.go b/pkg/binfile/expr.go index eeda32b..d076829 100644 --- a/pkg/binfile/expr.go +++ b/pkg/binfile/expr.go @@ -135,6 +135,24 @@ func (e *jsonExprFuncall) ToHir(schema *hir.Schema) hir.Expr { return &hir.Mul{Args: args} case "VectorSub", "Sub": return &hir.Sub{Args: args} + case "Exp": + if len(args) != 2 { + panic(fmt.Sprintf("incorrect number of arguments for Exp (%d)", len(args))) + } + + c, ok := args[1].(*hir.Constant) + + if !ok { + panic(fmt.Sprintf("constant power expected for Exp, got %s", args[1].String())) + } else if !c.Val.IsUint64() { + panic("constant power too large for Exp") + } + + var k big.Int + // Convert power to uint64 + c.Val.BigInt(&k) + // Done + return &hir.Exp{Arg: args[0], Pow: k.Uint64()} case "IfZero": if len(args) == 2 { return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: nil} diff --git a/pkg/hir/eval.go b/pkg/hir/eval.go index 97c53ce..07df88d 100644 --- a/pkg/hir/eval.go +++ b/pkg/hir/eval.go @@ -3,6 +3,7 @@ package hir import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" ) // EvalAllAt evaluates a column access at a given row in a trace, which returns the @@ -38,6 +39,17 @@ func (e *Mul) EvalAllAt(k int, tr trace.Trace) []*fr.Element { return evalExprsAt(k, tr, e.Args, fn) } +// EvalAllAt evaluates a product at a given row in a trace by first evaluating all of +// its arguments at that row. +func (e *Exp) EvalAllAt(k int, tr trace.Trace) []*fr.Element { + vals := e.Arg.EvalAllAt(k, tr) + for _, v := range vals { + util.Pow(v, e.Pow) + } + + return vals +} + // EvalAllAt evaluates a conditional at a given row in a trace by first evaluating // its condition at that row. If that condition is zero then the true branch // (if applicable) is evaluated; otherwise if the condition is non-zero then diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index d65cc9c..df5f735 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -86,6 +86,26 @@ func (p *Mul) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } +// ============================================================================ +// Exponentiation +// ============================================================================ + +// Exp represents the a given value taken to a power. +type Exp struct { + Arg Expr + Pow uint64 +} + +// Bounds returns max shift in either the negative (left) or positive +// direction (right). +func (p *Exp) Bounds() util.Bounds { return p.Arg.Bounds() } + +// Context determines the evaluation context (i.e. enclosing module) for this +// expression. +func (p *Exp) Context(schema sc.Schema) trace.Context { + return p.Arg.Context(schema) +} + // ============================================================================ // List // ============================================================================ diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 93f35de..65dfb20 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -110,6 +110,13 @@ func (e *ColumnAccess) LowerTo(schema *mir.Schema) []mir.Expr { return lowerTo(e, schema) } +// LowerTo lowers an exponent expression to the MIR level. This requires expanding +// the argument andn lowering it. Furthermore, conditionals are "lifted" to +// the top. +func (e *Exp) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) +} + // LowerTo lowers a product expression to the MIR level. This requires expanding // the arguments, then lowering them. Furthermore, conditionals are "lifted" to // the top. @@ -182,6 +189,8 @@ func lowerCondition(e Expr, schema *mir.Schema) mir.Expr { return lowerConditions(p.Args, schema) } else if p, ok := e.(*Normalise); ok { return lowerCondition(p.Arg, schema) + } else if p, ok := e.(*Exp); ok { + return lowerCondition(p.Arg, schema) } else if p, ok := e.(*IfZero); ok { return lowerIfZeroCondition(p, schema) } else if p, ok := e.(*Sub); ok { @@ -248,6 +257,8 @@ func lowerBody(e Expr, schema *mir.Schema) mir.Expr { return &mir.ColumnAccess{Column: p.Column, Shift: p.Shift} } else if p, ok := e.(*Mul); ok { return &mir.Mul{Args: lowerBodies(p.Args, schema)} + } else if p, ok := e.(*Exp); ok { + return &mir.Exp{Arg: lowerBody(p.Arg, schema), Pow: p.Pow} } else if p, ok := e.(*Normalise); ok { return &mir.Normalise{Arg: lowerBody(p.Arg, schema)} } else if p, ok := e.(*IfZero); ok { @@ -306,6 +317,13 @@ func expand(e Expr) []Expr { ees = append(ees, expand(arg)...) } + return ees + } else if p, ok := e.(*Exp); ok { + ees := expand(p.Arg) + for i, ee := range ees { + ees[i] = &Exp{ee, p.Pow} + } + return ees } else if p, ok := e.(*Normalise); ok { ees := expand(p.Arg) diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go index ceedd7f..23a2071 100644 --- a/pkg/hir/parser.go +++ b/pkg/hir/parser.go @@ -3,6 +3,7 @@ package hir import ( "errors" "fmt" + "math/big" "strconv" "strings" "unicode" @@ -77,6 +78,7 @@ func newHirParser(srcmap *sexp.SourceMap[sexp.SExp]) *hirParser { p.AddRecursiveRule("-", subParserRule) p.AddRecursiveRule("*", mulParserRule) p.AddRecursiveRule("~", normParserRule) + p.AddRecursiveRule("^", powParserRule) p.AddRecursiveRule("if", ifParserRule) p.AddRecursiveRule("ifnot", ifNotParserRule) p.AddRecursiveRule("begin", beginParserRule) @@ -543,6 +545,25 @@ func shiftParserRule(parser *hirParser) func(string, string) (Expr, error) { } } +func powParserRule(args []Expr) (Expr, error) { + var k big.Int + + if len(args) != 2 { + return nil, errors.New("incorrect number of arguments") + } + + c, ok := args[1].(*Constant) + if !ok { + return nil, errors.New("expected constant power") + } else if !c.Val.IsUint64() { + return nil, errors.New("constant power too large") + } + // Convert power to uint64 + c.Val.BigInt(&k) + // Done + return &Exp{Arg: args[0], Pow: k.Uint64()}, nil +} + func normParserRule(args []Expr) (Expr, error) { if len(args) != 1 { return nil, errors.New("incorrect number of arguments") diff --git a/pkg/hir/string.go b/pkg/hir/string.go index e650c29..8113e31 100644 --- a/pkg/hir/string.go +++ b/pkg/hir/string.go @@ -32,6 +32,10 @@ func (e *Mul) String() string { return naryString("*", e.Args) } +func (e *Exp) String() string { + return fmt.Sprintf("(^ %s %d)", e.Arg, e.Pow) +} + func (e *Normalise) String() string { return fmt.Sprintf("(~ %s)", e.Arg) } diff --git a/pkg/mir/eval.go b/pkg/mir/eval.go index 4a2f574..fdfb55b 100644 --- a/pkg/mir/eval.go +++ b/pkg/mir/eval.go @@ -3,6 +3,7 @@ package mir import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" ) // EvalAt evaluates a column access at a given row in a trace, which returns the @@ -38,6 +39,17 @@ func (e *Mul) EvalAt(k int, tr trace.Trace) *fr.Element { return evalExprsAt(k, tr, e.Args, fn) } +// EvalAt evaluates a product at a given row in a trace by first evaluating all of +// its arguments at that row. +func (e *Exp) EvalAt(k int, tr trace.Trace) *fr.Element { + // Check whether argument evaluates to zero or not. + val := e.Arg.EvalAt(k, tr) + // Compute exponent + util.Pow(val, e.Pow) + // Done + return val +} + // EvalAt evaluates the normalisation of some expression by first evaluating // that expression. Then, zero is returned if the result is zero; otherwise one // is returned. diff --git a/pkg/mir/expr.go b/pkg/mir/expr.go index 6c192ea..99e2f76 100644 --- a/pkg/mir/expr.go +++ b/pkg/mir/expr.go @@ -76,6 +76,26 @@ func (p *Mul) Context(schema sc.Schema) trace.Context { return sc.JoinContexts[Expr](p.Args, schema) } +// ============================================================================ +// Exponentiation +// ============================================================================ + +// Exp represents the a given value taken to a power. +type Exp struct { + Arg Expr + Pow uint64 +} + +// Bounds returns max shift in either the negative (left) or positive +// direction (right). +func (p *Exp) Bounds() util.Bounds { return p.Arg.Bounds() } + +// Context determines the evaluation context (i.e. enclosing module) for this +// expression. +func (p *Exp) Context(schema sc.Schema) trace.Context { + return p.Arg.Context(schema) +} + // ============================================================================ // Constant // ============================================================================ diff --git a/pkg/mir/lower.go b/pkg/mir/lower.go index d669e75..f4dc6bd 100644 --- a/pkg/mir/lower.go +++ b/pkg/mir/lower.go @@ -170,6 +170,22 @@ func (e *Mul) LowerTo(schema *air.Schema) air.Expr { return &air.Mul{Args: lowerExprs(e.Args, schema)} } +// LowerTo lowers an exponent expression to the AIR level by lowering the +// argument, and then constructing a multiplication. This is because the AIR +// level does not support an explicit exponent operator. +func (e *Exp) LowerTo(schema *air.Schema) air.Expr { + // Lower the expression being raised + le := e.Arg.LowerTo(schema) + // Multiply it out k times + es := make([]air.Expr, e.Pow) + // + for i := uint64(0); i < e.Pow; i++ { + es[i] = le + } + // Done + return &air.Mul{Args: es} +} + // LowerTo lowers a normalise expression to the AIR level by "compiling it out" // using a computed column. func (p *Normalise) LowerTo(schema *air.Schema) air.Expr { diff --git a/pkg/mir/string.go b/pkg/mir/string.go index cf904f3..7bf37fb 100644 --- a/pkg/mir/string.go +++ b/pkg/mir/string.go @@ -32,6 +32,10 @@ func (e *Normalise) String() string { return fmt.Sprintf("(~ %s)", e.Arg) } +func (e *Exp) String() string { + return fmt.Sprintf("(^ %s %d)", e.Arg, e.Pow) +} + func naryString(operator string, exprs []Expr) string { // This should be generalised and moved into common? var rs string diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index c1d2804..ee5e084 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -60,6 +60,10 @@ func Test_Basic_09(t *testing.T) { Check(t, "basic_09") } +func Test_Basic_10(t *testing.T) { + Check(t, "basic_10") +} + // =================================================================== // Domain Tests // =================================================================== @@ -420,12 +424,14 @@ func TestSlow_Mxp(t *testing.T) { // Determines the maximum amount of padding to use when testing. Specifically, // every trace is tested with varying amounts of padding upto this value. -const MAX_PADDING uint = 1 +const MAX_PADDING uint = 5 // For a given set of constraints, check that all traces which we // expect to be accepted are accepted, and all traces that we expect // to be rejected are rejected. func Check(t *testing.T, test string) { + // Enable testing each trace in parallel + t.Parallel() // Read constraints file bytes, err := os.ReadFile(fmt.Sprintf("%s/%s.lisp", TestDir, test)) // Check test file read ok diff --git a/pkg/test/util_test.go b/pkg/test/util_test.go new file mode 100644 index 0000000..bdb0966 --- /dev/null +++ b/pkg/test/util_test.go @@ -0,0 +1,71 @@ +package test + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/go-corset/pkg/util" +) + +const POW_BASE_MAX uint = 65536 +const POW_BASE_INC uint = 8 + +func Test_Pow_01(t *testing.T) { + PowCheckLoop(t, 0) +} + +func Test_Pow_02(t *testing.T) { + PowCheckLoop(t, 1) +} + +func Test_Pow_03(t *testing.T) { + PowCheckLoop(t, 2) +} + +func Test_Pow_04(t *testing.T) { + PowCheckLoop(t, 3) +} + +func Test_Pow_05(t *testing.T) { + PowCheckLoop(t, 4) +} + +func Test_Pow_06(t *testing.T) { + PowCheckLoop(t, 5) +} + +func Test_Pow_07(t *testing.T) { + PowCheckLoop(t, 6) +} + +func Test_Pow_08(t *testing.T) { + PowCheckLoop(t, 7) +} + +func PowCheckLoop(t *testing.T, first uint) { + // Enable parallel testing + t.Parallel() + // Run through the loop + for i := first; i < POW_BASE_MAX; i += POW_BASE_INC { + for j := uint64(0); j < 256; j++ { + PowCheck(t, i, j) + } + } +} + +// Check pow computed correctly. This is done by comparing against the existing +// gnark function. +func PowCheck(t *testing.T, base uint, pow uint64) { + k := big.NewInt(int64(pow)) + v1 := fr.NewElement(uint64(base)) + v2 := fr.NewElement(uint64(base)) + // V1 computed using our optimised method + util.Pow(&v1, pow) + // V2 computed using existing gnark function + v2.Exp(v2, k) + // Final sanity check + if v1.Cmp(&v2) != 0 { + t.Errorf("Pow(%d,%d)=%s (not %s)", base, pow, v1.String(), v2.String()) + } +} diff --git a/pkg/util/fields.go b/pkg/util/fields.go index dac2e29..bfd66b4 100644 --- a/pkg/util/fields.go +++ b/pkg/util/fields.go @@ -19,3 +19,25 @@ func ToFieldElements(ints []*big.Int) []*fr.Element { // Done. return elements } + +// Pow takes a given value to the power n. +func Pow(val *fr.Element, n uint64) { + if n == 0 { + val.SetOne() + } else if n > 1 { + m := n / 2 + // Check for odd case + if n%2 == 1 { + var tmp fr.Element + // Clone value + tmp.Set(val) + Pow(val, m) + val.Square(val) + val.Mul(val, &tmp) + } else { + // Even case is easy + Pow(val, m) + val.Square(val) + } + } +} diff --git a/testdata/basic_10.accepts b/testdata/basic_10.accepts new file mode 100644 index 0000000..4b5611d --- /dev/null +++ b/testdata/basic_10.accepts @@ -0,0 +1,18 @@ +{ "X": [], "Y": [] } +{ "X": [0], "Y": [0] } +{ "X": [1], "Y": [1] } +{ "X": [2], "Y": [4] } +{ "X": [3], "Y": [9] } +{ "X": [4], "Y": [16] } +{ "X": [5], "Y": [25] } +{ "X": [256], "Y": [65536] } +;; +{ "X": [0,0], "Y": [0,0] } +{ "X": [0,1], "Y": [0,1] } +{ "X": [1,0], "Y": [1,0] } +{ "X": [1,1], "Y": [1,1] } +;; +{ "X": [1,1], "Y": [1,1] } +{ "X": [1,2], "Y": [1,4] } +{ "X": [2,1], "Y": [4,1] } +{ "X": [2,2], "Y": [4,4] } diff --git a/testdata/basic_10.lisp b/testdata/basic_10.lisp new file mode 100644 index 0000000..12a4f44 --- /dev/null +++ b/testdata/basic_10.lisp @@ -0,0 +1,4 @@ +(column X) +(column Y) +;; Y == X*X +(vanish c1 (- Y (^ X 2))) diff --git a/testdata/basic_10.rejects b/testdata/basic_10.rejects new file mode 100644 index 0000000..8115870 --- /dev/null +++ b/testdata/basic_10.rejects @@ -0,0 +1,12 @@ +{ "X": [0], "Y": [1] } +{ "X": [0], "Y": [2] } +{ "X": [1], "Y": [0] } +{ "X": [1], "Y": [2] } +{ "X": [2], "Y": [0] } +{ "X": [2], "Y": [1] } +{ "X": [2], "Y": [2] } +{ "X": [2], "Y": [3] } +{ "X": [1,1], "Y": [1,2] } +{ "X": [1,1], "Y": [2,1] } +{ "X": [2,2], "Y": [2,4] } +{ "X": [2,2], "Y": [4,2] }