Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Non-Strict Trace Checking #222

Merged
merged 3 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ func (p *Schema) AddRangeConstraint(column uint, bound *fr.Element) {
// Schema Interface
// ============================================================================

// Inputs returns an array over the input declarations of this schema. That is,
// the subset of declarations whose trace values must be provided by the user.
func (p *Schema) Inputs() util.Iterator[schema.Declaration] {
return util.NewArrayIterator(p.inputs)
// InputColumns returns an array over the input columns of this schema. That
// is, the subset of columns whose trace values must be provided by the
// user.
func (p *Schema) InputColumns() util.Iterator[schema.Column] {
inputs := util.NewArrayIterator(p.inputs)
return util.NewFlattenIterator[schema.Declaration, schema.Column](inputs,
func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() })
}

// Assignments returns an array over the assignments of this schema. That
Expand All @@ -149,7 +152,8 @@ func (p *Schema) Assignments() util.Iterator[schema.Assignment] {
// Columns returns an array over the underlying columns of this schema.
// Specifically, the index of a column in this array is its column index.
func (p *Schema) Columns() util.Iterator[schema.Column] {
is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(),
inputs := util.NewArrayIterator(p.inputs)
is := util.NewFlattenIterator[schema.Declaration, schema.Column](inputs,
func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() })
ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(),
func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() })
Expand All @@ -166,8 +170,10 @@ func (p *Schema) Constraints() util.Iterator[schema.Constraint] {
// Declarations returns an array over the column declarations of this
// schema.
func (p *Schema) Declarations() util.Iterator[schema.Declaration] {
inputs := util.NewArrayIterator(p.inputs)
ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments())
return p.Inputs().Append(ps)

return inputs.Append(ps)
}

// Modules returns an iterator over the declared set of modules within this
Expand Down
46 changes: 44 additions & 2 deletions pkg/cmd/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var checkCmd = &cobra.Command{
cfg.expand = !getFlag(cmd, "raw")
cfg.report = getFlag(cmd, "report")
cfg.spillage = getInt(cmd, "spillage")
cfg.strict = !getFlag(cmd, "warn")
cfg.padding.Right = getUint(cmd, "padding")
// TODO: support true ranges
cfg.padding.Left = cfg.padding.Right
Expand Down Expand Up @@ -59,6 +60,10 @@ type checkConfig struct {
spillage int
// Determines how much padding to use
padding util.Pair[uint, uint]
// Specified whether strict checking is performed or not. This is enabled
// by default, and ensures the tool fails with an error in any unexpected or
// unusual case.
strict bool
// Specifies whether or not to perform trace expansion. Trace expansion is
// not required when a "raw" trace is given which already includes all
// implied columns.
Expand Down Expand Up @@ -166,7 +171,7 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
trace.PadColumns(tr, sc.RequiredSpillage(schema))
}
// Perform Input Alignment
if err := sc.AlignInputs(tr, schema); err != nil {
if err := performAlignment(true, tr, schema, cfg); err != nil {
return tr, err
}
// Expand trace
Expand All @@ -175,7 +180,7 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
}
}
// Perform Alignment
if err := sc.Align(tr, schema); err != nil {
if err := performAlignment(false, tr, schema, cfg); err != nil {
return tr, err
}
// Check whether padding requested
Expand All @@ -198,6 +203,41 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace,
return nil, nil
}

// Run the alignment algorithm with optional checks determined by the configuration.
func performAlignment(inputs bool, tr trace.Trace, schema sc.Schema, cfg checkConfig) error {
var err error

var nSchemaCols uint
// Determine number of trace columns
nTraceCols := tr.Columns().Len()

if inputs {
nSchemaCols = schema.InputColumns().Count()
err = sc.AlignInputs(tr, schema)
} else {
nSchemaCols = schema.Columns().Count()
err = sc.Align(tr, schema)
}
// Sanity check error
if err != nil {
return err
} else if cfg.strict && nSchemaCols != nTraceCols {
col := tr.Columns().Get(nSchemaCols)
mod := tr.Modules().Get(col.Context().Module())
// Return error
return fmt.Errorf("unknown trace column %s", sc.QualifiedColumnName(mod.Name(), col.Name()))
} else if nSchemaCols != nTraceCols {
// Log warning
for i := nSchemaCols; i < nTraceCols; i++ {
col := tr.Columns().Get(i)
mod := tr.Modules().Get(col.Context().Module())
fmt.Printf("[WARNING] unknown trace column %s\n", sc.QualifiedColumnName(mod.Name(), col.Name()))
}
}

return nil
}

