Skip to content

Commit

Permalink
Merge pull request #211 from Consensys/207-support-exponent-n-c
Browse files Browse the repository at this point in the history
Implement `Exp` expression at HIR/MIR level
  • Loading branch information
DavePearce authored Jul 5, 2024
2 parents 32ece2d + f034b5d commit ba78071
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 1 deletion.
18 changes: 18 additions & 0 deletions pkg/binfile/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,24 @@ func (e *jsonExprFuncall) ToHir(schema *hir.Schema) hir.Expr {
return &hir.Mul{Args: args}
case "VectorSub", "Sub":
return &hir.Sub{Args: args}
case "Exp":
if len(args) != 2 {
panic(fmt.Sprintf("incorrect number of arguments for Exp (%d)", len(args)))
}

c, ok := args[1].(*hir.Constant)

if !ok {
panic(fmt.Sprintf("constant power expected for Exp, got %s", args[1].String()))
} else if !c.Val.IsUint64() {
panic("constant power too large for Exp")
}

var k big.Int
// Convert power to uint64
c.Val.BigInt(&k)
// Done
return &hir.Exp{Arg: args[0], Pow: k.Uint64()}
case "IfZero":
if len(args) == 2 {
return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: nil}
Expand Down
12 changes: 12 additions & 0 deletions pkg/hir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// EvalAllAt evaluates a column access at a given row in a trace, which returns the
Expand Down Expand Up @@ -38,6 +39,17 @@ func (e *Mul) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalAllAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Exp) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
vals := e.Arg.EvalAllAt(k, tr)
for _, v := range vals {
util.Pow(v, e.Pow)
}

return vals
}

// EvalAllAt evaluates a conditional at a given row in a trace by first evaluating
// its condition at that row. If that condition is zero then the true branch
// (if applicable) is evaluated; otherwise if the condition is non-zero then
Expand Down
20 changes: 20 additions & 0 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// ============================================================================
// Exponentiation
// ============================================================================

// Exp represents the a given value taken to a power.
type Exp struct {
Arg Expr
Pow uint64
}

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Exp) Bounds() util.Bounds { return p.Arg.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Exp) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// ============================================================================
// List
// ============================================================================
Expand Down
18 changes: 18 additions & 0 deletions pkg/hir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ func (e *ColumnAccess) LowerTo(schema *mir.Schema) []mir.Expr {
return lowerTo(e, schema)
}

// LowerTo lowers an exponent expression to the MIR level. This requires expanding
// the argument andn lowering it. Furthermore, conditionals are "lifted" to
// the top.
func (e *Exp) LowerTo(schema *mir.Schema) []mir.Expr {
return lowerTo(e, schema)
}

