Skip to content

Commit

Permalink
Merge pull request #171 from Consensys/169-incorrect-out-of-bounds-ca…
Browse files Browse the repository at this point in the history
…lculation-for-padding-values

fix: Incorrect Out-of-Bounds for Padding Values
  • Loading branch information
DavePearce authored Jun 17, 2024
2 parents 28dbfe5 + f6cddd8 commit 407b8c5
Show file tree
Hide file tree
Showing 35 changed files with 492 additions and 159 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
rev: v1.0.0-rc.1
hooks:
- id: go-test-mod
args: [ --run=Test_ ]
- repo: https://github.com/golangci/golangci-lint
rev: v1.57.1
hooks:
Expand Down
12 changes: 0 additions & 12 deletions pkg/air/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,8 @@ import (
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element {
val := tbl.GetByName(e.Column, k+e.Shift)
// Sanity check value is not nil
if val == nil {
// Indicates an out-of-bounds access of some kind. Note that this is
// fine and expected under normal conditions. For example, when a
// constraint accesses a row which doesn't exist (e.g. via a shift).
return nil
}

var clone fr.Element

// Clone original value
return clone.Set(val)
}
Expand Down Expand Up @@ -65,10 +57,6 @@ func evalExprsAt(k int, tbl table.Trace, exprs []Expr, fn func(*fr.Element, *fr.
// Continue evaluating the rest
for i := 1; i < len(exprs); i++ {
ith := exprs[i].EvalAt(k, tbl)
if ith == nil {
return ith
}

fn(val, ith)
}

Expand Down
46 changes: 13 additions & 33 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
// expressed within a polynomial but can be computed externally (e.g. during
// trace expansion).
type Expr interface {
util.Boundable
// EvalAt evaluates this expression in a given tabular context. Observe that
// if this expression is *undefined* within this context then it returns
// "nil". An expression can be undefined for several reasons: firstly, if
Expand All @@ -34,10 +35,6 @@ type Expr interface {

// Equate one expression with another
Equate(Expr) Expr

// Determine the maximum shift in this expression in either the negative
// (left) or positive direction (right).
MaxShift() util.Pair[uint, uint]
}

// Add represents the sum over zero or more expressions.
Expand All @@ -55,9 +52,9 @@ func (p *Add) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Add) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }
func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// Sub represents the subtraction over zero or more expressions.
type Sub struct{ Args []Expr }
Expand All @@ -74,9 +71,9 @@ func (p *Sub) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Sub) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }
func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// Mul represents the product over zero or more expressions.
type Mul struct{ Args []Expr }
Expand All @@ -93,9 +90,9 @@ func (p *Mul) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Mul) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }
func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// Constant represents a constant value within an expression.
type Constant struct{ Value *fr.Element }
Expand Down Expand Up @@ -135,9 +132,9 @@ func (p *Constant) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Constant) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// Bounds returns max shift in either the negative (left) or positive
// direction (right). A constant has zero shift.
func (p *Constant) MaxShift() util.Pair[uint, uint] { return util.NewPair[uint, uint](0, 0) }
func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND }

// ColumnAccess represents reading the value held at a given column in the
// tabular context. Furthermore, the current row maybe shifted up (or down) by
Expand Down Expand Up @@ -168,30 +165,13 @@ func (p *ColumnAccess) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}
// Equate one expression with another (equivalent to subtraction).
func (p *ColumnAccess) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *ColumnAccess) MaxShift() util.Pair[uint, uint] {
func (p *ColumnAccess) Bounds() util.Bounds {
if p.Shift >= 0 {
// Positive shift
return util.NewPair[uint, uint](0, uint(p.Shift))
return util.NewBounds(0, uint(p.Shift))
}
// Negative shift
return util.NewPair[uint, uint](uint(-p.Shift), 0)
}

// ==========================================================================
// Helpers
// ==========================================================================

func maxShiftOfArray(args []Expr) util.Pair[uint, uint] {
neg := uint(0)
pos := uint(0)

for _, e := range args {
mx := e.MaxShift()
neg = max(neg, mx.Left)
pos = max(pos, mx.Right)
}
// Done
return util.NewPair(neg, pos)
return util.NewBounds(uint(-p.Shift), 0)
}
4 changes: 2 additions & 2 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,11 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
}

