Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Accounting for Spillage #170

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package air
import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/table"
"github.com/consensys/go-corset/pkg/util"
)

// Expr represents an expression in the Arithmetic Intermediate Representation
Expand Down Expand Up @@ -33,6 +34,10 @@ type Expr interface {

// Equate one expression with another
Equate(Expr) Expr

// Determine the maximum shift in this expression in either the negative
// (left) or positive direction (right).
MaxShift() util.Pair[uint, uint]
}

// Add represents the sum over zero or more expressions.
Expand All @@ -50,6 +55,10 @@ func (p *Add) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *Add) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }

// Sub represents the subtraction over zero or more expressions.
type Sub struct{ Args []Expr }

Expand All @@ -65,6 +74,10 @@ func (p *Sub) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *Sub) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }

// Mul represents the product over zero or more expressions.
type Mul struct{ Args []Expr }

Expand All @@ -80,6 +93,10 @@ func (p *Mul) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *Mul) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }

// Constant represents a constant value within an expression.
type Constant struct{ Value *fr.Element }

Expand Down Expand Up @@ -118,6 +135,10 @@ func (p *Constant) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Constant) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right). A constant has zero shift.
func (p *Constant) MaxShift() util.Pair[uint, uint] { return util.NewPair[uint, uint](0, 0) }

// ColumnAccess represents reading the value held at a given column in the
// tabular context. Furthermore, the current row maybe shifted up (or down) by
// a given amount. Suppose we are evaluating a constraint on row k=5 which
Expand Down Expand Up @@ -146,3 +167,31 @@ func (p *ColumnAccess) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}

// Equate one expression with another (equivalent to subtraction).
func (p *ColumnAccess) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *ColumnAccess) MaxShift() util.Pair[uint, uint] {
if p.Shift >= 0 {
// Positive shift
return util.NewPair[uint, uint](0, uint(p.Shift))
}
// Negative shift
return util.NewPair[uint, uint](uint(-p.Shift), 0)
}

// ==========================================================================
// Helpers
// ==========================================================================

func maxShiftOfArray(args []Expr) util.Pair[uint, uint] {
neg := uint(0)
pos := uint(0)

for _, e := range args {
mx := e.MaxShift()
neg = max(neg, mx.Left)
pos = max(pos, mx.Right)
}
// Done
return util.NewPair(neg, pos)
}
3 changes: 2 additions & 1 deletion pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) {
// Construct X == (X:0 * 1) + ... + (X:n * 2^n)
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
schema.AddVanishingConstraint(col, nil, eq)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", col, nbits), nil, eq)
// Finally, add the necessary byte decomposition computation.
schema.AddComputation(table.NewByteDecomposition(col, nbits))
}
Expand Down
12 changes: 9 additions & 3 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ type lexicographicSortExpander struct {
bitwidth uint
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding.
func (p *lexicographicSortExpander) RequiredSpillage() uint {
return uint(0)
}

// Accepts checks whether a given trace has the necessary columns
func (p *lexicographicSortExpander) Accepts(tr table.Trace) error {
prefix := constructLexicographicSortingPrefix(p.columns, p.signs)
Expand Down Expand Up @@ -194,14 +200,14 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
bit[i] = make([]*fr.Element, nrows)
}

for i := 0; i < nrows; i++ {
for i := uint(0); i < nrows; i++ {
set := false
// Initialise delta to zero
delta[i] = &zero
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
prev := tr.GetByName(p.columns[j], i-1)
curr := tr.GetByName(p.columns[j], i)
prev := tr.GetByName(p.columns[j], int(i-1))
curr := tr.GetByName(p.columns[j], int(i))

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element
Expand Down
5 changes: 5 additions & 0 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/air"
"github.com/consensys/go-corset/pkg/table"
"github.com/consensys/go-corset/pkg/util"
)

// Normalise constructs an expression representing the normalised value of e.
Expand Down Expand Up @@ -73,6 +74,10 @@ func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element {
return inv.Inverse(val)
}

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (e *Inverse) MaxShift() util.Pair[uint, uint] { return e.Expr.MaxShift() }

func (e *Inverse) String() string {
return fmt.Sprintf("(inv %s)", e.Expr)
}
40 changes: 33 additions & 7 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,38 @@ func (p *Schema) HasColumn(name string) bool {
return false
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding. Spillage can
// only arise from computations as this is where values outside of the user's
// control are determined.
func (p *Schema) RequiredSpillage() uint {
// Ensures always at least one row of spillage (referred to as the "initial
// padding row")
mx := uint(1)
// Determine if any more spillage required
for _, c := range p.computations {
mx = max(mx, c.RequiredSpillage())
}

return mx
}

// ApplyPadding adds n items of padding to each column of the trace.
// Padding values are placed either at the front or the back of a given
// column, depending on their interpretation.
func (p *Schema) ApplyPadding(n uint, tr table.Trace) {
tr.Pad(n, func(j int) *fr.Element {
// Extract front value to use for padding.
return tr.GetByIndex(j, 0)
})
}

// IsInputTrace determines whether a given input trace is a suitable
// input (i.e. non-expanded) trace for this schema. Specifically, the
// input trace must contain a matching column for each non-synthetic
// column in this trace.
func (p *Schema) IsInputTrace(tr table.Trace) error {
count := 0
count := uint(0)

for _, c := range p.dataColumns {
if !c.Synthetic && !tr.HasColumn(c.Name) {
Expand All @@ -112,8 +138,8 @@ func (p *Schema) IsInputTrace(tr table.Trace) error {
// Determine the unknown columns for error reporting.
unknown := make([]string, 0)

for i := 0; i < tr.Width(); i++ {
n := tr.ColumnName(i)
for i := uint(0); i < tr.Width(); i++ {
n := tr.ColumnName(int(i))
if !p.HasColumn(n) {
unknown = append(unknown, n)
}
Expand All @@ -132,7 +158,7 @@ func (p *Schema) IsInputTrace(tr table.Trace) error {
// output trace must contain a matching column for each column in this
// trace (synthetic or otherwise).
func (p *Schema) IsOutputTrace(tr table.Trace) error {
count := 0
count := uint(0)

for _, c := range p.dataColumns {
if !tr.HasColumn(c.Name) {
Expand All @@ -153,7 +179,9 @@ func (p *Schema) IsOutputTrace(tr table.Trace) error {
// AddColumn appends a new data column which is either synthetic or
// not. A synthetic column is one which has been introduced by the
// process of lowering from HIR / MIR to AIR. That is, it is not a
// column which was original specified by the user.
// column which was original specified by the user. Columns also support a
// "padding sign", which indicates whether padding should occur at the front
// (positive sign) or the back (negative sign).
func (p *Schema) AddColumn(name string, synthetic bool) {
// NOTE: the air level has no ability to enforce the type specified for a
// given column.
Expand Down Expand Up @@ -219,8 +247,6 @@ func (p *Schema) Accepts(trace table.Trace) error {
// columns. Observe that computed columns have to be computed in the correct
// order.
func (p *Schema) ExpandTrace(tr table.Trace) error {
// Insert initial padding row
table.PadTrace(1, tr)
// Execute all computations
for _, c := range p.computations {
err := c.ExpandTrace(tr)
Expand Down
Loading