Skip to content

Commit

Permalink
Merge pull request #194 from Consensys/189-support-modules-in-trace
Browse files Browse the repository at this point in the history
Support Modules in `Trace`
  • Loading branch information
DavePearce authored Jul 1, 2024
2 parents 61ff183 + 162f2b2 commit 59c74ca
Show file tree
Hide file tree
Showing 32 changed files with 648 additions and 260 deletions.
2 changes: 1 addition & 1 deletion pkg/air/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element {
val := tr.Column(e.Column).Get(k + e.Shift)
val := tr.Columns().Get(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
8 changes: 4 additions & 4 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
// Apply spillage
if cfg.spillage >= 0 {
// Apply user-specified spillage
tr.Pad(uint(cfg.spillage))
trace.PadColumns(tr, uint(cfg.spillage))
} else {
// Apply default inferred spillage
tr.Pad(sc.RequiredSpillage(schema))
trace.PadColumns(tr, sc.RequiredSpillage(schema))
}
// Perform Input Alignment
if err := sc.AlignInputs(tr, schema); err != nil {
Expand All @@ -188,7 +188,7 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
// Prevent interference
ptr := tr.Clone()
// Apply padding
ptr.Pad(n)
trace.PadColumns(ptr, n)
// Check whether accepted or not.
if err := sc.Accepts(schema, ptr); err != nil {
return ptr, err
Expand All @@ -207,7 +207,7 @@ func toErrorString(err error) string {
}

func reportError(ir string, tr trace.Trace, err error, cfg checkConfig) {
if cfg.report {
if cfg.report && tr != nil {
trace.PrintTrace(tr)
}

Expand Down
70 changes: 46 additions & 24 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ var traceCmd = &cobra.Command{
os.Exit(1)
}
// Parse trace
trace := readTraceFile(args[0])
tr := readTraceFile(args[0])
list := getFlag(cmd, "list")
print := getFlag(cmd, "print")
padding := getUint(cmd, "pad")
Expand All @@ -33,50 +33,71 @@ var traceCmd = &cobra.Command{
max_width := getUint(cmd, "max-width")
filter := getString(cmd, "filter")
output := getString(cmd, "out")
//
// construct filters
if filter != "" {
trace = filterColumns(trace, filter)
tr = filterColumns(tr, filter)
}
if padding != 0 {
trace.Pad(padding)
trace.PadColumns(tr, padding)
}
if list {
listColumns(trace)
listColumns(tr)
}
//
if output != "" {
writeTraceFile(output, trace)
writeTraceFile(output, tr)
}

if print {
printTrace(start, end, max_width, trace)
printTrace(start, end, max_width, tr)
}
},
}

// Construct a new trace containing only those columns from the original who
// name begins with the given prefix.
func filterColumns(tr trace.Trace, prefix string) trace.Trace {
ntr := trace.EmptyArrayTrace()
//
for i := uint(0); i < tr.Width(); i++ {
ith := tr.Column(i)
if strings.HasPrefix(ith.Name(), prefix) {
ntr.Add(ith)
n := tr.Columns().Len()
builder := trace.NewBuilder()
// Initialise modules in the builder to ensure module indices are preserved
// across traces.
for i := uint(0); i < n; i++ {
ith := tr.Columns().Get(i)
name := tr.Modules().Get(ith.Module()).Name()

if !builder.HasModule(name) {
if _, err := builder.Register(name, ith.Height()); err != nil {
panic(err)
}
}
}
// Now create the columns.
for i := uint(0); i < n; i++ {
qName := QualifiedColumnName(i, tr)
//
if strings.HasPrefix(qName, prefix) {
ith := tr.Columns().Get(i)

err := builder.Add(qName, ith.Padding(), ith.Data())
// Sanity check
if err != nil {
panic(err)
}
}
}
// Done
return ntr
return builder.Build()
}

func listColumns(tr trace.Trace) {
tbl := util.NewTablePrinter(3, tr.Width())
n := tr.Columns().Len()
tbl := util.NewTablePrinter(3, n)

for i := uint(0); i < tr.Width(); i++ {
ith := tr.Column(i)
for i := uint(0); i < n; i++ {
ith := tr.Columns().Get(i)
elems := fmt.Sprintf("%d rows", ith.Height())
bytes := fmt.Sprintf("%d bytes", ith.Width()*ith.Height())
tbl.SetRow(i, ith.Name(), elems, bytes)
tbl.SetRow(i, QualifiedColumnName(i, tr), elems, bytes)
}

//
Expand All @@ -85,16 +106,18 @@ func listColumns(tr trace.Trace) {
}

func printTrace(start uint, end uint, max_width uint, tr trace.Trace) {
height := min(tr.Height(), end) - start
tbl := util.NewTablePrinter(1+height, 1+tr.Width())
cols := tr.Columns()
n := tr.Columns().Len()
height := min(trace.MaxHeight(tr), end) - start
tbl := util.NewTablePrinter(1+height, 1+n)

for j := uint(0); j < height; j++ {
tbl.Set(j+1, 0, fmt.Sprintf("#%d", j+start))
}

for i := uint(0); i < tr.Width(); i++ {
ith := tr.Column(i)
tbl.Set(0, i+1, ith.Name())
for i := uint(0); i < n; i++ {
ith := cols.Get(i)
tbl.Set(0, i+1, QualifiedColumnName(i, tr))

if start < ith.Height() {
ith_height := min(ith.Height(), end) - start
Expand All @@ -103,7 +126,6 @@ func printTrace(start uint, end uint, max_width uint, tr trace.Trace) {
}
}
}

//
tbl.SetMaxWidth(max_width)
tbl.Print()
Expand Down
15 changes: 14 additions & 1 deletion pkg/cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func readSchemaFile(filename string) *hir.Schema {
return schema
}
default:
err = fmt.Errorf("Unknown trace file format: %s\n", ext)
err = fmt.Errorf("Unknown schema file format: %s\n", ext)
}
}
// Handle error
Expand Down Expand Up @@ -182,3 +182,16 @@ func printSyntaxError(filename string, err *sexp.SyntaxError, text string) {
// Print highlight
fmt.Println(strings.Repeat("^", length))
}

// QualifiedColumnName returns a fully qualified column name based on its column
// index.
func QualifiedColumnName(cid uint, tr trace.Trace) string {
col := tr.Columns().Get(cid)
mod := tr.Modules().Get(col.Module())
// Check whether qualification required
if mod.Name() != "" {
return fmt.Sprintf("%s.%s", mod.Name(), col.Name())
}
// Prelude module
return col.Name()
}
2 changes: 1 addition & 1 deletion pkg/hir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []*fr.Element {
val := tr.Column(e.Column).Get(k + e.Shift)
val := tr.Columns().Get(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
2 changes: 1 addition & 1 deletion pkg/mir/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
// value at that row of the column in question or nil is that row is
// out-of-bounds.
func (e *ColumnAccess) EvalAt(k int, tr trace.Trace) *fr.Element {
val := tr.Column(e.Column).Get(k + e.Shift)
val := tr.Columns().Get(e.Column).Get(k + e.Shift)

var clone fr.Element
// Clone original value
Expand Down
9 changes: 5 additions & 4 deletions pkg/schema/alignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func Align(p tr.Trace, schema Schema) error {
// allocated before computed columns. As such, alignment of these input
// columns is performed.
func alignWith(expand bool, p tr.Trace, schema Schema) error {
ncols := p.Width()
columns := p.Columns()
ncols := p.Columns().Len()
index := uint(0)
// Check each column described in this schema is present in the trace.
for i := schema.Declarations(); i.HasNext(); {
Expand All @@ -47,7 +48,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
return fmt.Errorf("trace missing column %s", schemaName)
}

traceName := p.Column(index).Name()
traceName := columns.Get(index).Name()
// Check alignment
if traceName != schemaName {
// Not aligned --- so fix
Expand All @@ -57,7 +58,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
return fmt.Errorf("trace missing column %s", schemaName)
}
// Swap columns
p.Swap(index, k)
columns.Swap(index, k)
}
// Continue
index++
Expand All @@ -74,7 +75,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
unknowns := make([]string, n)
// Determine names of unknown columns.
for i := index; i < ncols; i++ {
unknowns[i-index] = p.Column(i).Name()
unknowns[i-index] = columns.Get(i).Name()
}
//
return fmt.Errorf("trace contains unknown columns: %v", unknowns)
Expand Down
11 changes: 9 additions & 2 deletions pkg/schema/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ type PropertyAssertion[T Testable] struct {
// A unique identifier for this constraint. This is primarily
// useful for debugging.
Handle string
// Enclosing module for this assertion. This restricts the asserted
// property to access only columns from within this module.
module uint
// The actual assertion itself, namely an expression which
// should hold (i.e. vanish) for every row of a trace.
// Observe that this can be any function which is computable
Expand All @@ -28,7 +31,8 @@ type PropertyAssertion[T Testable] struct {

// NewPropertyAssertion constructs a new property assertion!
func NewPropertyAssertion[T Testable](handle string, property T) *PropertyAssertion[T] {
return &PropertyAssertion[T]{handle, property}
// FIXME: determine correct module index
return &PropertyAssertion[T]{handle, 0, property}
}

// GetHandle returns the handle associated with this constraint.
Expand All @@ -43,7 +47,10 @@ func (p *PropertyAssertion[T]) GetHandle() string {
//
//nolint:revive
func (p *PropertyAssertion[T]) Accepts(tr tr.Trace) error {
for k := uint(0); k < tr.Height(); k++ {
// Determine height of enclosing module
height := tr.Modules().Get(p.module).Height()
// Iterate every row in the module
for k := uint(0); k < height; k++ {
// Check whether property holds (or was undefined)
if !p.Property.TestAt(int(k), tr) {
// Construct useful error message
Expand Down
8 changes: 5 additions & 3 deletions pkg/schema/assignment/byte_decomposition.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ func (p *ByteDecomposition) IsComputed() bool {
// ByteDecomposition. This requires computing the value of each byte column in
// the decomposition.
func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
columns := tr.Columns()
// Calculate how many bytes required.
n := len(p.targets)
// Identify target column
target := tr.Column(p.source)
target := columns.Get(p.source)
// Extract column data to decompose
data := tr.Column(p.source).Data()
data := columns.Get(p.source).Data()
// Construct byte column data
cols := make([][]*fr.Element, n)
// Initialise columns
Expand All @@ -86,7 +87,8 @@ func (p *ByteDecomposition) ExpandTrace(tr trace.Trace) error {
padding := decomposeIntoBytes(target.Padding(), n)
// Finally, add byte columns to trace
for i := 0; i < n; i++ {
tr.Add(trace.NewFieldColumn(p.targets[i].Name(), cols[i], padding[i]))
ith := p.targets[i]
columns.Add(trace.NewFieldColumn(ith.Module(), ith.Name(), cols[i], padding[i]))
}
// Done
return nil
Expand Down
14 changes: 10 additions & 4 deletions pkg/schema/assignment/computed.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import (
// give rise to "trace expansion". That is where the initial trace provided by
// the user is expanded by determining the value of all computed columns.
type ComputedColumn[E sc.Evaluable] struct {
// Module in which to locate new column
module uint
// Name of the new column
name string
// The computation which accepts a given trace and computes
// the value of this column at a given row.
Expand All @@ -27,7 +30,8 @@ type ComputedColumn[E sc.Evaluable] struct {
// determining expression. More specifically, that expression is used to
// compute the values for this column during trace expansion.
func NewComputedColumn[E sc.Evaluable](name string, expr E) *ComputedColumn[E] {
return &ComputedColumn[E]{name, expr}
// FIXME: module index should not always be zero!
return &ComputedColumn[E]{0, name, expr}
}

// nolint:revive
Expand Down Expand Up @@ -76,11 +80,13 @@ func (p *ComputedColumn[E]) RequiredSpillage() uint {
// of evaluating a given expression on each row. If the column already exists,
// then an error is flagged.
func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
if tr.HasColumn(p.name) {
columns := tr.Columns()
// Check whether a column already exists with the given name.
if tr.Columns().HasColumn(p.name) {
return fmt.Errorf("Computed column already exists ({%s})", p.name)
}

data := make([]*fr.Element, tr.Height())
data := make([]*fr.Element, tr.Modules().Get(p.module).Height())
// Expand the trace
for i := 0; i < len(data); i++ {
val := p.expr.EvalAt(i, tr)
Expand All @@ -96,7 +102,7 @@ func (p *ComputedColumn[E]) ExpandTrace(tr trace.Trace) error {
// the padding value for *this* column.
padding := p.expr.EvalAt(-1, tr)
// Colunm needs to be expanded.
tr.Add(trace.NewFieldColumn(p.name, data, padding))
columns.Add(trace.NewFieldColumn(p.module, p.name, data, padding))
// Done
return nil
}
Loading

0 comments on commit 59c74ca

Please sign in to comment.