Skip to content

Commit

Permalink
Merge pull request #205 from Consensys/110-support-interleaving-const…
Browse files Browse the repository at this point in the history
…raints

Support Interleaving Constraints
  • Loading branch information
DavePearce authored Jul 4, 2024
2 parents bd76d3a + 7c0e6d2 commit 4f5dd3e
Show file tree
Hide file tree
Showing 48 changed files with 686 additions and 228 deletions.
15 changes: 8 additions & 7 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type Add struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Add) Context(schema sc.Schema) (uint, bool) {
func (p *Add) Context(schema sc.Schema) (uint, uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -73,7 +73,7 @@ type Sub struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Sub) Context(schema sc.Schema) (uint, bool) {
func (p *Sub) Context(schema sc.Schema) (uint, uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -102,7 +102,7 @@ type Mul struct{ Args []Expr }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Mul) Context(schema sc.Schema) (uint, bool) {
func (p *Mul) Context(schema sc.Schema) (uint, uint, bool) {
return sc.JoinContexts[Expr](p.Args, schema)
}

Expand Down Expand Up @@ -154,8 +154,8 @@ func NewConstCopy(val *fr.Element) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *Constant) Context(schema sc.Schema) (uint, bool) {
return math.MaxUint, true
func (p *Constant) Context(schema sc.Schema) (uint, uint, bool) {
return math.MaxUint, math.MaxUint, true
}

// Add two expressions together, producing a third.
Expand Down Expand Up @@ -193,8 +193,9 @@ func NewColumnAccess(column uint, shift int) Expr {

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (p *ColumnAccess) Context(schema sc.Schema) (uint, bool) {
return schema.Columns().Nth(p.Column).Module(), true
func (p *ColumnAccess) Context(schema sc.Schema) (uint, uint, bool) {
col := schema.Columns().Nth(p.Column)
return col.Module(), col.LengthMultiplier(), true
}

// Add two expressions together, producing a third.
Expand Down
7 changes: 4 additions & 3 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ApplyBinaryGadget(col uint, schema *air.Schema) {
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), column.Module(), column.LengthMultiplier(), nil, X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
Expand All @@ -44,7 +44,8 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
name := column.Name()
coefficient := fr.NewElement(1)
// Add decomposition assignment
index := schema.AddAssignment(assignment.NewByteDecomposition(name, column.Module(), col, n))
index := schema.AddAssignment(
assignment.NewByteDecomposition(name, column.Module(), column.LengthMultiplier(), col, n))
// Construct Columns
for i := uint(0); i < n; i++ {
// Create Column + Constraint
Expand All @@ -60,5 +61,5 @@ func ApplyBitwidthGadget(col uint, nbits uint, schema *air.Schema) {
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), column.Module(), column.LengthMultiplier(), nil, eq)
}
5 changes: 3 additions & 2 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ func ApplyColumnSortGadget(col uint, sign bool, bitwidth uint, schema *air.Schem
deltaName = fmt.Sprintf("-%s", name)
}
// Add delta assignment
deltaIndex := schema.AddAssignment(assignment.NewComputedColumn(column.Module(), deltaName, Xdiff))
deltaIndex := schema.AddAssignment(
assignment.NewComputedColumn(column.Module(), deltaName, column.LengthMultiplier(), Xdiff))
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, column.Module(), nil, Dk.Equate(Xdiff))
schema.AddVanishingConstraint(deltaName, column.Module(), column.LengthMultiplier(), nil, Dk.Equate(Xdiff))
}
6 changes: 3 additions & 3 deletions pkg/air/gadgets/expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ func Expand(e air.Expr, schema *air.Schema) uint {
return ca.Column
}
// No optimisation, therefore expand using a computedcolumn
module := sc.DetermineEnclosingModuleOfExpression(e, schema)
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
// Determine computed column name
name := e.String()
// Look up column
index, ok := sc.ColumnIndexOf(schema, module, name)
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, e))
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, e))
}
// Construct v == [e]
v := air.NewColumnAccess(index, 0)
// Construct 1 == e/e
eq_e_v := v.Equate(e)
// Ensure (e - v) == 0, where v is value of computed column.
c_name := fmt.Sprintf("[%s]", e.String())
schema.AddVanishingConstraint(c_name, module, nil, eq_e_v)
schema.AddVanishingConstraint(c_name, module, multiplier, nil, eq_e_v)
//
return index
}
19 changes: 11 additions & 8 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint
panic("Inconsistent number of columns and signs for lexicographic sort.")
}
// Determine enclosing module for this gadget.
module := sc.DetermineEnclosingModuleOfColumns(columns, schema)
module, multiplier := sc.DetermineEnclosingModuleOfColumns(columns, schema)
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(columns, signs, schema)
// Add trace computation
deltaIndex := schema.AddAssignment(assignment.NewLexicographicSort(prefix, module, columns, signs, bitwidth))
deltaIndex := schema.AddAssignment(
assignment.NewLexicographicSort(prefix, module, multiplier, columns, signs, bitwidth))
// Construct selecto bits.
addLexicographicSelectorBits(prefix, module, deltaIndex, columns, schema)
addLexicographicSelectorBits(prefix, module, multiplier, deltaIndex, columns, schema)
// Construct delta terms
constraint := constructLexicographicDeltaConstraint(deltaIndex, columns, signs)
// Add delta constraint
deltaName := fmt.Sprintf("%s:delta", prefix)
schema.AddVanishingConstraint(deltaName, module, nil, constraint)
schema.AddVanishingConstraint(deltaName, module, multiplier, nil, constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}
Expand Down Expand Up @@ -76,7 +77,8 @@ func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *a
//
// NOTE: this implementation differs from the original corset which used an
// additional "Eq" bit to help ensure at most one selector bit was enabled.
func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, columns []uint, schema *air.Schema) {
func addLexicographicSelectorBits(prefix string, module uint, multiplier uint,
deltaIndex uint, columns []uint, schema *air.Schema) {
ncols := uint(len(columns))
// Calculate column index of first selector bit
bitIndex := deltaIndex + 1
Expand All @@ -100,7 +102,8 @@ func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, c
pterms[i] = air.NewColumnAccess(bitIndex+i, 0)
pDiff := air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1))
pName := fmt.Sprintf("%s:%d:a", prefix, i)
schema.AddVanishingConstraint(pName, module, nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
schema.AddVanishingConstraint(pName, module, multiplier,
nil, air.NewConst64(1).Sub(&air.Add{Args: pterms}).Mul(pDiff))
// (∀j<i.Bj=0) ∧ Bi=1 ==> C[k]≠C[k-1]
qDiff := Normalise(air.NewColumnAccess(columns[i], 0).Sub(air.NewColumnAccess(columns[i], -1)), schema)
qName := fmt.Sprintf("%s:%d:b", prefix, i)
Expand All @@ -112,14 +115,14 @@ func addLexicographicSelectorBits(prefix string, module uint, deltaIndex uint, c
constraint = air.NewConst64(1).Sub(&air.Add{Args: qterms}).Mul(constraint)
}

schema.AddVanishingConstraint(qName, module, nil, constraint)
schema.AddVanishingConstraint(qName, module, multiplier, nil, constraint)
}

