Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply Constant Propagation on Lowering to AIR #216

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions pkg/mir/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package mir

import (
"fmt"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/util"
)

// ApplyConstantPropagation simply collapses constant expressions down to single
// values. For example, "(+ 1 2)" would be collapsed down to "3".
func applyConstantPropagation(e Expr) Expr {
if p, ok := e.(*Add); ok {
return applyConstantPropagationAdd(p.Args)
} else if _, ok := e.(*Constant); ok {
return e
} else if _, ok := e.(*ColumnAccess); ok {
return e
} else if p, ok := e.(*Mul); ok {
return applyConstantPropagationMul(p.Args)
} else if p, ok := e.(*Exp); ok {
return applyConstantPropagationExp(p.Arg, p.Pow)
} else if p, ok := e.(*Normalise); ok {
return applyConstantPropagationNorm(p.Arg)
} else if p, ok := e.(*Sub); ok {
return applyConstantPropagationSub(p.Args)
}
// Should be unreachable
panic(fmt.Sprintf("unknown expression: %s", e.String()))
}

func applyConstantPropagationAdd(es []Expr) Expr {
var zero = fr.NewElement(0)
sum := &zero
rs := make([]Expr, len(es))
//
for i, e := range es {
rs[i] = applyConstantPropagation(e)
// Check for constant
c, ok := rs[i].(*Constant)
// Try to continue sum
if ok && sum != nil {
sum.Add(sum, c.Value)
} else {
sum = nil
}
}
//
if sum != nil {
// Propagate constant
return &Constant{sum}
}
// Done
return &Add{rs}
}

func applyConstantPropagationSub(es []Expr) Expr {
var sum *fr.Element = nil

rs := make([]Expr, len(es))
//
for i, e := range es {
rs[i] = applyConstantPropagation(e)
// Check for constant
c, ok := rs[i].(*Constant)
// Try to continue sum
if ok && i == 0 {
var val fr.Element
// Clone value
val.Set(c.Value)
sum = &val
} else if ok && sum != nil {
sum.Sub(sum, c.Value)
} else {
sum = nil
}
}
//
if sum != nil {
// Propagate constant
return &Constant{sum}
}
// Done
return &Sub{rs}
}

func applyConstantPropagationMul(es []Expr) Expr {
var one = fr.NewElement(1)
prod := &one
rs := make([]Expr, len(es))
//
for i, e := range es {
rs[i] = applyConstantPropagation(e)
// Check for constant
c, ok := rs[i].(*Constant)
//
if ok && c.Value.IsZero() {
// No matter what, outcome is zero.
return &Constant{c.Value}
} else if ok && prod != nil {
// Continue building constant
prod.Mul(prod, c.Value)
} else {
prod = nil
}
}
// Attempt to propagate constant
if prod != nil {
return &Constant{prod}
}
//
return &Mul{rs}
}

func applyConstantPropagationExp(arg Expr, pow uint64) Expr {
arg = applyConstantPropagation(arg)
//
if c, ok := arg.(*Constant); ok {
var val fr.Element
// Clone value
val.Set(c.Value)
// Compute exponent (in place)
util.Pow(&val, pow)
// Done
return &Constant{&val}
}
//
return &Exp{arg, pow}
}

func applyConstantPropagationNorm(arg Expr) Expr {
arg = applyConstantPropagation(arg)
//
if c, ok := arg.(*Constant); ok {
var val fr.Element
// Clone value
val.Set(c.Value)
// Normalise (in place)
if !val.IsZero() {
val.SetOne()
}
// Done
return &Constant{&val}
}
//
return &Normalise{arg}
}
6 changes: 0 additions & 6 deletions pkg/mir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package mir

