From 8b1d43ea222c6149a09b92048c19be37f91f16f2 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 10 Jul 2024 14:37:46 +1200 Subject: [PATCH 1/3] Update Alignment Algorithm This makes a small (but important) change to the alignment algorithm. Specifically, it no longer reports an error if there are unknown columns in the trace. Instead, the caller is responsible for checking this and reacting accordingly. This enables the caller to continue if it so wishes. --- pkg/air/schema.go | 18 ++++++++++++------ pkg/cmd/check.go | 15 +++++++++++++++ pkg/hir/schema.go | 19 +++++++++++++------ pkg/mir/schema.go | 18 ++++++++++++------ pkg/schema/alignment.go | 33 +++++++++++++++++---------------- pkg/schema/schema.go | 3 +++ pkg/test/ir_test.go | 36 ++++++++++++++++++++++++++++-------- 7 files changed, 100 insertions(+), 42 deletions(-) diff --git a/pkg/air/schema.go b/pkg/air/schema.go index bcfdefa..b0763ee 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -133,10 +133,13 @@ func (p *Schema) AddRangeConstraint(column uint, bound *fr.Element) { // Schema Interface // ============================================================================ -// Inputs returns an array over the input declarations of this schema. That is, -// the subset of declarations whose trace values must be provided by the user. -func (p *Schema) Inputs() util.Iterator[schema.Declaration] { - return util.NewArrayIterator(p.inputs) +// InputColumns returns an array over the input columns of this schema. That +// is, the subset of columns whose trace values must be provided by the +// user. +func (p *Schema) InputColumns() util.Iterator[schema.Column] { + inputs := util.NewArrayIterator(p.inputs) + return util.NewFlattenIterator[schema.Declaration, schema.Column](inputs, + func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) } // Assignments returns an array over the assignments of this schema. That @@ -149,7 +152,8 @@ func (p *Schema) Assignments() util.Iterator[schema.Assignment] { // Columns returns an array over the underlying columns of this schema. // Specifically, the index of a column in this array is its column index. func (p *Schema) Columns() util.Iterator[schema.Column] { - is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(), + inputs := util.NewArrayIterator(p.inputs) + is := util.NewFlattenIterator[schema.Declaration, schema.Column](inputs, func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(), func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() }) @@ -166,8 +170,10 @@ func (p *Schema) Constraints() util.Iterator[schema.Constraint] { // Declarations returns an array over the column declarations of this // schema. func (p *Schema) Declarations() util.Iterator[schema.Declaration] { + inputs := util.NewArrayIterator(p.inputs) ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments()) - return p.Inputs().Append(ps) + + return inputs.Append(ps) } // Modules returns an iterator over the declared set of modules within this diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index 9303ed6..c4a84cf 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -154,6 +154,11 @@ func checkTraceWithLoweringDefault(tr trace.Trace, hirSchema *hir.Schema, cfg ch } func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, error) { + // Determine the number of expected input columns. + nSchemaInputs := schema.InputColumns().Count() + // Determine the number of expected columns. + nSchemaCols := schema.Columns().Count() + if cfg.expand { // Clone to prevent interefence with subsequent checks tr = tr.Clone() @@ -168,6 +173,11 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, // Perform Input Alignment if err := sc.AlignInputs(tr, schema); err != nil { return tr, err + } else if nSchemaInputs != tr.Columns().Len() { + col := tr.Columns().Get(nSchemaInputs) + mod := tr.Modules().Get(col.Context().Module()) + + return tr, fmt.Errorf("unknown trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) } // Expand trace if err := sc.ExpandTrace(schema, tr); err != nil { @@ -177,6 +187,11 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, // Perform Alignment if err := sc.Align(tr, schema); err != nil { return tr, err + } else if nSchemaCols != tr.Columns().Len() { + col := tr.Columns().Get(nSchemaCols) + mod := tr.Modules().Get(col.Context().Module()) + + return tr, fmt.Errorf("unknown (expanded) trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) } // Check whether padding requested if cfg.padding.Left == 0 && cfg.padding.Right == 0 { diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index 9d8e560..8f6147f 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -3,6 +3,7 @@ package hir import ( "fmt" + "github.com/consensys/go-corset/pkg/schema" sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/schema/assignment" "github.com/consensys/go-corset/pkg/schema/constraint" @@ -131,10 +132,13 @@ func (p *Schema) AddPropertyAssertion(module uint, handle string, property Expr) // Schema Interface // ============================================================================ -// Inputs returns an array over the input declarations of this sc. That is, -// the subset of declarations whose trace values must be provided by the user. -func (p *Schema) Inputs() util.Iterator[sc.Declaration] { - return util.NewArrayIterator(p.inputs) +// InputColumns returns an array over the input columns of this schema. That +// is, the subset of columns whose trace values must be provided by the +// user. +func (p *Schema) InputColumns() util.Iterator[sc.Column] { + inputs := util.NewArrayIterator(p.inputs) + return util.NewFlattenIterator[schema.Declaration, schema.Column](inputs, + func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) } // Assignments returns an array over the assignments of this sc. That @@ -147,7 +151,8 @@ func (p *Schema) Assignments() util.Iterator[sc.Assignment] { // Columns returns an array over the underlying columns of this sc. // Specifically, the index of a column in this array is its column index. func (p *Schema) Columns() util.Iterator[sc.Column] { - is := util.NewFlattenIterator[sc.Declaration, sc.Column](p.Inputs(), + inputs := util.NewArrayIterator(p.inputs) + is := util.NewFlattenIterator[sc.Declaration, sc.Column](inputs, func(d sc.Declaration) util.Iterator[sc.Column] { return d.Columns() }) ps := util.NewFlattenIterator[sc.Assignment, sc.Column](p.Assignments(), func(d sc.Assignment) util.Iterator[sc.Column] { return d.Columns() }) @@ -164,8 +169,10 @@ func (p *Schema) Constraints() util.Iterator[sc.Constraint] { // Declarations returns an array over the column declarations of this // sc. func (p *Schema) Declarations() util.Iterator[sc.Declaration] { + inputs := util.NewArrayIterator(p.inputs) ps := util.NewCastIterator[sc.Assignment, sc.Declaration](p.Assignments()) - return p.Inputs().Append(ps) + + return inputs.Append(ps) } // Modules returns an iterator over the declared set of modules within this diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 0a2d635..82df9b3 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -128,10 +128,13 @@ func (p *Schema) AddPropertyAssertion(module uint, handle string, expr Expr) { // Schema Interface // ============================================================================ -// Inputs returns an array over the input declarations of this schema. That is, -// the subset of declarations whose trace values must be provided by the user. -func (p *Schema) Inputs() util.Iterator[schema.Declaration] { - return util.NewArrayIterator(p.inputs) +// InputColumns returns an array over the input columns of this schema. That +// is, the subset of columns whose trace values must be provided by the +// user. +func (p *Schema) InputColumns() util.Iterator[schema.Column] { + inputs := util.NewArrayIterator(p.inputs) + return util.NewFlattenIterator[schema.Declaration, schema.Column](inputs, + func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) } // Assignments returns an array over the assignments of this schema. That @@ -144,7 +147,8 @@ func (p *Schema) Assignments() util.Iterator[schema.Assignment] { // Columns returns an array over the underlying columns of this schema. // Specifically, the index of a column in this array is its column index. func (p *Schema) Columns() util.Iterator[schema.Column] { - is := util.NewFlattenIterator[schema.Declaration, schema.Column](p.Inputs(), + inputs := util.NewArrayIterator(p.inputs) + is := util.NewFlattenIterator[schema.Declaration, schema.Column](inputs, func(d schema.Declaration) util.Iterator[schema.Column] { return d.Columns() }) ps := util.NewFlattenIterator[schema.Assignment, schema.Column](p.Assignments(), func(d schema.Assignment) util.Iterator[schema.Column] { return d.Columns() }) @@ -161,8 +165,10 @@ func (p *Schema) Constraints() util.Iterator[schema.Constraint] { // Declarations returns an array over the column declarations of this // schema. func (p *Schema) Declarations() util.Iterator[schema.Declaration] { + inputs := util.NewArrayIterator(p.inputs) ps := util.NewCastIterator[schema.Assignment, schema.Declaration](p.Assignments()) - return p.Inputs().Append(ps) + + return inputs.Append(ps) } // Modules returns an iterator over the declared set of modules within this diff --git a/pkg/schema/alignment.go b/pkg/schema/alignment.go index e44c0f9..5bdc295 100644 --- a/pkg/schema/alignment.go +++ b/pkg/schema/alignment.go @@ -30,7 +30,10 @@ func Align(p tr.Trace, schema Schema) error { // unexpanded mode, the trace is only expected to contain input (i.e. // non-computed) columns. Furthermore, in the schema these are expected to be // allocated before computed columns. As such, alignment of these input -// columns is performed. +// columns is performed. Finally, it is worth noting that alignment can succeed +// when there are more trace columns than schema columns. In such case, the +// common columns are aligned at the beginning of the index space, whilst the +// remainder come at the end. func alignWith(expand bool, p tr.Trace, schema Schema) error { columns := p.Columns() modules := p.Modules() @@ -76,9 +79,10 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error { // Extract schema column & module schemaCol := j.Next() schemaMod := schema.Modules().Nth(schemaCol.Context().Module()) + schemaQualifiedCol := QualifiedColumnName(schemaMod.Name(), schemaCol.Name()) // Sanity check column exists if colIndex >= ncols { - return fmt.Errorf("trace missing column %s.%s (too few columns)", schemaMod.Name(), schemaCol.Name()) + return fmt.Errorf("missing column %s (too few columns)", schemaQualifiedCol) } // Extract trace column and module traceCol := columns.Get(colIndex) @@ -89,7 +93,7 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error { k, ok := p.Columns().IndexOf(schemaCol.Context().Module(), schemaCol.Name()) // check exists if !ok { - return fmt.Errorf("trace missing column %s.%s", schemaMod.Name(), schemaCol.Name()) + return fmt.Errorf("missing column %s", schemaQualifiedCol) } // Swap columns columns.Swap(colIndex, k) @@ -99,18 +103,15 @@ func alignWith(expand bool, p tr.Trace, schema Schema) error { } } } - // Check whether all columns matched - if colIndex == ncols { - // Yes, alignment complete. - return nil - } - // Error Case. - n := ncols - colIndex - unknowns := make([]string, n) - // Determine names of unknown columns. - for i := colIndex; i < ncols; i++ { - unknowns[i-colIndex] = columns.Get(i).Name() + // Alignment complete. + return nil +} + +// QualifiedColumnName returns the fully qualified name of a given column. +func QualifiedColumnName(module string, column string) string { + if module == "" { + return column } - // - return fmt.Errorf("trace contains unknown columns: %v", unknowns) + + return fmt.Sprintf("%s.%s", module, column) } diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index d1599bf..f6538b7 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -27,6 +27,9 @@ type Schema interface { // schema. Declarations() util.Iterator[Declaration] + // Iterator over the input (i.e. non-computed) columns of the schema. + InputColumns() util.Iterator[Column] + // Modules returns an iterator over the declared set of modules within this // schema. Modules() util.Iterator[Module] diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index 01ac4dc..ace099d 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -487,8 +487,15 @@ func Check(t *testing.T, test string) { // Check a given set of tests have an expected outcome (i.e. are // either accepted or rejected) by a given set of constraints. func CheckTraces(t *testing.T, test string, expected bool, traces []trace.Trace, hirSchema *hir.Schema) { + filename := getTestFileName(test, expected) + // Determine the number of expected input columns. + nSchemaInputs := hirSchema.InputColumns().Count() + + // for i, tr := range traces { if tr != nil { + nTraceCols := tr.Columns().Len() + // for padding := uint(0); padding <= MAX_PADDING; padding++ { // Lower HIR => MIR mirSchema := hirSchema.LowerToMir() @@ -499,16 +506,21 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []trace.Trace, mirID := traceId{"MIR", test, expected, i + 1, padding, schema.RequiredSpillage(mirSchema)} airID := traceId{"AIR", test, expected, i + 1, padding, schema.RequiredSpillage(airSchema)} // Check whether trace is input compatible with schema - if err := sc.AlignInputs(tr, hirSchema); err != nil { - // Alignment failed. So, attempt alignment as expanded - // trace instead. - if err := sc.Align(tr, airSchema); err != nil { - // Still failed, hence trace must be malformed in some way - if expected { - t.Errorf("Trace malformed (%s.accepts, line %d): [%s]", test, i+1, err) + if err := sc.AlignInputs(tr, hirSchema); err != nil || nSchemaInputs != nTraceCols { + nSchemaCols := airSchema.Columns().Count() + // Alignment failed. So, attempt alignment as expanded trace instead. + if err := sc.Align(tr, airSchema); err != nil || nSchemaCols != nTraceCols { + var msg string + // Still failed, hence trace must be malformed in some way. + if err != nil { + msg = err.Error() } else { - t.Errorf("Trace malformed (%s.rejects, line %d): [%s]", test, i+1, err) + col := tr.Columns().Get(nSchemaCols) + mod := tr.Modules().Get(col.Context().Module()) + msg = fmt.Sprintf("unknown column %s", schema.QualifiedColumnName(mod.Name(), col.Name())) } + + t.Errorf("Malformed (expanded) trace (%s, line %d): %s", filename, i+1, msg) } else { // Aligned as expanded trace checkExpandedTrace(t, tr, airID, airSchema) @@ -524,6 +536,14 @@ func CheckTraces(t *testing.T, test string, expected bool, traces []trace.Trace, } } +func getTestFileName(test string, expected bool) string { + if expected { + return fmt.Sprintf("%s.accepts", test) + } + + return fmt.Sprintf("%s.rejects", test) +} + func checkInputTrace(t *testing.T, tr trace.Trace, id traceId, schema sc.Schema) { // Clone trace (to ensure expansion does not affect subsequent tests) etr := tr.Clone() From 0c5b9a62bab344b243e29004f1b8be92a60893fe Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 10 Jul 2024 14:45:59 +1200 Subject: [PATCH 2/3] Support `--warn` in `check` command This adds the ability to override the default and reduce the level of strictness when it comes to unknown columns in the trace. Specifically, with `--warn` the tool will continue with a warning when an unknown trace column is encountered. --- pkg/cmd/check.go | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index c4a84cf..2177bf7 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -32,6 +32,7 @@ var checkCmd = &cobra.Command{ cfg.expand = !getFlag(cmd, "raw") cfg.report = getFlag(cmd, "report") cfg.spillage = getInt(cmd, "spillage") + cfg.strict = !getFlag(cmd, "warn") cfg.padding.Right = getUint(cmd, "padding") // TODO: support true ranges cfg.padding.Left = cfg.padding.Right @@ -59,6 +60,10 @@ type checkConfig struct { spillage int // Determines how much padding to use padding util.Pair[uint, uint] + // Specified whether strict checking is performed or not. This is enabled + // by default, and ensures the tool fails with an error in any unexpected or + // unusual case. + strict bool // Specifies whether or not to perform trace expansion. Trace expansion is // not required when a "raw" trace is given which already includes all // implied columns. @@ -176,8 +181,12 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, } else if nSchemaInputs != tr.Columns().Len() { col := tr.Columns().Get(nSchemaInputs) mod := tr.Modules().Get(col.Context().Module()) - - return tr, fmt.Errorf("unknown trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) + // Choose error or warning + if cfg.strict { + return tr, fmt.Errorf("unknown trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) + } + // Log warning + fmt.Printf("[WARNING] unknown trace column: %s\n", sc.QualifiedColumnName(mod.Name(), col.Name())) } // Expand trace if err := sc.ExpandTrace(schema, tr); err != nil { @@ -187,11 +196,15 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, // Perform Alignment if err := sc.Align(tr, schema); err != nil { return tr, err - } else if nSchemaCols != tr.Columns().Len() { + } else if cfg.strict && nSchemaCols != tr.Columns().Len() { col := tr.Columns().Get(nSchemaCols) mod := tr.Modules().Get(col.Context().Module()) - - return tr, fmt.Errorf("unknown (expanded) trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) + // Choose error or warning + if cfg.strict { + return tr, fmt.Errorf("unknown (expanded) trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) + } + // Log warning + fmt.Printf("[WARNING] unknown trace column: %s\n", sc.QualifiedColumnName(mod.Name(), col.Name())) } // Check whether padding requested if cfg.padding.Left == 0 && cfg.padding.Right == 0 { @@ -240,6 +253,8 @@ func init() { checkCmd.Flags().Bool("hir", false, "check at HIR level") checkCmd.Flags().Bool("mir", false, "check at MIR level") checkCmd.Flags().Bool("air", false, "check at AIR level") + checkCmd.Flags().Bool("warn", false, "report warnings instead of failing for certain errors"+ + "(e.g. unknown columns in the trace)") checkCmd.Flags().Uint("padding", 0, "specify amount of (front) padding to apply") checkCmd.Flags().Int("spillage", -1, "specify amount of splillage to account for (where -1 indicates this should be inferred)") From 0a5f9499e970e26206d7fc8866d19122d5a3a998 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 10 Jul 2024 14:56:44 +1200 Subject: [PATCH 3/3] Tidy up `check` command This tidies up the check command and improves the warnings. --- pkg/cmd/check.go | 62 +++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index 2177bf7..286349b 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -159,11 +159,6 @@ func checkTraceWithLoweringDefault(tr trace.Trace, hirSchema *hir.Schema, cfg ch } func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, error) { - // Determine the number of expected input columns. - nSchemaInputs := schema.InputColumns().Count() - // Determine the number of expected columns. - nSchemaCols := schema.Columns().Count() - if cfg.expand { // Clone to prevent interefence with subsequent checks tr = tr.Clone() @@ -176,17 +171,8 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, trace.PadColumns(tr, sc.RequiredSpillage(schema)) } // Perform Input Alignment - if err := sc.AlignInputs(tr, schema); err != nil { + if err := performAlignment(true, tr, schema, cfg); err != nil { return tr, err - } else if nSchemaInputs != tr.Columns().Len() { - col := tr.Columns().Get(nSchemaInputs) - mod := tr.Modules().Get(col.Context().Module()) - // Choose error or warning - if cfg.strict { - return tr, fmt.Errorf("unknown trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) - } - // Log warning - fmt.Printf("[WARNING] unknown trace column: %s\n", sc.QualifiedColumnName(mod.Name(), col.Name())) } // Expand trace if err := sc.ExpandTrace(schema, tr); err != nil { @@ -194,17 +180,8 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, } } // Perform Alignment - if err := sc.Align(tr, schema); err != nil { + if err := performAlignment(false, tr, schema, cfg); err != nil { return tr, err - } else if cfg.strict && nSchemaCols != tr.Columns().Len() { - col := tr.Columns().Get(nSchemaCols) - mod := tr.Modules().Get(col.Context().Module()) - // Choose error or warning - if cfg.strict { - return tr, fmt.Errorf("unknown (expanded) trace column: %s", sc.QualifiedColumnName(mod.Name(), col.Name())) - } - // Log warning - fmt.Printf("[WARNING] unknown trace column: %s\n", sc.QualifiedColumnName(mod.Name(), col.Name())) } // Check whether padding requested if cfg.padding.Left == 0 && cfg.padding.Right == 0 { @@ -226,6 +203,41 @@ func checkTrace(tr trace.Trace, schema sc.Schema, cfg checkConfig) (trace.Trace, return nil, nil } +// Run the alignment algorithm with optional checks determined by the configuration. +func performAlignment(inputs bool, tr trace.Trace, schema sc.Schema, cfg checkConfig) error { + var err error + + var nSchemaCols uint + // Determine number of trace columns + nTraceCols := tr.Columns().Len() + + if inputs { + nSchemaCols = schema.InputColumns().Count() + err = sc.AlignInputs(tr, schema) + } else { + nSchemaCols = schema.Columns().Count() + err = sc.Align(tr, schema) + } + // Sanity check error + if err != nil { + return err + } else if cfg.strict && nSchemaCols != nTraceCols { + col := tr.Columns().Get(nSchemaCols) + mod := tr.Modules().Get(col.Context().Module()) + // Return error + return fmt.Errorf("unknown trace column %s", sc.QualifiedColumnName(mod.Name(), col.Name())) + } else if nSchemaCols != nTraceCols { + // Log warning + for i := nSchemaCols; i < nTraceCols; i++ { + col := tr.Columns().Get(i) + mod := tr.Modules().Get(col.Context().Module()) + fmt.Printf("[WARNING] unknown trace column %s\n", sc.QualifiedColumnName(mod.Name(), col.Name())) + } + } + + return nil +} + func toErrorString(err error) string { if err == nil { return ""