func toErrorString(err error) string {
if err == nil {
return ""
Expand Down Expand Up @@ -225,6 +265,8 @@ func init() {
checkCmd.Flags().Bool("hir", false, "check at HIR level")
checkCmd.Flags().Bool("mir", false, "check at MIR level")
checkCmd.Flags().Bool("air", false, "check at AIR level")
checkCmd.Flags().Bool("warn", false, "report warnings instead of failing for certain errors"+
"(e.g. unknown columns in the trace)")
checkCmd.Flags().Uint("padding", 0, "specify amount of (front) padding to apply")
checkCmd.Flags().Int("spillage", -1,
"specify amount of splillage to account for (where -1 indicates this should be inferred)")
Expand Down
19 changes: 13 additions & 6 deletions pkg/hir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package hir
import (
"fmt"

"github.com/consensys/go-corset/pkg/schema"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/schema/assignment"
"github.com/consensys/go-corset/pkg/schema/constraint"
Expand Down Expand Up @@ -131,10 +132,13 @@ func (p *Schema) AddPropertyAssertion(module uint, handle string, property Expr)
// Schema Interface
// ============================================================================

// Inputs returns an array over the input declarations of this sc. That is,
// the subset of declarations whose trace values must be provided by the user.
func (p *Schema) Inputs() util.Iterator[sc.Declaration] {
return util.NewArrayIterator(p.inputs)
// InputColumns returns an array over the input columns of this schema. That
// is, the subset of columns whose trace values must be provided by the
// user.
func (p *Schema) InputColumns() util.Iterator[sc.Column] {
inputs := util.NewArrayIterator(p.inputs)
return util.NewFlattenIterator[schema.Declaration, schema.Column](inputs,
func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() })
}

// Assignments returns an array over the assignments of this sc. That
Expand All @@ -147,7 +151,8 @@ func (p *Schema) Assignments() util.Iterator[sc.Assignment] {
// Columns returns an array over the underlying columns of this sc.
// Specifically, the index of a column in this array is its column index.
func (p *Schema) Columns() util.Iterator[sc.Column] {
is := util.NewFlattenIterator[sc.Declaration, sc.Column](p.Inputs(),
inputs := util.NewArrayIterator(p.inputs)
is := util.NewFlattenIterator[sc.Declaration, sc.Column](inputs,
func(d sc.Declaration) util.Iterator[sc.Column] { return d.Columns() })
ps := util.NewFlattenIterator[sc.Assignment, sc.Column](p.Assignments(),
func(d sc.Assignment) util.Iterator[sc.Column] { return d.Columns() })
Expand All @@ -164,8 +169,10 @@ func (p *Schema) Constraints() util.Iterator[sc.Constraint] {
// Declarations returns an array over the column declarations of this
// sc.
func (p *Schema) Declarations() util.Iterator[sc.Declaration] {
inputs := util.NewArrayIterator(p.inputs)
ps := util.NewCastIterator[sc.Assignment, sc.Declaration](p.Assignments())
return p.Inputs().Append(ps)

return inputs.Append(ps)
}

// Modules returns an iterator over the declared set of modules within this
Expand Down
18 changes: 12 additions & 6 deletions pkg/mir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,13 @@ func (p *Schema) AddPropertyAssertion(module uint, handle string, expr Expr) {
// Schema Interface
// ============================================================================

// Inputs returns an array over the input declarations of this schema. That is,
// the subset of declarations whose trace values must be provided by the user.
func (p *Schema) Inputs() util.Iterator[schema.Declaration] {
return util.NewArrayIterator(p.inputs)
// InputColumns returns an array over the input columns of this schema. That
// is, the subset of columns whose trace values must be provided by the
// user.
func (p *Schema) InputColumns() util.Iterator[schema.Column] {
inputs := util.NewArrayIterator(p.inputs)
return util.NewFlattenIterator[schema.Declaration, schema.Column](inputs,
func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() })
}

// Assignments returns an array over the assignments of this schema. That
Expand All @@ -144,7 +147,8 @@ func (p *Schema) Assignments() util.Iterator[schema.Assignment] {
// Columns returns an array over the underlying columns of this schema.
// Specifically, the index of a column in this array is its column index.
func (p *Schema) Columns() util.Iterator[schema.Column] {
is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(),
inputs := util.NewArrayIterator(p.inputs)
is := util.NewFlattenIterator[schema.Declaration, schema.Column](inputs,
func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() })
ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(),
func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() })
Expand All @@ -161,8 +165,10 @@ func (p *Schema) Constraints() util.Iterator[schema.Constraint] {
// Declarations returns an array over the column declarations of this
// schema.
func (p *Schema) Declarations() util.Iterator[schema.Declaration] {
inputs := util.NewArrayIterator(p.inputs)
ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments())
return p.Inputs().Append(ps)

return inputs.Append(ps)
}

// Modules returns an iterator over the declared set of modules within this
Expand Down
33 changes: 17 additions & 16 deletions pkg/schema/alignment.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ func Align(p tr.Trace, schema Schema) error {
// unexpanded mode, the trace is only expected to contain input (i.e.
// non-computed) columns. Furthermore, in the schema these are expected to be
// allocated before computed columns. As such, alignment of these input
// columns is performed.
// columns is performed. Finally, it is worth noting that alignment can succeed
// when there are more trace columns than schema columns. In such case, the
// common columns are aligned at the beginning of the index space, whilst the
// remainder come at the end.
func alignWith(expand bool, p tr.Trace, schema Schema) error {
columns := p.Columns()
modules := p.Modules()
Expand Down Expand Up @@ -76,9 +79,10 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
// Extract schema column & module
schemaCol := j.Next()
schemaMod := schema.Modules().Nth(schemaCol.Context().Module())
schemaQualifiedCol := QualifiedColumnName(schemaMod.Name(), schemaCol.Name())
// Sanity check column exists
if colIndex >= ncols {
return fmt.Errorf("trace missing column %s.%s (too few columns)", schemaMod.Name(), schemaCol.Name())
return fmt.Errorf("missing column %s (too few columns)", schemaQualifiedCol)
}
// Extract trace column and module
traceCol := columns.Get(colIndex)
Expand All @@ -89,7 +93,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
k, ok := p.Columns().IndexOf(schemaCol.Context().Module(), schemaCol.Name())
// check exists
if !ok {
return fmt.Errorf("trace missing column %s.%s", schemaMod.Name(), schemaCol.Name())
return fmt.Errorf("missing column %s", schemaQualifiedCol)
}
// Swap columns
columns.Swap(colIndex, k)
Expand All @@ -99,18 +103,15 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error {
}
}
}
// Check whether all columns matched
if colIndex == ncols {
// Yes, alignment complete.
return nil
}
// Error Case.
n := ncols - colIndex
unknowns := make([]string, n)
// Determine names of unknown columns.
for i := colIndex; i < ncols; i++ {
unknowns[i-colIndex] = columns.Get(i).Name()
// Alignment complete.
return nil
}

// QualifiedColumnName returns the fully qualified name of a given column.
func QualifiedColumnName(module string, column string) string {
if module == "" {
return column
}
//
return fmt.Errorf("trace contains unknown columns: %v", unknowns)

return fmt.Sprintf("%s.%s", module, column)
}
3 changes: 3 additions & 0 deletions pkg/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ type Schema interface {
// schema.
Declarations() util.Iterator[Declaration]

// Iterator over the input (i.e. non-computed) columns of the schema.
InputColumns() util.Iterator[Column]

// Modules returns an iterator over the declared set of modules within this
// schema.
Modules() util.Iterator[Module]
Expand Down
36 changes: 28 additions & 8 deletions pkg/test/ir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,15 @@ func Check(t *testing.T, test string) {
// Check a given set of tests have an expected outcome (i.e. are
// either accepted or rejected) by a given set of constraints.
func CheckTraces(t *testing.T, test string, expected bool, traces []trace.Trace, hirSchema *hir.Schema) {
filename := getTestFileName(test, expected)
// Determine the number of expected input columns.
nSchemaInputs := hirSchema.InputColumns().Count()

//
for i, tr := range traces {
if tr != nil {
nTraceCols := tr.Columns().Len()
//
for padding := uint(0); padding <= MAX_PADDING; padding++ {
// Lower HIR => MIR
mirSchema := hirSchema.LowerToMir()
Expand All @@ -499,16 +506,21 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []trace.Trace,
mirID := traceId{"MIR", test, expected, i + 1, padding, schema.RequiredSpillage(mirSchema)}
airID := traceId{"AIR", test, expected, i + 1, padding, schema.RequiredSpillage(airSchema)}
// Check whether trace is input compatible with schema
if err := sc.AlignInputs(tr, hirSchema); err != nil {
// Alignment failed. So, attempt alignment as expanded
// trace instead.
if err := sc.Align(tr, airSchema); err != nil {
// Still failed, hence trace must be malformed in some way
if expected {
t.Errorf("Trace malformed (%s.accepts, line %d): [%s]", test, i+1, err)
if err := sc.AlignInputs(tr, hirSchema); err != nil || nSchemaInputs != nTraceCols {
nSchemaCols := airSchema.Columns().Count()
// Alignment failed. So, attempt alignment as expanded trace instead.
if err := sc.Align(tr, airSchema); err != nil || nSchemaCols != nTraceCols {
var msg string
// Still failed, hence trace must be malformed in some way.
if err != nil {
msg = err.Error()
} else {
t.Errorf("Trace malformed (%s.rejects, line %d): [%s]", test, i+1, err)
col := tr.Columns().Get(nSchemaCols)
mod := tr.Modules().Get(col.Context().Module())
msg = fmt.Sprintf("unknown column %s", schema.QualifiedColumnName(mod.Name(), col.Name()))
}

t.Errorf("Malformed (expanded) trace (%s, line %d): %s", filename, i+1, msg)
} else {
// Aligned as expanded trace
checkExpandedTrace(t, tr, airID, airSchema)
Expand All @@ -524,6 +536,14 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []trace.Trace,
}
}

func getTestFileName(test string, expected bool) string {
if expected {
return fmt.Sprintf("%s.accepts", test)
}

return fmt.Sprintf("%s.rejects", test)
}

func checkInputTrace(t *testing.T, tr trace.Trace, id traceId, schema sc.Schema) {
// Clone trace (to ensure expansion does not affect subsequent tests)
etr := tr.Clone()
Expand Down