// LowerTo lowers a product expression to the MIR level. This requires expanding
// the arguments, then lowering them. Furthermore, conditionals are "lifted" to
// the top.
Expand Down Expand Up @@ -182,6 +189,8 @@ func lowerCondition(e Expr, schema *mir.Schema) mir.Expr {
return lowerConditions(p.Args, schema)
} else if p, ok := e.(*Normalise); ok {
return lowerCondition(p.Arg, schema)
} else if p, ok := e.(*Exp); ok {
return lowerCondition(p.Arg, schema)
} else if p, ok := e.(*IfZero); ok {
return lowerIfZeroCondition(p, schema)
} else if p, ok := e.(*Sub); ok {
Expand Down Expand Up @@ -248,6 +257,8 @@ func lowerBody(e Expr, schema *mir.Schema) mir.Expr {
return &mir.ColumnAccess{Column: p.Column, Shift: p.Shift}
} else if p, ok := e.(*Mul); ok {
return &mir.Mul{Args: lowerBodies(p.Args, schema)}
} else if p, ok := e.(*Exp); ok {
return &mir.Exp{Arg: lowerBody(p.Arg, schema), Pow: p.Pow}
} else if p, ok := e.(*Normalise); ok {
return &mir.Normalise{Arg: lowerBody(p.Arg, schema)}
} else if p, ok := e.(*IfZero); ok {
Expand Down Expand Up @@ -306,6 +317,13 @@ func expand(e Expr) []Expr {
ees = append(ees, expand(arg)...)
}

return ees
} else if p, ok := e.(*Exp); ok {
ees := expand(p.Arg)
for i, ee := range ees {
ees[i] = &Exp{ee, p.Pow}
}

return ees
} else if p, ok := e.(*Normalise); ok {
ees := expand(p.Arg)
Expand Down
21 changes: 21 additions & 0 deletions pkg/hir/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"errors"
"fmt"
"math/big"
"strconv"
"strings"
"unicode"
Expand Down Expand Up @@ -77,6 +78,7 @@ func newHirParser(srcmap *sexp.SourceMap[sexp.SExp]) *hirParser {
p.AddRecursiveRule("-", subParserRule)
p.AddRecursiveRule("*", mulParserRule)
p.AddRecursiveRule("~", normParserRule)
p.AddRecursiveRule("^", powParserRule)
p.AddRecursiveRule("if", ifParserRule)
p.AddRecursiveRule("ifnot", ifNotParserRule)
p.AddRecursiveRule("begin", beginParserRule)
Expand Down Expand Up @@ -543,6 +545,25 @@ func shiftParserRule(parser *hirParser) func(string, string) (Expr, error) {
}
}

func powParserRule(args []Expr) (Expr, error) {
var k big.Int

if len(args) != 2 {
return nil, errors.New("incorrect number of arguments")
}

c, ok := args[1].(*Constant)
if !ok {
return nil, errors.New("expected constant power")
} else if !c.Val.IsUint64() {
return nil, errors.New("constant power too large")
}
// Convert power to uint64
c.Val.BigInt(&k)
// Done
return &Exp{Arg: args[0], Pow: k.Uint64()}, nil
}

func normParserRule(args []Expr) (Expr, error) {
if len(args) != 1 {
return nil, errors.New("incorrect number of arguments")
Expand Down
4 changes: 4 additions & 0 deletions pkg/hir/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (e *Mul) String() string {
return naryString("*", e.Args)
}

func (e *Exp) String() string {
return fmt.Sprintf("(^ %s %d)", e.Arg, e.Pow)
}

func (e *Normalise) String() string {
return fmt.Sprintf("(~ %s)", e.Arg)
}
Expand Down
12 changes: 12 additions & 0 deletions pkg/mir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mir
import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
)

// EvalAt evaluates a column access at a given row in a trace, which returns the
Expand Down Expand Up @@ -38,6 +39,17 @@ func (e *Mul) EvalAt(k int, tr trace.Trace) *fr.Element {
return evalExprsAt(k, tr, e.Args, fn)
}

// EvalAt evaluates a product at a given row in a trace by first evaluating all of
// its arguments at that row.
func (e *Exp) EvalAt(k int, tr trace.Trace) *fr.Element {
// Check whether argument evaluates to zero or not.
val := e.Arg.EvalAt(k, tr)
// Compute exponent
util.Pow(val, e.Pow)
// Done
return val
}

// EvalAt evaluates the normalisation of some expression by first evaluating
// that expression. Then, zero is returned if the result is zero; otherwise one
// is returned.
Expand Down
20 changes: 20 additions & 0 deletions pkg/mir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ func (p *Mul) Context(schema sc.Schema) trace.Context {
return sc.JoinContexts[Expr](p.Args, schema)
}

// ============================================================================
// Exponentiation
// ============================================================================

// Exp represents the a given value taken to a power.
type Exp struct {
Arg Expr
Pow uint64
}

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Exp) Bounds() util.Bounds { return p.Arg.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Exp) Context(schema sc.Schema) trace.Context {
return p.Arg.Context(schema)
}

// ============================================================================
// Constant
// ============================================================================
Expand Down
16 changes: 16 additions & 0 deletions pkg/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ func (e *Mul) LowerTo(schema *air.Schema) air.Expr {
return &air.Mul{Args: lowerExprs(e.Args, schema)}
}

// LowerTo lowers an exponent expression to the AIR level by lowering the
// argument, and then constructing a multiplication. This is because the AIR
// level does not support an explicit exponent operator.
func (e *Exp) LowerTo(schema *air.Schema) air.Expr {
// Lower the expression being raised
le := e.Arg.LowerTo(schema)
// Multiply it out k times
es := make([]air.Expr, e.Pow)
//
for i := uint64(0); i < e.Pow; i++ {
es[i] = le
}
// Done
return &air.Mul{Args: es}
}

// LowerTo lowers a normalise expression to the AIR level by "compiling it out"
// using a computed column.
func (p *Normalise) LowerTo(schema *air.Schema) air.Expr {
Expand Down
4 changes: 4 additions & 0 deletions pkg/mir/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (e *Normalise) String() string {
return fmt.Sprintf("(~ %s)", e.Arg)
}

func (e *Exp) String() string {
return fmt.Sprintf("(^ %s %d)", e.Arg, e.Pow)
}

func naryString(operator string, exprs []Expr) string {
// This should be generalised and moved into common?
var rs string
Expand Down
8 changes: 7 additions & 1 deletion pkg/test/ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ func Test_Basic_09(t *testing.T) {
Check(t, "basic_09")
}

func Test_Basic_10(t *testing.T) {
Check(t, "basic_10")
}

// ===================================================================
// Domain Tests
// ===================================================================
Expand Down Expand Up @@ -420,12 +424,14 @@ func TestSlow_Mxp(t *testing.T) {

// Determines the maximum amount of padding to use when testing. Specifically,
// every trace is tested with varying amounts of padding upto this value.
const MAX_PADDING uint = 1
const MAX_PADDING uint = 5

// For a given set of constraints, check that all traces which we
// expect to be accepted are accepted, and all traces that we expect
// to be rejected are rejected.
func Check(t *testing.T, test string) {
// Enable testing each trace in parallel
t.Parallel()
// Read constraints file
bytes, err := os.ReadFile(fmt.Sprintf("%s/%s.lisp", TestDir, test))
// Check test file read ok
Expand Down
71 changes: 71 additions & 0 deletions pkg/test/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package test

import (
"math/big"
"testing"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/util"
)

const POW_BASE_MAX uint = 65536
const POW_BASE_INC uint = 8

func Test_Pow_01(t *testing.T) {
PowCheckLoop(t, 0)
}

func Test_Pow_02(t *testing.T) {
PowCheckLoop(t, 1)
}

func Test_Pow_03(t *testing.T) {
PowCheckLoop(t, 2)
}

func Test_Pow_04(t *testing.T) {
PowCheckLoop(t, 3)
}

func Test_Pow_05(t *testing.T) {
PowCheckLoop(t, 4)
}

func Test_Pow_06(t *testing.T) {
PowCheckLoop(t, 5)
}

func Test_Pow_07(t *testing.T) {
PowCheckLoop(t, 6)
}

func Test_Pow_08(t *testing.T) {
PowCheckLoop(t, 7)
}

func PowCheckLoop(t *testing.T, first uint) {
// Enable parallel testing
t.Parallel()
// Run through the loop
for i := first; i < POW_BASE_MAX; i += POW_BASE_INC {
for j := uint64(0); j < 256; j++ {
PowCheck(t, i, j)
}
}
}

// Check pow computed correctly. This is done by comparing against the existing
// gnark function.
func PowCheck(t *testing.T, base uint, pow uint64) {
k := big.NewInt(int64(pow))
v1 := fr.NewElement(uint64(base))
v2 := fr.NewElement(uint64(base))
// V1 computed using our optimised method
util.Pow(&v1, pow)
// V2 computed using existing gnark function
v2.Exp(v2, k)
// Final sanity check
if v1.Cmp(&v2) != 0 {
t.Errorf("Pow(%d,%d)=%s (not %s)", base, pow, v1.String(), v2.String())
}
}
22 changes: 22 additions & 0 deletions pkg/util/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,25 @@ func ToFieldElements(ints []*big.Int) []*fr.Element {
// Done.
return elements
}

// Pow takes a given value to the power n.
func Pow(val *fr.Element, n uint64) {
if n == 0 {
val.SetOne()
} else if n > 1 {
m := n / 2
// Check for odd case
if n%2 == 1 {
var tmp fr.Element
// Clone value
tmp.Set(val)
Pow(val, m)
val.Square(val)
val.Mul(val, &tmp)
} else {
// Even case is easy
Pow(val, m)
val.Square(val)
}
}
}
Loading

0 comments on commit ba78071

Please sign in to comment.