sum := &air.Add{Args: terms}
// (sum = 0) ∨ (sum = 1)
constraint := sum.Mul(sum.Equate(air.NewConst64(1)))
name := fmt.Sprintf("%s:xor", prefix)
schema.AddVanishingConstraint(name, module, nil, constraint)
schema.AddVanishingConstraint(name, module, multiplier, nil, constraint)
}

// Construct the lexicographic delta constraint. This states that the delta
Expand Down
10 changes: 5 additions & 5 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func Normalise(e air.Expr, schema *air.Schema) air.Expr {
// ensure it really holds the inverted value.
func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Determine enclosing module.
module := sc.DetermineEnclosingModuleOfExpression(e, schema)
module, multiplier := sc.DetermineEnclosingModuleOfExpression(e, schema)
// Construct inverse computation
ie := &Inverse{Expr: e}
// Determine computed column name
Expand All @@ -39,7 +39,7 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
// Add new column (if it does not already exist)
if !ok {
// Add computed column
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, ie))
index = schema.AddAssignment(assignment.NewComputedColumn(module, name, multiplier, ie))
}

// Construct 1/e
Expand All @@ -54,10 +54,10 @@ func ApplyPseudoInverseGadget(e air.Expr, schema *air.Schema) air.Expr {
inv_e_implies_one_e_e := inv_e.Mul(one_e_e)
// Ensure (e != 0) ==> (1 == e/e)
l_name := fmt.Sprintf("[%s <=]", ie.String())
schema.AddVanishingConstraint(l_name, module, nil, e_implies_one_e_e)
schema.AddVanishingConstraint(l_name, module, multiplier, nil, e_implies_one_e_e)
// Ensure (e/e != 0) ==> (1 == e/e)
r_name := fmt.Sprintf("[%s =>]", ie.String())
schema.AddVanishingConstraint(r_name, module, nil, inv_e_implies_one_e_e)
schema.AddVanishingConstraint(r_name, module, multiplier, nil, inv_e_implies_one_e_e)
// Done
return air.NewColumnAccess(index, 0)
}
Expand All @@ -81,7 +81,7 @@ func (e *Inverse) Bounds() util.Bounds { return e.Expr.Bounds() }

// Context determines the evaluation context (i.e. enclosing module) for this
// expression.
func (e *Inverse) Context(schema sc.Schema) (uint, bool) {
func (e *Inverse) Context(schema sc.Schema) (uint, uint, bool) {
return e.Expr.Context(schema)
}

Expand Down
9 changes: 5 additions & 4 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func (p *Schema) AddAssignment(c schema.Assignment) uint {
}

// AddLookupConstraint appends a new lookup constraint.
func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, sources []uint, targets []uint) {
func (p *Schema) AddLookupConstraint(handle string, source uint, source_multiplier uint,
target uint, target_multiplier uint, sources []uint, targets []uint) {
if len(targets) != len(sources) {
panic("differeng number of target / source lookup columns")
}
Expand All @@ -102,7 +103,7 @@ func (p *Schema) AddLookupConstraint(handle string, source uint, target uint, so
}
//
p.constraints = append(p.constraints,
constraint.NewLookupConstraint(handle, source, target, from, into))
constraint.NewLookupConstraint(handle, source, source_multiplier, target, target_multiplier, from, into))
}

