Skip to content

Commit

Permalink
Merge pull request #177 from Consensys/172-trace-alignment
Browse files Browse the repository at this point in the history
feat: 172 trace alignment
  • Loading branch information
DavePearce authored Jun 20, 2024
2 parents 3c7f5dc + 7272caa commit e117dae
Show file tree
Hide file tree
Showing 30 changed files with 866 additions and 651 deletions.
2 changes: 1 addition & 1 deletion pkg/air/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tbl table.Trace) *fr.Element {
val := tbl.GetByName(e.Column, k+e.Shift)
val := tbl.ColumnByIndex(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
6 changes: 3 additions & 3 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND }
// accesses the STAMP column at row 5, whilst CT(-1) accesses the CT column at
// row 4.
type ColumnAccess struct {
Column string
Column uint
Shift int
}

// NewColumnAccess constructs an AIR expression representing the value of a given
// column on the current row.
func NewColumnAccess(name string, shift int) Expr {
return &ColumnAccess{name, shift}
func NewColumnAccess(column uint, shift int) Expr {
return &ColumnAccess{column, shift}
}

// Add two expressions together, producing a third.
Expand Down
35 changes: 20 additions & 15 deletions pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ import (
// ApplyBinaryGadget adds a binarity constraint for a given column in the schema
// which enforces that all values in the given column are either 0 or 1. For a
// column X, this corresponds to the vanishing constraint X * (X-1) == 0.
func ApplyBinaryGadget(col string, schema *air.Schema) {
func ApplyBinaryGadget(column uint, schema *air.Schema) {
// Determine column name
name := schema.Column(column).Name()
// Construct X
X := air.NewColumnAccess(col, 0)
X := air.NewColumnAccess(column, 0)
// Construct X-1
X_m1 := X.Sub(air.NewConst64(1))
// Construct X * (X-1)
X_X_m1 := X.Mul(X_m1)
// Done!
schema.AddVanishingConstraint(col, nil, X_X_m1)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u1", name), nil, X_X_m1)
}

// ApplyBitwidthGadget ensures all values in a given column fit within a given
// number of bits. This is implemented using a *byte decomposition* which adds
// n columns and a vanishing constraint (where n*8 >= nbits).
func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) {
func ApplyBitwidthGadget(column uint, nbits uint, schema *air.Schema) {
if nbits%8 != 0 {
panic("asymmetric bitwidth constraints not yet supported")
} else if nbits == 0 {
Expand All @@ -35,38 +37,41 @@ func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) {
n := nbits / 8
es := make([]air.Expr, n)
fr256 := fr.NewElement(256)
name := schema.Column(column).Name()
coefficient := fr.NewElement(1)
// Construct Columns
for i := uint(0); i < n; i++ {
// Determine name for the ith byte column
colName := fmt.Sprintf("%s:%d", col, i)
colName := fmt.Sprintf("%s:%d", name, i)
// Create Column + Constraint
schema.AddColumn(colName, true)
schema.AddRangeConstraint(colName, &fr256)
es[i] = air.NewColumnAccess(colName, 0).Mul(air.NewConstCopy(&coefficient))
colIndex := schema.AddColumn(colName, true)
es[i] = air.NewColumnAccess(colIndex, 0).Mul(air.NewConstCopy(&coefficient))

schema.AddRangeConstraint(colIndex, &fr256)
// Update coefficient
coefficient.Mul(&coefficient, &fr256)
}
// Construct (X:0 * 1) + ... + (X:n * 2^n)
sum := &air.Add{Args: es}
// Construct X == (X:0 * 1) + ... + (X:n * 2^n)
X := air.NewColumnAccess(col, 0)
X := air.NewColumnAccess(column, 0)
eq := X.Equate(sum)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", col, nbits), nil, eq)
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", name, nbits), nil, eq)
// Finally, add the necessary byte decomposition computation.
schema.AddComputation(table.NewByteDecomposition(col, nbits))
schema.AddComputation(table.NewByteDecomposition(name, nbits))
}

// AddBitArray adds an array of n bit columns using a given prefix, including
// the necessary binarity constraints.
func AddBitArray(prefix string, count int, schema *air.Schema) []string {
bits := make([]string, count)
func AddBitArray(prefix string, count int, schema *air.Schema) []uint {
bits := make([]uint, count)

for i := 0; i < count; i++ {
// Construct bit column name
bits[i] = fmt.Sprintf("%s:%d", prefix, i)
ith := fmt.Sprintf("%s:%d", prefix, i)
// Add (synthetic) column
schema.AddColumn(bits[i], true)
bits[i] = schema.AddColumn(ith, true)
// Add binarity constraints (i.e. to enfoce that this column is a bit).
ApplyBinaryGadget(bits[i], schema)
}
Expand Down
16 changes: 9 additions & 7 deletions pkg/air/gadgets/column_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/consensys/go-corset/pkg/table"
)

// ApplyColumnSortGadget Add sorting constraints for a column where the
// ApplyColumnSortGadget adds sorting constraints for a column where the
// difference between any two rows (i.e. the delta) is constrained to fit within
// a given bitwidth. The target column is assumed to have an appropriate
// (enforced) bitwidth to ensure overflow cannot arise. The sorting constraint
Expand All @@ -18,27 +18,29 @@ import (
// This gadget does not attempt to sort the column data during trace expansion,
// and assumes the data either comes sorted or is sorted by some other
// computation.
func ApplyColumnSortGadget(column string, sign bool, bitwidth uint, schema *air.Schema) {
func ApplyColumnSortGadget(column uint, sign bool, bitwidth uint, schema *air.Schema) {
var deltaName string
// Determine column name
name := schema.Column(column).Name()
// Configure computation
Xk := air.NewColumnAccess(column, 0)
Xkm1 := air.NewColumnAccess(column, -1)
// Account for sign
var Xdiff air.Expr
if sign {
Xdiff = Xk.Sub(Xkm1)
deltaName = fmt.Sprintf("+%s", column)
deltaName = fmt.Sprintf("+%s", name)
} else {
Xdiff = Xkm1.Sub(Xk)
deltaName = fmt.Sprintf("-%s", column)
deltaName = fmt.Sprintf("-%s", name)
}
// Add delta column
schema.AddColumn(deltaName, true)
deltaIndex := schema.AddColumn(deltaName, true)
// Add diff computation
schema.AddComputation(table.NewComputedColumn(deltaName, Xdiff))
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaName, bitwidth, schema)
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
// Configure constraint: Delta[k] = X[k] - X[k-1]
Dk := air.NewColumnAccess(deltaName, 0)
Dk := air.NewColumnAccess(deltaIndex, 0)
schema.AddVanishingConstraint(deltaName, nil, Dk.Equate(Xdiff))
}
48 changes: 24 additions & 24 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,38 @@ import (
// case (see above). The delta value captures the difference Ci[k]-Ci[k-1] to
// ensure it is positive. The delta column is constrained to a given bitwidth,
// with constraints added as necessary to ensure this.
func ApplyLexicographicSortingGadget(columns []string, signs []bool, bitwidth uint, schema *air.Schema) {
func ApplyLexicographicSortingGadget(columns []uint, signs []bool, bitwidth uint, schema *air.Schema) {
// Check preconditions
ncols := len(columns)
if ncols != len(signs) {
panic("Inconsistent number of columns and signs for lexicographic sort.")
}
// Add trace computation
schema.AddComputation(&lexicographicSortExpander{columns, signs, bitwidth})
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(columns, signs)
prefix := constructLexicographicSortingPrefix(columns, signs, schema)
// Add trace computation
schema.AddComputation(&lexicographicSortExpander{prefix, columns, signs, bitwidth})
deltaName := fmt.Sprintf("%s:delta", prefix)
// Construct selecto bits.
bits := addLexicographicSelectorBits(prefix, columns, schema)
// Add delta column
schema.AddColumn(deltaName, true)
deltaIndex := schema.AddColumn(deltaName, true)
// Construct delta terms
constraint := constructLexicographicDeltaConstraint(deltaName, bits, columns, signs)
constraint := constructLexicographicDeltaConstraint(deltaIndex, bits, columns, signs)
// Add delta constraint
schema.AddVanishingConstraint(deltaName, nil, constraint)
// Add necessary bitwidth constraints
ApplyBitwidthGadget(deltaName, bitwidth, schema)
ApplyBitwidthGadget(deltaIndex, bitwidth, schema)
}

// Construct a unique identifier for the given sort. This should not conflict
// with the identifier for any other sort.
func constructLexicographicSortingPrefix(columns []string, signs []bool) string {
func constructLexicographicSortingPrefix(columns []uint, signs []bool, schema *air.Schema) string {
// Use string builder to try and make this vaguely efficient.
var id strings.Builder
// Concatenate column names with their signs.
for i := 0; i < len(columns); i++ {
id.WriteString(columns[i])
ith := schema.Column(columns[i])
id.WriteString(ith.Name())

if signs[i] {
id.WriteString("+")
Expand All @@ -75,7 +76,7 @@ func constructLexicographicSortingPrefix(columns []string, signs []bool) string
//
// NOTE: this implementation differs from the original corset which used an
// additional "Eq" bit to help ensure at most one selector bit was enabled.
func addLexicographicSelectorBits(prefix string, columns []string, schema *air.Schema) []string {
func addLexicographicSelectorBits(prefix string, columns []uint, schema *air.Schema) []uint {
ncols := len(columns)
// Add bits and their binary constraints.
bits := AddBitArray(prefix, ncols, schema)
Expand Down Expand Up @@ -123,11 +124,11 @@ func addLexicographicSelectorBits(prefix string, columns []string, schema *air.S
// appropriately for the sign) between the ith column whose multiplexor bit is
// set. This is assumes that multiplexor bits are mutually exclusive (i.e. at
// most is one).
func constructLexicographicDeltaConstraint(deltaName string, bits []string, columns []string, signs []bool) air.Expr {
func constructLexicographicDeltaConstraint(delta uint, bits []uint, columns []uint, signs []bool) air.Expr {
ncols := len(columns)
// Construct delta terms
terms := make([]air.Expr, ncols)
Dk := air.NewColumnAccess(deltaName, 0)
Dk := air.NewColumnAccess(delta, 0)

for i := 0; i < ncols; i++ {
var Xdiff air.Expr
Expand All @@ -150,7 +151,8 @@ func constructLexicographicDeltaConstraint(deltaName string, bits []string, colu
}

type lexicographicSortExpander struct {
columns []string
prefix string
columns []uint
signs []bool
bitwidth uint
}
Expand All @@ -163,15 +165,15 @@ func (p *lexicographicSortExpander) RequiredSpillage() uint {

// Accepts checks whether a given trace has the necessary columns
func (p *lexicographicSortExpander) Accepts(tr table.Trace) error {
prefix := constructLexicographicSortingPrefix(p.columns, p.signs)
deltaName := fmt.Sprintf("%s:delta", prefix)
//prefix := constructLexicographicSortingPrefix(p.columns, p.signs)
deltaName := fmt.Sprintf("%s:delta", p.prefix)
// Check delta column exists
if !tr.HasColumn(deltaName) {
return fmt.Errorf("Trace missing lexicographic delta column ({%s})", deltaName)
}
// Check selector columns exist
for i := range p.columns {
bitName := fmt.Sprintf("%s:%d", prefix, i)
bitName := fmt.Sprintf("%s:%d", p.prefix, i)
if !tr.HasColumn(bitName) {
return fmt.Errorf("Trace missing lexicographic selector column ({%s})", bitName)
}
Expand All @@ -190,8 +192,7 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
// Determine how many rows to be constrained.
nrows := tr.Height()
// Construct a unique prefix for this sort.
prefix := constructLexicographicSortingPrefix(p.columns, p.signs)
deltaName := fmt.Sprintf("%s:delta", prefix)
deltaName := fmt.Sprintf("%s:delta", p.prefix)
// Initialise new data columns
delta := make([]*fr.Element, nrows)
bit := make([][]*fr.Element, ncols)
Expand All @@ -200,14 +201,14 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
bit[i] = make([]*fr.Element, nrows)
}

for i := uint(0); i < nrows; i++ {
for i := 0; i < int(nrows); i++ {
set := false
// Initialise delta to zero
delta[i] = &zero
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
prev := tr.GetByName(p.columns[j], int(i-1))
curr := tr.GetByName(p.columns[j], int(i))
prev := tr.ColumnByIndex(p.columns[j]).Get(i - 1)
curr := tr.ColumnByIndex(p.columns[j]).Get(i)

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element
Expand All @@ -228,12 +229,11 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
}
}
}

// Add delta column data
tr.AddColumn(deltaName, delta, &zero)
// Add bit column data
for i := 0; i < ncols; i++ {
bitName := fmt.Sprintf("%s:%d", prefix, i)
bitName := fmt.Sprintf("%s:%d", p.prefix, i)
tr.AddColumn(bitName, bit[i], &zero)
}
// Done.
Expand All @@ -243,5 +243,5 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
// String returns a string representation of this constraint. This is primarily
// used for debugging.
func (p *lexicographicSortExpander) String() string {
return fmt.Sprintf("(lexer (%s) (%v) :%d))", any(p.columns), p.signs, p.bitwidth)
return fmt.Sprintf("(lexer (%v) (%v) :%d))", any(p.columns), p.signs, p.bitwidth)
}
14 changes: 6 additions & 8 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@ func ApplyPseudoInverseGadget(e air.Expr, tbl *air.Schema) air.Expr {
ie := &Inverse{Expr: e}
// Determine computed column name
name := ie.String()
// Look up column
index, ok := tbl.ColumnIndex(name)
// Add new column (if it does not already exist)
if !tbl.HasColumn(name) {
if !ok {
// Add (synthetic) computed column
tbl.AddColumn(name, true)
index = tbl.AddColumn(name, true)
tbl.AddComputation(table.NewComputedColumn(name, ie))
}

// Construct 1/e
inv_e := air.NewColumnAccess(name, 0)
inv_e := air.NewColumnAccess(index, 0)
// Construct e/e
e_inv_e := e.Mul(inv_e)
// Construct 1 == e/e
Expand All @@ -54,7 +56,7 @@ func ApplyPseudoInverseGadget(e air.Expr, tbl *air.Schema) air.Expr {
r_name := fmt.Sprintf("[%s =>]", ie.String())
tbl.AddVanishingConstraint(r_name, nil, inv_e_implies_one_e_e)
// Done
return air.NewColumnAccess(name, 0)
return air.NewColumnAccess(index, 0)
}

// Inverse represents a computation which computes the multiplicative
Expand All @@ -66,10 +68,6 @@ type Inverse struct{ Expr air.Expr }
func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element {
inv := new(fr.Element)
val := e.Expr.EvalAt(k, tbl)
// Catch undefined case
if val == nil {
return nil
}
// Go syntax huh?
return inv.Inverse(val)
}
Expand Down
Loading

0 comments on commit e117dae

Please sign in to comment.