import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/air"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"
Expand All @@ -16,11 +15,6 @@ import (
type Expr interface {
util.Boundable
sc.Evaluable
// Lower this expression into the Arithmetic Intermediate
// Representation. Essentially, this means eliminating
// normalising expressions by introducing new columns into the
// given table (with appropriate constraints).
LowerTo(*air.Schema) air.Expr
// String produces a string representing this as an S-Expression.
String() string
}
Expand Down
80 changes: 41 additions & 39 deletions pkg/mir/lower.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package mir

import (
"fmt"

"github.com/consensys/go-corset/pkg/air"
air_gadgets "github.com/consensys/go-corset/pkg/air/gadgets"
sc "github.com/consensys/go-corset/pkg/schema"
Expand Down Expand Up @@ -59,7 +61,7 @@ func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) {
if v, ok := c.(LookupConstraint); ok {
lowerLookupConstraintToAir(v, schema)
} else if v, ok := c.(VanishingConstraint); ok {
air_expr := v.Constraint().Expr.LowerTo(schema)
air_expr := lowerExprTo(v.Constraint().Expr, schema)
schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), air_expr)
} else if v, ok := c.(*constraint.TypeConstraint); ok {
if t := v.Type().AsUint(); t != nil {
Expand Down Expand Up @@ -95,8 +97,8 @@ func lowerLookupConstraintToAir(c LookupConstraint, schema *air.Schema) {
//
for i := 0; i < len(targets); i++ {
// Lower source and target expressions
target := c.Targets()[i].LowerTo(schema)
source := c.Sources()[i].LowerTo(schema)
target := lowerExprTo(c.Targets()[i], schema)
source := lowerExprTo(c.Sources()[i], schema)
// Expand them
targets[i] = air_gadgets.Expand(target, schema)
sources[i] = air_gadgets.Expand(source, schema)
Expand Down Expand Up @@ -155,27 +157,49 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
}
}

// LowerTo lowers a sum expression to the AIR level by lowering the arguments.
func (e *Add) LowerTo(schema *air.Schema) air.Expr {
return &air.Add{Args: lowerExprs(e.Args, schema)}
}

// LowerTo lowers a subtract expression to the AIR level by lowering the arguments.
func (e *Sub) LowerTo(schema *air.Schema) air.Expr {
return &air.Sub{Args: lowerExprs(e.Args, schema)}
// Lower an expression into the Arithmetic Intermediate Representation.
// Essentially, this means eliminating normalising expressions by introducing
// new columns into the given table (with appropriate constraints). This first
// performs constant propagation to ensure lowering is as efficient as possible.
func lowerExprTo(e1 Expr, schema *air.Schema) air.Expr {
// Apply constant propagation
e2 := applyConstantPropagation(e1)
// Lower properly
return lowerExprToInner(e2, schema)
}

// LowerTo lowers a product expression to the AIR level by lowering the arguments.
func (e *Mul) LowerTo(schema *air.Schema) air.Expr {
return &air.Mul{Args: lowerExprs(e.Args, schema)}
// Inner form is used for recursive calls and does not repeat the constant
// propagation phase.
func lowerExprToInner(e Expr, schema *air.Schema) air.Expr {
if p, ok := e.(*Add); ok {
return &air.Add{Args: lowerExprs(p.Args, schema)}
} else if p, ok := e.(*Constant); ok {
return &air.Constant{Value: p.Value}
} else if p, ok := e.(*ColumnAccess); ok {
return &air.ColumnAccess{Column: p.Column, Shift: p.Shift}
} else if p, ok := e.(*Mul); ok {
return &air.Mul{Args: lowerExprs(p.Args, schema)}
} else if p, ok := e.(*Exp); ok {
return lowerExpTo(p, schema)
} else if p, ok := e.(*Normalise); ok {
// Lower the expression being normalised
e := lowerExprToInner(p.Arg, schema)
// Construct an expression representing the normalised value of e. That is,
// an expression which is 0 when e is 0, and 1 when e is non-zero.
return air_gadgets.Normalise(e, schema)
} else if p, ok := e.(*Sub); ok {
return &air.Sub{Args: lowerExprs(p.Args, schema)}
}
// Should be unreachable
panic(fmt.Sprintf("unknown expression: %s", e.String()))
}

// LowerTo lowers an exponent expression to the AIR level by lowering the
// argument, and then constructing a multiplication. This is because the AIR
// level does not support an explicit exponent operator.
func (e *Exp) LowerTo(schema *air.Schema) air.Expr {
func lowerExpTo(e *Exp, schema *air.Schema) air.Expr {
// Lower the expression being raised
le := e.Arg.LowerTo(schema)
le := lowerExprToInner(e.Arg, schema)
// Multiply it out k times
es := make([]air.Expr, e.Pow)
//
Expand All @@ -186,35 +210,13 @@ func (e *Exp) LowerTo(schema *air.Schema) air.Expr {
return &air.Mul{Args: es}
}

// LowerTo lowers a normalise expression to the AIR level by "compiling it out"
// using a computed column.
func (p *Normalise) LowerTo(schema *air.Schema) air.Expr {
// Lower the expression being normalised
e := p.Arg.LowerTo(schema)
// Construct an expression representing the normalised value of e. That is,
// an expression which is 0 when e is 0, and 1 when e is non-zero.
return air_gadgets.Normalise(e, schema)
}

// LowerTo lowers a column access to the AIR level. This is straightforward as
// it is already in the correct form.
func (e *ColumnAccess) LowerTo(schema *air.Schema) air.Expr {
return &air.ColumnAccess{Column: e.Column, Shift: e.Shift}
}

// LowerTo lowers a constant to the AIR level. This is straightforward as it is
// already in the correct form.
func (e *Constant) LowerTo(schema *air.Schema) air.Expr {
return &air.Constant{Value: e.Value}
}

// Lower a set of zero or more MIR expressions.
func lowerExprs(exprs []Expr, schema *air.Schema) []air.Expr {
n := len(exprs)
nexprs := make([]air.Expr, n)

for i := 0; i < n; i++ {
nexprs[i] = exprs[i].LowerTo(schema)
nexprs[i] = lowerExprToInner(exprs[i], schema)
}

return nexprs
Expand Down
24 changes: 24 additions & 0 deletions pkg/test/ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,30 @@ func Test_Type_03(t *testing.T) {
Check(t, "type_03")
}

// ===================================================================
// Constant Propagation
// ===================================================================

func Test_Constant_01(t *testing.T) {
Check(t, "constant_01")
}

func Test_Constant_02(t *testing.T) {
Check(t, "constant_02")
}

func Test_Constant_03(t *testing.T) {
Check(t, "constant_03")
}

func Test_Constant_04(t *testing.T) {
Check(t, "constant_04")
}

func Test_Constant_05(t *testing.T) {
Check(t, "constant_05")
}

// ===================================================================
// Modules
// ===================================================================
Expand Down
12 changes: 12 additions & 0 deletions testdata/constant_01.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{ "X": [1], "Y": [1] }
{ "X": [1,1], "Y": [1,1] }
{ "X": [1,2], "Y": [1,2] }
{ "X": [2,1], "Y": [2,1] }
{ "X": [2,2], "Y": [2,2] }
;;
{ "X": [1,1,1], "Y": [1,1,1] }
{ "X": [1,1,2], "Y": [1,1,2] }
{ "X": [1,2,1], "Y": [1,2,1] }
{ "X": [1,2,2], "Y": [1,2,2] }
{ "X": [2,1,1], "Y": [2,1,1] }
{ "X": [2,2,1], "Y": [2,2,1] }
6 changes: 6 additions & 0 deletions testdata/constant_01.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
(column X)
(column Y)
;; X == Y + n - n
(vanish c1 (- X Y (+ 1 1) (- 0 2)))
(vanish c2 (- X Y (+ 1 1 1) (- 0 1 2)))
(vanish c3 (- X Y (+ 2 1) (- 0 2 1)))
13 changes: 13 additions & 0 deletions testdata/constant_01.rejects
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{ "X": [0], "Y": [1] }
{ "X": [1], "Y": [0] }
{ "X": [2], "Y": [1] }
{ "X": [1], "Y": [2] }
{ "X": [0,0], "Y": [0,1] }
{ "X": [0,0], "Y": [1,0] }
{ "X": [0,0], "Y": [1,1] }
{ "X": [0,1], "Y": [1,0] }
{ "X": [0,1], "Y": [1,1] }
{ "X": [1,0], "Y": [0,1] }
{ "X": [1,0], "Y": [1,1] }
{ "X": [1,1], "Y": [0,1] }
{ "X": [1,1], "Y": [1,0] }
12 changes: 12 additions & 0 deletions testdata/constant_02.accepts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{ "X": [1], "Y": [1] }
{ "X": [1,1], "Y": [1,1] }
{ "X": [1,2], "Y": [1,2] }
{ "X": [2,1], "Y": [2,1] }
{ "X": [2,2], "Y": [2,2] }
;;
{ "X": [1,1,1], "Y": [1,1,1] }
{ "X": [1,1,2], "Y": [1,1,2] }
{ "X": [1,2,1], "Y": [1,2,1] }
{ "X": [1,2,2], "Y": [1,2,2] }
{ "X": [2,1,1], "Y": [2,1,1] }
{ "X": [2,2,1], "Y": [2,2,1] }
24 changes: 24 additions & 0 deletions testdata/constant_02.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
(column X)
(column Y)
;; X*2 == Y*2
(vanish c1 (- (* X (* 2 1)) (* Y (* 1 2))))
;; X*1458 == Y*1458
(vanish c1 (- (* X (* 243 22)) (* Y (* 6 891))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22)) (* Y (* 6 891 2))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 243 2 22)) (* Y (* 6 891 2))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 22 243 2)) (* Y (* 6 891 2))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22)) (* Y (* 891 6 2))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22)) (* Y (* 2 891 6))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22)) (* Y (* 2 891 6 1))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22)) (* Y (* 2 891 6 1 1))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22 1)) (* Y (* 2 891 6))))
;; X*2916 == Y*2916
(vanish c1 (- (* X (* 2 243 22 1 1)) (* Y (* 2 891 6))))
Loading