// AddPermutationConstraint appends a new permutation constraint which
Expand All @@ -113,13 +114,13 @@ func (p *Schema) AddPermutationConstraint(targets []uint, sources []uint) {
}

// AddVanishingConstraint appends a new vanishing constraint.
func (p *Schema) AddVanishingConstraint(handle string, module uint, domain *int, expr Expr) {
func (p *Schema) AddVanishingConstraint(handle string, module uint, multiplier uint, domain *int, expr Expr) {
if module >= uint(len(p.modules)) {
panic(fmt.Sprintf("invalid module index (%d)", module))
}
// TODO: sanity check expression enclosed by module
p.constraints = append(p.constraints,
constraint.NewVanishingConstraint(handle, module, domain, constraint.ZeroTest[Expr]{Expr: expr}))
constraint.NewVanishingConstraint(handle, module, multiplier, domain, constraint.ZeroTest[Expr]{Expr: expr}))
}

// AddRangeConstraint appends a new range constraint.
Expand Down
13 changes: 10 additions & 3 deletions pkg/binfile/computation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/consensys/go-corset/pkg/hir"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
)

type jsonComputationSet struct {
Expand All @@ -26,6 +27,8 @@ type jsonSortedComputation struct {
// =============================================================================

func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
var multiplier uint
//
for _, c := range e.Computations {
if c.Sorted != nil {
targetRefs := asColumnRefs(c.Sorted.Tos)
Expand Down Expand Up @@ -53,13 +56,17 @@ func (e jsonComputationSet) addToSchema(schema *hir.Schema) {
// Sanity check we have a sensible type here.
if ith.Type().AsUint() == nil {
panic(fmt.Sprintf("source column %s has field type", sourceRefs[i]))
} else if i == 0 {
multiplier = ith.LengthMultiplier()
} else if multiplier != ith.LengthMultiplier() {
panic(fmt.Sprintf("source column %s has inconsistent length multiplier", sourceRefs[i]))
}

sources[i] = src_cid
targets[i] = sc.NewColumn(ith.Module(), targetRef.column, ith.Type())
targets[i] = sc.NewColumn(ith.Module(), targetRef.column, multiplier, ith.Type())
}
// Finally, add the permutation column
schema.AddPermutationColumns(module, targets, c.Sorted.Signs, sources)
// Finally, add the sorted permutation assignment
schema.AddAssignment(assignment.NewSortedPermutation(module, multiplier, targets, c.Sorted.Signs, sources))
}
}
}
4 changes: 2 additions & 2 deletions pkg/binfile/constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ func (e jsonConstraint) addToSchema(schema *hir.Schema) {
// Translate Domain
domain := e.Vanishes.Domain.toHir()
// Determine enclosing module
module := sc.DetermineEnclosingModuleOfExpression(expr, schema)
module, multiplier := sc.DetermineEnclosingModuleOfExpression(expr, schema)
// Construct the vanishing constraint
schema.AddVanishingConstraint(e.Vanishes.Handle, module, domain, expr)
schema.AddVanishingConstraint(e.Vanishes.Handle, module, multiplier, domain, expr)
} else if e.Permutation == nil {
// Catch all
panic("Unknown JSON constraint encountered")
Expand Down
20 changes: 11 additions & 9 deletions pkg/hir/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"fmt"

"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
)

Expand Down Expand Up @@ -68,17 +69,18 @@ func (p *Environment) AddDataColumn(module uint, column string, datatype sc.Type
return cid
}

// AddPermutationColumns registers a new permutation within a given module. Observe that
// this will panic if any of the target columns already exists, or the source
// columns don't exist.
func (p *Environment) AddPermutationColumns(module uint, targets []sc.Column, signs []bool, sources []uint) {
// AddAssignment appends a new assignment (i.e. set of computed columns) to be
// used during trace expansion for this schema. Computed columns are introduced
// by the process of lowering from HIR / MIR to AIR.
func (p *Environment) AddAssignment(decl schema.Assignment) {
// Update schema
p.schema.AddPermutationColumns(module, targets, signs, sources)
index := p.schema.AddAssignment(decl)
// Update cache
for _, col := range targets {
cid := uint(len(p.columns))
cref := columnRef{module, col.Name()}
p.columns[cref] = cid
for i := decl.Columns(); i.HasNext(); {
ith := i.Next()
cref := columnRef{ith.Module(), ith.Name()}
p.columns[cref] = index
index++
}
}

Expand Down
Loading

0 comments on commit 4f5dd3e

Please sign in to comment.