// Add delta column data
tr.AddColumn(deltaName, delta)
tr.AddColumn(deltaName, delta, &zero)
// Add bit column data
for i := 0; i < ncols; i++ {
bitName := fmt.Sprintf("%s:%d", prefix, i)
tr.AddColumn(bitName, bit[i])
tr.AddColumn(bitName, bit[i], &zero)
}
// Done.
return nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element {
return inv.Inverse(val)
}

// MaxShift returns max shift in either the negative (left) or positive
// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (e *Inverse) MaxShift() util.Pair[uint, uint] { return e.Expr.MaxShift() }
func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }

func (e *Inverse) String() string {
return fmt.Sprintf("(inv %s)", e.Expr)
Expand Down
10 changes: 0 additions & 10 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,6 @@ func (p *Schema) RequiredSpillage() uint {
return mx
}

// ApplyPadding adds n items of padding to each column of the trace.
// Padding values are placed either at the front or the back of a given
// column, depending on their interpretation.
func (p *Schema) ApplyPadding(n uint, tr table.Trace) {
tr.Pad(n, func(j int) *fr.Element {
// Extract front value to use for padding.
return tr.GetByIndex(j, 0)
})
}

// IsInputTrace determines whether a given input trace is a suitable
// input (i.e. non-expanded) trace for this schema. Specifically, the
// input trace must contain a matching column for each non-synthetic
Expand Down
6 changes: 3 additions & 3 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (tab
// Apply spillage
if cfg.spillage >= 0 {
// Apply user-specified spillage
table.FrontPadWithZeros(uint(cfg.spillage), tr)
tr.Pad(uint(cfg.spillage))
} else {
// Apply default inferred spillage
table.FrontPadWithZeros(schema.RequiredSpillage(), tr)
tr.Pad(schema.RequiredSpillage())
}
// Expand trace
if err := schema.ExpandTrace(tr); err != nil {
Expand All @@ -180,7 +180,7 @@ func checkTrace(tr *table.ArrayTrace, schema table.Schema, cfg checkConfig) (tab
// Prevent interference
ptr := tr.Clone()
// Apply padding
schema.ApplyPadding(n, ptr)
ptr.Pad(n)
// Check whether accepted or not.
if err := schema.Accepts(ptr); err != nil {
return ptr, err
Expand Down
7 changes: 0 additions & 7 deletions pkg/hir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ import (
// out-of-bounds.
func (e *ColumnAccess) EvalAllAt(k int, tbl table.Trace) []*fr.Element {
val := tbl.GetByName(e.Column, k+e.Shift)
// Sanity check value is not nil
if val == nil {
// Indicates an out-of-bounds access of some kind. Note that this is
// fine and expected under normal conditions. For example, when a
// constraint accesses a row which doesn't exist (e.g. via a shift).
return nil
}

var clone fr.Element
// Clone original value
Expand Down
57 changes: 57 additions & 0 deletions pkg/hir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/mir"
"github.com/consensys/go-corset/pkg/table"
"github.com/consensys/go-corset/pkg/util"
)

// ============================================================================
Expand All @@ -15,6 +16,7 @@ import (
// in the AIR level. For example, an "if" expression at this level will be
// "compiled out" into one or more expressions at the MIR level.
type Expr interface {
util.Boundable
// LowerTo lowers this expression into the Mid-Level Intermediate
// Representation. Observe that a single expression at this
// level can expand into *multiple* expressions at the MIR
Expand All @@ -34,18 +36,38 @@ type Expr interface {
// Add represents the sum over zero or more expressions.
type Add struct{ Args []Expr }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// Sub represents the subtraction over zero or more expressions.
type Sub struct{ Args []Expr }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// Mul represents the product over zero or more expressions.
type Mul struct{ Args []Expr }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// List represents a block of zero or more expressions.
type List struct{ Args []Expr }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *List) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// Constant represents a constant value within an expression.
type Constant struct{ Val *fr.Element }

// Bounds returns max shift in either the negative (left) or positive
// direction (right). A constant has zero shift.
func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND }

// IfZero returns the (optional) true branch when the condition evaluates to zero, and
// the (optional false branch otherwise.
type IfZero struct {
Expand All @@ -57,10 +79,34 @@ type IfZero struct {
FalseBranch Expr
}

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *IfZero) Bounds() util.Bounds {
c := p.Condition.Bounds()
// Get bounds for true branch (if applicable)
if p.TrueBranch != nil {
tbounds := p.TrueBranch.Bounds()
c.Union(&tbounds)
}
// Get bounds for false branch (if applicable)
if p.FalseBranch != nil {
fbounds := p.FalseBranch.Bounds()
c.Union(&fbounds)
}
// Done
return c
}

// Normalise reduces the value of an expression to either zero (if it was zero)
// or one (otherwise).
type Normalise struct{ Arg Expr }

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *Normalise) Bounds() util.Bounds {
return p.Arg.Bounds()
}

// ColumnAccess represents reading the value held at a given column in the
// tabular context. Furthermore, the current row maybe shifted up (or down) by
// a given amount. Suppose we are evaluating a constraint on row k=5 which
Expand All @@ -71,3 +117,14 @@ type ColumnAccess struct {
Column string
Shift int
}

// Bounds returns max shift in either the negative (left) or positive
// direction (right).
func (p *ColumnAccess) Bounds() util.Bounds {
if p.Shift >= 0 {
// Positive shift
return util.NewBounds(0, uint(p.Shift))
}
// Negative shift
return util.NewBounds(uint(-p.Shift), 0)
}
15 changes: 5 additions & 10 deletions pkg/hir/schema.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package hir

import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/mir"
"github.com/consensys/go-corset/pkg/table"
"github.com/consensys/go-corset/pkg/util"
Expand Down Expand Up @@ -35,6 +34,11 @@ func (p ZeroArrayTest) String() string {
return p.Expr.String()
}

// Bounds determines the bounds for this zero test.
func (p ZeroArrayTest) Bounds() util.Bounds {
return p.Expr.Bounds()
}

// DataColumn captures the essence of a data column at AIR level.
type DataColumn = *table.DataColumn[table.Type]

Expand Down Expand Up @@ -109,15 +113,6 @@ func (p *Schema) RequiredSpillage() uint {
return uint(1)
}

// ApplyPadding adds n items of padding to each column of the trace.
// Padding values are placed either at the front or the back of a given
// column, depending on their interpretation.
func (p *Schema) ApplyPadding(n uint, tr table.Trace) {
tr.Pad(n, func(j int) *fr.Element {
return tr.GetByIndex(j, 0)
})
}

// GetDeclaration returns the ith declaration in this schema.
func (p *Schema) GetDeclaration(index int) table.Declaration {
ith := util.FlatArrayIndexOf_4(index, p.dataColumns, p.permutations, p.vanishing, p.assertions)
Expand Down
15 changes: 1 addition & 14 deletions pkg/mir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ import (
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element {
val := tbl.GetByName(e.Column, k+e.Shift)
// Sanity check value is not nil
if val == nil {
// Indicates an out-of-bounds access of some kind. Note that this is
// fine and expected under normal conditions. For example, when a
// constraint accesses a row which doesn't exist (e.g. via a shift).
return nil
}

var clone fr.Element
// Clone original value
Expand Down Expand Up @@ -52,7 +45,7 @@ func (e *Normalise) EvalAt(k int, tbl table.Trace) *fr.Element {
// Check whether argument evaluates to zero or not.
val := e.Arg.EvalAt(k, tbl)
// Normalise value (if necessary)
if val != nil && !val.IsZero() {
if !val.IsZero() {
val.SetOne()
}
// Done
Expand All @@ -71,15 +64,9 @@ func (e *Sub) EvalAt(k int, tbl table.Trace) *fr.Element {
func evalExprsAt(k int, tbl table.Trace, exprs []Expr, fn func(*fr.Element, *fr.Element)) *fr.Element {
// Evaluate first argument
val := exprs[0].EvalAt(k, tbl)
if val == nil {
return nil
}
// Continue evaluating the rest
for i := 1; i < len(exprs); i++ {
ith := exprs[i].EvalAt(k, tbl)
if ith == nil {
return ith
}

fn(val, ith)
}
Expand Down
Loading

0 comments on commit 407b8c5

Please sign in to comment.