diff --git a/pkg/corset/binding.go b/pkg/corset/binding.go index 5e8067e..b218d0d 100644 --- a/pkg/corset/binding.go +++ b/pkg/corset/binding.go @@ -4,6 +4,7 @@ import ( "math" "reflect" + "github.com/consensys/go-corset/pkg/sexp" tr "github.com/consensys/go-corset/pkg/trace" ) @@ -128,14 +129,14 @@ func (p *FunctionSignature) SubtypeOf(other *FunctionSignature) bool { // Apply a set of concreate arguments to this function. This substitutes // them through the body of the function producing a single expression. -func (p *FunctionSignature) Apply(args []Expr) Expr { +func (p *FunctionSignature) Apply(args []Expr, srcmap *sexp.SourceMaps[Node]) Expr { mapping := make(map[uint]Expr) // Setup the mapping for i, e := range args { mapping[uint(i)] = e } // Substitute through - return p.body.Substitute(mapping) + return Substitute(p.body, mapping, srcmap) } // ============================================================================ diff --git a/pkg/corset/compiler.go b/pkg/corset/compiler.go index 7af5b8c..c084d98 100644 --- a/pkg/corset/compiler.go +++ b/pkg/corset/compiler.go @@ -87,10 +87,14 @@ func (p *Compiler) Compile() (*hir.Schema, []SyntaxError) { if len(errs) != 0 { return nil, errs } + // Preprocess circuit to remove invocations, reductions, etc. + if errs := PreprocessCircuit(p.debug, p.srcmap, &p.circuit); len(errs) > 0 { + return nil, errs + } // Convert global scope into an environment by allocating all columns. environment := scope.ToEnvironment() // Finally, translate everything and add it to the schema. - return TranslateCircuit(environment, p.debug, p.srcmap, &p.circuit) + return TranslateCircuit(environment, p.srcmap, &p.circuit) } func includeStdlib(stdlib bool, srcfiles []*sexp.SourceFile) []*sexp.SourceFile { diff --git a/pkg/corset/expression.go b/pkg/corset/expression.go index 986c156..cb22732 100644 --- a/pkg/corset/expression.go +++ b/pkg/corset/expression.go @@ -3,6 +3,7 @@ package corset import ( "fmt" "math/big" + "reflect" "github.com/consensys/go-corset/pkg/sexp" tr "github.com/consensys/go-corset/pkg/trace" @@ -26,16 +27,10 @@ type Expr interface { // lists return one value for each element in the list. Note, every // expression must return at least one value. Multiplicity() uint - // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). Context() Context - - // Substitute all variables (such as for function parameters) arising in - // this expression. - Substitute(mapping map[uint]Expr) Expr - // Return set of columns on which this declaration depends. Dependencies() []Symbol } @@ -76,12 +71,6 @@ func (e *Add) Lisp() sexp.SExp { return ListOfExpressions(sexp.NewSymbol("+"), e.Args) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Add) Substitute(mapping map[uint]Expr) Expr { - return &Add{SubstituteExpressions(e.Args, mapping)} -} - // Dependencies needed to signal declaration. func (e *Add) Dependencies() []Symbol { return DependenciesOfExpressions(e.Args) @@ -164,12 +153,6 @@ func (e *ArrayAccess) Lisp() sexp.SExp { }) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *ArrayAccess) Substitute(mapping map[uint]Expr) Expr { - return &ArrayAccess{e.name, e.arg.Substitute(mapping), e.binding} -} - // Resolve this symbol by associating it with the binding associated with // the definition of the symbol to which this refers. func (e *ArrayAccess) Resolve(binding Binding) bool { @@ -222,12 +205,6 @@ func (e *Constant) Lisp() sexp.SExp { return sexp.NewSymbol(e.Val.String()) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Constant) Substitute(mapping map[uint]Expr) Expr { - return e -} - // Dependencies needed to signal declaration. func (e *Constant) Dependencies() []Symbol { return nil @@ -271,12 +248,6 @@ func (e *Debug) Lisp() sexp.SExp { e.Arg.Lisp()}) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Debug) Substitute(mapping map[uint]Expr) Expr { - return &Debug{e.Arg.Substitute(mapping)} -} - // Dependencies needed to signal declaration. func (e *Debug) Dependencies() []Symbol { return e.Arg.Dependencies() @@ -327,12 +298,6 @@ func (e *Exp) Lisp() sexp.SExp { e.Pow.Lisp()}) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Exp) Substitute(mapping map[uint]Expr) Expr { - return &Exp{e.Arg.Substitute(mapping), e.Pow.Substitute(mapping)} -} - // Dependencies needed to signal declaration. func (e *Exp) Dependencies() []Symbol { return DependenciesOfExpressions([]Expr{e.Arg, e.Pow}) @@ -385,13 +350,6 @@ func (e *For) Lisp() sexp.SExp { panic("todo") } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *For) Substitute(mapping map[uint]Expr) Expr { - body := e.Body.Substitute(mapping) - return &For{e.Binding, e.Start, e.End, body} -} - // Dependencies needed to signal declaration. func (e *For) Dependencies() []Symbol { // Remove occurrences of the index variable defined by this expression. In @@ -493,15 +451,6 @@ func (e *If) Lisp() sexp.SExp { e.TrueBranch.Lisp()}) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *If) Substitute(mapping map[uint]Expr) Expr { - return &If{e.kind, e.Condition.Substitute(mapping), - SubstituteOptionalExpression(e.TrueBranch, mapping), - SubstituteOptionalExpression(e.FalseBranch, mapping), - } -} - // Dependencies needed to signal declaration. func (e *If) Dependencies() []Symbol { return DependenciesOfExpressions([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}) @@ -525,7 +474,7 @@ func (e *Invoke) AsConstant() *big.Int { panic("unresolved invocation") } // Unroll body - body := e.signature.Apply(e.args) + body := e.signature.Apply(e.args, nil) // Attempt to evaluate as constant return body.AsConstant() } @@ -561,12 +510,6 @@ func (e *Invoke) Finalise(signature *FunctionSignature) { e.signature = signature } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Invoke) Substitute(mapping map[uint]Expr) Expr { - return &Invoke{e.fn, e.signature, SubstituteExpressions(e.args, mapping)} -} - // Dependencies needed to signal declaration. func (e *Invoke) Dependencies() []Symbol { deps := DependenciesOfExpressions(e.args) @@ -608,12 +551,6 @@ func (e *List) Lisp() sexp.SExp { return ListOfExpressions(sexp.NewSymbol("begin"), e.Args) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *List) Substitute(mapping map[uint]Expr) Expr { - return &List{SubstituteExpressions(e.Args, mapping)} -} - // Dependencies needed to signal declaration. func (e *List) Dependencies() []Symbol { return DependenciesOfExpressions(e.Args) @@ -652,12 +589,6 @@ func (e *Mul) Lisp() sexp.SExp { return ListOfExpressions(sexp.NewSymbol("*"), e.Args) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Mul) Substitute(mapping map[uint]Expr) Expr { - return &Mul{SubstituteExpressions(e.Args, mapping)} -} - // Dependencies needed to signal declaration. func (e *Mul) Dependencies() []Symbol { return DependenciesOfExpressions(e.Args) @@ -699,12 +630,6 @@ func (e *Normalise) Lisp() sexp.SExp { e.Arg.Lisp()}) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Normalise) Substitute(mapping map[uint]Expr) Expr { - return &Normalise{e.Arg.Substitute(mapping)} -} - // Dependencies needed to signal declaration. func (e *Normalise) Dependencies() []Symbol { return e.Arg.Dependencies() @@ -750,16 +675,6 @@ func (e *Reduce) Lisp() sexp.SExp { e.arg.Lisp()}) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Reduce) Substitute(mapping map[uint]Expr) Expr { - return &Reduce{ - e.fn, - e.signature, - e.arg.Substitute(mapping), - } -} - // Finalise the signature for this reduction. func (e *Reduce) Finalise(signature *FunctionSignature) { if signature == nil { @@ -810,12 +725,6 @@ func (e *Sub) Lisp() sexp.SExp { return ListOfExpressions(sexp.NewSymbol("-"), e.Args) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Sub) Substitute(mapping map[uint]Expr) Expr { - return &Sub{SubstituteExpressions(e.Args, mapping)} -} - // Dependencies needed to signal declaration. func (e *Sub) Dependencies() []Symbol { return DependenciesOfExpressions(e.Args) @@ -867,12 +776,6 @@ func (e *Shift) Lisp() sexp.SExp { e.Shift.Lisp()}) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *Shift) Substitute(mapping map[uint]Expr) Expr { - return &Shift{e.Arg.Substitute(mapping), e.Shift.Substitute(mapping)} -} - // Dependencies needed to signal declaration. func (e *Shift) Dependencies() []Symbol { return DependenciesOfExpressions([]Expr{e.Arg, e.Shift}) @@ -989,18 +892,6 @@ func (e *VariableAccess) Lisp() sexp.SExp { return sexp.NewSymbol(name) } -// Substitute all variables (such as for function parameters) arising in -// this expression. -func (e *VariableAccess) Substitute(mapping map[uint]Expr) Expr { - if b, ok1 := e.binding.(*LocalVariableBinding); ok1 { - if e, ok2 := mapping[b.index]; ok2 { - return e - } - } - // Nothing to do here - return e -} - // Dependencies needed to signal declaration. func (e *VariableAccess) Dependencies() []Symbol { return []Symbol{e} @@ -1025,23 +916,98 @@ func ContextOfExpressions(exprs []Expr) Context { return context } -// SubstituteExpressions substitutes all variables found in a given set of +// Substitute variables (such as for function parameters) in this expression +// based on a mapping of said variables to expressions. Furthermore, an +// (optional) source map is provided which will be updated, such that the +// freshly created expressions are mapped to their corresponding nodes. +func Substitute(expr Expr, mapping map[uint]Expr, srcmap *sexp.SourceMaps[Node]) Expr { + var nexpr Expr + // + switch e := expr.(type) { + case *ArrayAccess: + arg := Substitute(e.arg, mapping, srcmap) + nexpr = &ArrayAccess{e.name, arg, e.binding} + case *Add: + args := SubstituteAll(e.Args, mapping, srcmap) + nexpr = &Add{args} + case *Constant: + return e + case *Debug: + arg := Substitute(e.Arg, mapping, srcmap) + nexpr = &Debug{arg} + case *Exp: + arg := Substitute(e.Arg, mapping, srcmap) + pow := Substitute(e.Pow, mapping, srcmap) + // Done + nexpr = &Exp{arg, pow} + case *For: + body := Substitute(e.Body, mapping, srcmap) + nexpr = &For{e.Binding, e.Start, e.End, body} + case *If: + condition := Substitute(e.Condition, mapping, srcmap) + trueBranch := SubstituteOptional(e.TrueBranch, mapping, srcmap) + falseBranch := SubstituteOptional(e.FalseBranch, mapping, srcmap) + // Construct appropriate if form + nexpr = &If{e.kind, condition, trueBranch, falseBranch} + case *Invoke: + args := SubstituteAll(e.args, mapping, srcmap) + nexpr = &Invoke{e.fn, e.signature, args} + case *List: + args := SubstituteAll(e.Args, mapping, srcmap) + nexpr = &List{args} + case *Mul: + args := SubstituteAll(e.Args, mapping, srcmap) + nexpr = &Mul{args} + case *Normalise: + arg := Substitute(e.Arg, mapping, srcmap) + nexpr = &Normalise{arg} + case *Reduce: + arg := Substitute(e.arg, mapping, srcmap) + nexpr = &Reduce{e.fn, e.signature, arg} + case *Sub: + args := SubstituteAll(e.Args, mapping, srcmap) + nexpr = &Sub{args} + case *Shift: + arg := Substitute(e.Arg, mapping, srcmap) + nexpr = &Shift{arg, e.Shift} + case *VariableAccess: + // + if b, ok1 := e.binding.(*LocalVariableBinding); !ok1 { + return e + } else if e2, ok2 := mapping[b.index]; !ok2 { + return e + } else { + return e2 + } + default: + panic(fmt.Sprintf("unknown expression (%s)", reflect.TypeOf(expr))) + } + // + if srcmap != nil { + // Copy over source information + srcmap.Copy(expr, nexpr) + } + // Done + return nexpr +} + +// SubstituteAll substitutes all variables found in a given set of // expressions. -func SubstituteExpressions(exprs []Expr, mapping map[uint]Expr) []Expr { +func SubstituteAll(exprs []Expr, mapping map[uint]Expr, srcmap *sexp.SourceMaps[Node]) []Expr { nexprs := make([]Expr, len(exprs)) // for i := 0; i < len(nexprs); i++ { - nexprs[i] = exprs[i].Substitute(mapping) + nexprs[i] = Substitute(exprs[i], mapping, srcmap) } // return nexprs } -// SubstituteOptionalExpression substitutes through an expression which is +// SubstituteOptional substitutes through an expression which is // optional (i.e. might be nil). In such case, nil is returned. -func SubstituteOptionalExpression(expr Expr, mapping map[uint]Expr) Expr { +func SubstituteOptional(expr Expr, mapping map[uint]Expr, srcmap *sexp.SourceMaps[Node]) Expr { if expr != nil { - expr = expr.Substitute(mapping) + expr = Substitute(expr, mapping, srcmap) } // return expr diff --git a/pkg/corset/preprocessor.go b/pkg/corset/preprocessor.go new file mode 100644 index 0000000..3dfa2e5 --- /dev/null +++ b/pkg/corset/preprocessor.go @@ -0,0 +1,332 @@ +package corset + +import ( + "math/big" + + "github.com/consensys/go-corset/pkg/sexp" +) + +// PreprocessCircuit performs preprocessing prior to final translation. +// Specifically, it expands all invocations, reductions and for loops. Thus, +// final translation is greatly simplified after this step. +func PreprocessCircuit(debug bool, srcmap *sexp.SourceMaps[Node], + circuit *Circuit) []SyntaxError { + // Construct fresh preprocessor + p := preprocessor{debug, srcmap} + // Preprocess all declarations + return p.preprocessDeclarations(circuit) +} + +// Preprocessor performs preprocessing prior to final translation. Specifically, +// it expands all invocations, reductions and for loops. Thus, final +// translation is greatly simplified after this step. +type preprocessor struct { + // Debug enables the use of debug constraints. + debug bool + // Source maps nodes in the circuit back to the spans in their original + // source files. This is needed when reporting syntax errors to generate + // highlights of the relevant source line(s) in question. + srcmap *sexp.SourceMaps[Node] +} + +// preprocess all assignment or constraint declarations in the circuit. +func (p *preprocessor) preprocessDeclarations(circuit *Circuit) []SyntaxError { + errors := p.preprocessDeclarationsInModule("", circuit.Declarations) + // preprocess each module + for _, m := range circuit.Modules { + errs := p.preprocessDeclarationsInModule(m.Name, m.Declarations) + errors = append(errors, errs...) + } + // Done + return errors +} + +// preprocess all assignment or constraint declarations in a given module within +// the circuit. +func (p *preprocessor) preprocessDeclarationsInModule(module string, decls []Declaration) []SyntaxError { + var errors []SyntaxError + // + for _, d := range decls { + errs := p.preprocessDeclaration(d, module) + errors = append(errors, errs...) + } + // Done + return errors +} + +// preprocess an assignment or constraint declarartion which occurs within a +// given module. +func (p *preprocessor) preprocessDeclaration(decl Declaration, module string) []SyntaxError { + var errors []SyntaxError + // + if _, ok := decl.(*DefAliases); ok { + // ignore + } else if _, ok := decl.(*DefColumns); ok { + // ignore + } else if _, ok := decl.(*DefConst); ok { + // ignore + } else if d, ok := decl.(*DefConstraint); ok { + errors = p.preprocessDefConstraint(d, module) + } else if _, ok := decl.(*DefFun); ok { + // ignore + } else if d, ok := decl.(*DefInRange); ok { + errors = p.preprocessDefInRange(d, module) + } else if _, Ok := decl.(*DefInterleaved); Ok { + // ignore + } else if d, ok := decl.(*DefLookup); ok { + errors = p.preprocessDefLookup(d, module) + } else if _, Ok := decl.(*DefPermutation); Ok { + // ignore + } else if d, ok := decl.(*DefProperty); ok { + errors = p.preprocessDefProperty(d, module) + } else { + // Error handling + panic("unknown declaration") + } + // + return errors +} + +// preprocess a "defconstraint" declaration. +func (p *preprocessor) preprocessDefConstraint(decl *DefConstraint, module string) []SyntaxError { + var ( + constraint_errors []SyntaxError + guard_errors []SyntaxError + ) + // preprocess constraint body + decl.Constraint, constraint_errors = p.preprocessExpressionInModule(decl.Constraint, module) + // preprocess (optional) guard + decl.Guard, guard_errors = p.preprocessOptionalExpressionInModule(decl.Guard, module) + // Combine errors + return append(constraint_errors, guard_errors...) +} + +// preprocess a "deflookup" declaration. +// +//nolint:staticcheck +func (p *preprocessor) preprocessDefLookup(decl *DefLookup, module string) []SyntaxError { + var ( + source_errs []SyntaxError + target_errs []SyntaxError + ) + // preprocess source expressions + decl.Sources, source_errs = p.preprocessExpressionsInModule(decl.Sources, module) + decl.Targets, target_errs = p.preprocessExpressionsInModule(decl.Targets, module) + // Combine errors + return append(source_errs, target_errs...) +} + +// preprocess a "definrange" declaration. +func (p *preprocessor) preprocessDefInRange(decl *DefInRange, module string) []SyntaxError { + var errors []SyntaxError + // preprocess constraint body + decl.Expr, errors = p.preprocessExpressionInModule(decl.Expr, module) + // Done + return errors +} + +// preprocess a "defproperty" declaration. +func (p *preprocessor) preprocessDefProperty(decl *DefProperty, module string) []SyntaxError { + var errors []SyntaxError + // preprocess constraint body + decl.Assertion, errors = p.preprocessExpressionInModule(decl.Assertion, module) + // Done + return errors +} + +// preprocess an optional expression in a given context. That is an expression +// which maybe nil (i.e. doesn't exist). In such case, nil is returned (i.e. +// without any errors). +func (p *preprocessor) preprocessOptionalExpressionInModule(expr Expr, module string) (Expr, []SyntaxError) { + // + if expr != nil { + return p.preprocessExpressionInModule(expr, module) + } + + return nil, nil +} + +// preprocess a sequence of zero or more expressions enclosed in a given module. +// All expressions are expected to be non-voidable (see below for more on +// voidability). +func (p *preprocessor) preprocessExpressionsInModule(exprs []Expr, module string) ([]Expr, []SyntaxError) { + // + errors := []SyntaxError{} + hirExprs := make([]Expr, len(exprs)) + // Iterate each expression in turn + for i, e := range exprs { + if e != nil { + var errs []SyntaxError + hirExprs[i], errs = p.preprocessExpressionInModule(e, module) + errors = append(errors, errs...) + // Check for non-voidability + if hirExprs[i] == nil { + errors = append(errors, *p.srcmap.SyntaxError(e, "void expression not permitted here")) + } + } + } + // + return hirExprs, errors +} + +// preprocess a sequence of zero or more expressions enclosed in a given module. +// A key aspect of this function is that it additionally accounts for "voidable" +// expressions. That is, essentially, to account for debug constraints which +// only exist in debug mode. Hence, when debug mode is not enabled, then a +// debug constraint is "void". +func (p *preprocessor) preprocessVoidableExpressionsInModule(exprs []Expr, module string) ([]Expr, []SyntaxError) { + // + errors := []SyntaxError{} + hirExprs := make([]Expr, len(exprs)) + nils := 0 + // Iterate each expression in turn + for i, e := range exprs { + if e != nil { + var errs []SyntaxError + hirExprs[i], errs = p.preprocessExpressionInModule(e, module) + errors = append(errors, errs...) + // Update dirty flag + if hirExprs[i] == nil { + nils++ + } + } + } + // Nil check. + if nils == 0 { + // Done + return hirExprs, errors + } + // Stip nils. Recall that nils can arise legitimately when we have debug + // constraints, but debug mode is not enabled. In such case, we want to + // strip them out. Since this is a rare occurrence, we try to keep the happy + // path efficient. + nHirExprs := make([]Expr, len(exprs)-nils) + i := 0 + // Strip out nils + for _, e := range hirExprs { + if e != nil { + nHirExprs[i] = e + i++ + } + } + // + return nHirExprs, errors +} + +// preprocess an expression situated in a given context. The context is +// necessary to resolve unqualified names (e.g. for column access, function +// invocations, etc). +func (p *preprocessor) preprocessExpressionInModule(expr Expr, module string) (Expr, []SyntaxError) { + var ( + nexpr Expr + errors []SyntaxError + ) + // + switch e := expr.(type) { + case *ArrayAccess: + arg, errs := p.preprocessExpressionInModule(e.arg, module) + nexpr, errors = &ArrayAccess{e.name, arg, e.binding}, errs + case *Add: + args, errs := p.preprocessExpressionsInModule(e.Args, module) + nexpr, errors = &Add{args}, errs + case *Constant: + return e, nil + case *Debug: + if p.debug { + return p.preprocessExpressionInModule(e.Arg, module) + } + // When debug is not enabled, return "void". + return nil, nil + case *Exp: + arg, errs1 := p.preprocessExpressionInModule(e.Arg, module) + pow, errs2 := p.preprocessExpressionInModule(e.Pow, module) + // Done + nexpr, errors = &Exp{arg, pow}, append(errs1, errs2...) + case *For: + return p.preprocessForInModule(e, module) + case *If: + args, errs := p.preprocessExpressionsInModule([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}, module) + // Construct appropriate if form + nexpr, errors = &If{e.kind, args[0], args[1], args[2]}, errs + case *Invoke: + return p.preprocessInvokeInModule(e, module) + case *List: + args, errs := p.preprocessVoidableExpressionsInModule(e.Args, module) + nexpr, errors = &List{args}, errs + case *Mul: + args, errs := p.preprocessExpressionsInModule(e.Args, module) + nexpr, errors = &Mul{args}, errs + case *Normalise: + arg, errs := p.preprocessExpressionInModule(e.Arg, module) + nexpr, errors = &Normalise{arg}, errs + case *Reduce: + return p.preprocessReduceInModule(e, module) + case *Sub: + args, errs := p.preprocessExpressionsInModule(e.Args, module) + nexpr, errors = &Sub{args}, errs + case *Shift: + arg, errs := p.preprocessExpressionInModule(e.Arg, module) + nexpr, errors = &Shift{arg, e.Shift}, errs + case *VariableAccess: + return e, nil + default: + return nil, p.srcmap.SyntaxErrors(expr, "unknown expression encountered during translation") + } + // Copy over source information + p.srcmap.Copy(expr, nexpr) + // Done + return nexpr, errors +} + +func (p *preprocessor) preprocessForInModule(expr *For, module string) (Expr, []SyntaxError) { + var ( + errors []SyntaxError + mapping map[uint]Expr = make(map[uint]Expr) + ) + // Determine range for index variable + n := expr.End - expr.Start + 1 + args := make([]Expr, n) + // Expand body n times + for i := uint(0); i < n; i++ { + var errs []SyntaxError + // Substitute through for i + mapping[expr.Binding.index] = &Constant{*big.NewInt(int64(i + expr.Start))} + ith := Substitute(expr.Body, mapping, p.srcmap) + // preprocess subsituted expression + args[i], errs = p.preprocessExpressionInModule(ith, module) + errors = append(errors, errs...) + } + // Error check + if len(errors) != 0 { + return nil, errors + } + // Done + return &List{args}, nil +} + +func (p *preprocessor) preprocessInvokeInModule(expr *Invoke, module string) (Expr, []SyntaxError) { + if expr.signature != nil { + body := expr.signature.Apply(expr.Args(), p.srcmap) + return p.preprocessExpressionInModule(body, module) + } + // + return nil, p.srcmap.SyntaxErrors(expr, "unbound function") +} + +func (p *preprocessor) preprocessReduceInModule(expr *Reduce, module string) (Expr, []SyntaxError) { + body, errors := p.preprocessExpressionInModule(expr.arg, module) + // + if list, ok := body.(*List); !ok { + return nil, append(errors, *p.srcmap.SyntaxError(expr.arg, "expected list")) + } else if sig := expr.signature; sig == nil { + return nil, append(errors, *p.srcmap.SyntaxError(expr.arg, "unbound function")) + } else { + reduction := list.Args[0] + // Build reduction + for i := 1; i < len(list.Args); i++ { + reduction = sig.Apply([]Expr{reduction, list.Args[i]}, p.srcmap) + } + // done + return reduction, errors + } +} diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index aedb8d0..ac77764 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -649,7 +649,7 @@ func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) (Type, // type information. Potentially, we could adjust the local scope to // provide the required type information. Or we could have a separate pass // which just determines the type. - body := signature.Apply(expr.Args()) + body := signature.Apply(expr.Args(), nil) // Dig out the type return r.finaliseExpressionInModule(scope, body) } diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go index 68587c7..f68716c 100644 --- a/pkg/corset/translator.go +++ b/pkg/corset/translator.go @@ -2,7 +2,6 @@ package corset import ( "fmt" - "math/big" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/go-corset/pkg/hir" @@ -18,10 +17,10 @@ import ( // easily. Thus, whilst syntax errors can be returned here, this should never // happen. The mechanism is supported, however, to simplify development of new // features, etc. -func TranslateCircuit(env Environment, debug bool, srcmap *sexp.SourceMaps[Node], +func TranslateCircuit(env Environment, srcmap *sexp.SourceMaps[Node], circuit *Circuit) (*hir.Schema, []SyntaxError) { // - t := translator{env, debug, srcmap, hir.EmptySchema()} + t := translator{env, srcmap, hir.EmptySchema()} // Allocate all modules into schema t.translateModules(circuit) // Translate input columns @@ -42,8 +41,6 @@ type translator struct { // Environment is needed for determining the identifiers for modules and // columns. env Environment - // Debug enables the use of debug constraints. - debug bool // Source maps nodes in the circuit back to the spans in their original // source files. This is needed when reporting syntax errors to generate // highlights of the relevant source line(s) in question. @@ -386,8 +383,6 @@ func (t *translator) translateUnitExpressionsInModule(exprs []Expr, module strin } // Translate a sequence of zero or more expressions enclosed in a given module. -// All expressions are expected to be non-voidable (see below for more on -// voidability). func (t *translator) translateExpressionsInModule(exprs []Expr, module string, shift int) ([]hir.Expr, []SyntaxError) { // @@ -409,108 +404,52 @@ func (t *translator) translateExpressionsInModule(exprs []Expr, module string, return hirExprs, errors } -// Translate a sequence of zero or more expressions enclosed in a given module. -// A key aspect of this function is that it additionally accounts for "voidable" -// expressions. That is, essentially, to account for debug constraints which -// only exist in debug mode. Hence, when debug mode is not enabled, then a -// debug constraint is "void". -func (t *translator) translateVoidableExpressionsInModule(exprs []Expr, module string, - shift int) ([]hir.Expr, []SyntaxError) { - // - errors := []SyntaxError{} - hirExprs := make([]hir.Expr, len(exprs)) - nils := 0 - // Iterate each expression in turn - for i, e := range exprs { - if e != nil { - var errs []SyntaxError - hirExprs[i], errs = t.translateExpressionInModule(e, module, shift) - errors = append(errors, errs...) - // Update dirty flag - if hirExprs[i] == nil { - nils++ - } - } - } - // Nil check. - if nils == 0 { - // Done - return hirExprs, errors - } - // Stip nils. Recall that nils can arise legitimately when we have debug - // constraints, but debug mode is not enabled. In such case, we want to - // strip them out. Since this is a rare occurrence, we try to keep the happy - // path efficient. - nHirExprs := make([]hir.Expr, len(exprs)-nils) - i := 0 - // Strip out nils - for _, e := range hirExprs { - if e != nil { - nHirExprs[i] = e - i++ - } - } - // - return nHirExprs, errors -} - // Translate an expression situated in a given context. The context is // necessary to resolve unqualified names (e.g. for column access, function // invocations, etc). func (t *translator) translateExpressionInModule(expr Expr, module string, shift int) (hir.Expr, []SyntaxError) { - if e, ok := expr.(*ArrayAccess); ok { + switch e := expr.(type) { + case *ArrayAccess: return t.translateArrayAccessInModule(e, shift) - } else if v, ok := expr.(*Add); ok { - args, errs := t.translateExpressionsInModule(v.Args, module, shift) + case *Add: + args, errs := t.translateExpressionsInModule(e.Args, module, shift) return &hir.Add{Args: args}, errs - } else if e, ok := expr.(*Constant); ok { + case *Constant: var val fr.Element // Initialise field from bigint val.SetBigInt(&e.Val) // return &hir.Constant{Val: val}, nil - } else if e, ok := expr.(*Debug); ok { - if t.debug { - return t.translateExpressionInModule(e.Arg, module, shift) - } - // When debug is not enabled, simply substitute for 0. - return nil, nil - } else if e, ok := expr.(*Exp); ok { + case *Exp: return t.translateExpInModule(e, module, shift) - } else if e, ok := expr.(*For); ok { - return t.translateForInModule(e, module, shift) - } else if v, ok := expr.(*If); ok { - args, errs := t.translateExpressionsInModule([]Expr{v.Condition, v.TrueBranch, v.FalseBranch}, module, shift) + case *If: + args, errs := t.translateExpressionsInModule([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}, module, shift) // Construct appropriate if form - if v.IsIfZero() { + if e.IsIfZero() { return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}, errs - } else if v.IsIfNotZero() { + } else if e.IsIfNotZero() { // In this case, switch the ordering. return &hir.IfZero{Condition: args[0], TrueBranch: args[2], FalseBranch: args[1]}, errs } // Should be unreachable return nil, t.srcmap.SyntaxErrors(expr, "unresolved conditional encountered during translation") - } else if e, ok := expr.(*Invoke); ok { - return t.translateInvokeInModule(e, module, shift) - } else if v, ok := expr.(*List); ok { - args, errs := t.translateVoidableExpressionsInModule(v.Args, module, shift) + case *List: + args, errs := t.translateExpressionsInModule(e.Args, module, shift) return &hir.List{Args: args}, errs - } else if v, ok := expr.(*Mul); ok { - args, errs := t.translateExpressionsInModule(v.Args, module, shift) + case *Mul: + args, errs := t.translateExpressionsInModule(e.Args, module, shift) return &hir.Mul{Args: args}, errs - } else if v, ok := expr.(*Normalise); ok { - arg, errs := t.translateExpressionInModule(v.Arg, module, shift) + case *Normalise: + arg, errs := t.translateExpressionInModule(e.Arg, module, shift) return &hir.Normalise{Arg: arg}, errs - } else if v, ok := expr.(*Reduce); ok { - return t.translateReduceInModule(v, module, shift) - } else if v, ok := expr.(*Sub); ok { - args, errs := t.translateExpressionsInModule(v.Args, module, shift) + case *Sub: + args, errs := t.translateExpressionsInModule(e.Args, module, shift) return &hir.Sub{Args: args}, errs - } else if e, ok := expr.(*Shift); ok { + case *Shift: return t.translateShiftInModule(e, module, shift) - } else if e, ok := expr.(*VariableAccess); ok { + case *VariableAccess: return t.translateVariableAccessInModule(e, module, shift) - } else { + default: return nil, t.srcmap.SyntaxErrors(expr, "unknown expression encountered during translation") } } @@ -560,57 +499,6 @@ func (t *translator) translateExpInModule(expr *Exp, module string, shift int) ( return nil, errs } -func (t *translator) translateForInModule(expr *For, module string, shift int) (hir.Expr, []SyntaxError) { - var ( - errors []SyntaxError - mapping map[uint]Expr = make(map[uint]Expr) - ) - // Determine range for index variable - n := expr.End - expr.Start + 1 - args := make([]hir.Expr, n) - // Expand body n times - for i := uint(0); i < n; i++ { - var errs []SyntaxError - // Substitute through for i - mapping[expr.Binding.index] = &Constant{*big.NewInt(int64(i + expr.Start))} - ith := expr.Body.Substitute(mapping) - // Translate subsituted expression - args[i], errs = t.translateExpressionInModule(ith, module, shift) - errors = append(errors, errs...) - } - // Error check - if len(errors) != 0 { - return nil, errors - } - // Done - return &hir.List{Args: args}, nil -} - -func (t *translator) translateInvokeInModule(expr *Invoke, module string, shift int) (hir.Expr, []SyntaxError) { - if expr.signature != nil { - body := expr.signature.Apply(expr.Args()) - return t.translateExpressionInModule(body, module, shift) - } - // - return nil, t.srcmap.SyntaxErrors(expr, "unbound function") -} - -func (t *translator) translateReduceInModule(expr *Reduce, module string, shift int) (hir.Expr, []SyntaxError) { - if list, ok := expr.arg.(*List); !ok { - return nil, t.srcmap.SyntaxErrors(expr.arg, "expected list") - } else if sig := expr.signature; sig == nil { - return nil, t.srcmap.SyntaxErrors(expr.arg, "unbound function") - } else { - reduction := list.Args[0] - // Build reduction - for i := 1; i < len(list.Args); i++ { - reduction = sig.Apply([]Expr{reduction, list.Args[i]}) - } - // Translate reduction - return t.translateExpressionInModule(reduction, module, shift) - } -} - func (t *translator) translateShiftInModule(expr *Shift, module string, shift int) (hir.Expr, []SyntaxError) { constant := expr.Shift.AsConstant() // Determine the shift constant diff --git a/pkg/sexp/source_map.go b/pkg/sexp/source_map.go index e7a3d19..965743b 100644 --- a/pkg/sexp/source_map.go +++ b/pkg/sexp/source_map.go @@ -1,6 +1,8 @@ package sexp -import "fmt" +import ( + "fmt" +) // Span represents a contiguous slice of the original string. Instead of // representing this as a string slice, however, it is useful to retain the @@ -83,6 +85,20 @@ func (p *SourceMaps[T]) Join(srcmap *SourceMap[T]) { p.maps = append(p.maps, *srcmap) } +// Copy copies the source mapping for one node to the source mapping for +// another. The main use of this is when an existing node is expanded into some +// other nodes (e.g. during preprocessing). +func (p *SourceMaps[T]) Copy(from T, to T) { + for _, m := range p.maps { + if m.Has(from) { + span := m.Get(from) + m.Put(to, span) + // Done + return + } + } +} + // SourceMap maps terms from an AST to slices of their originating string. This // is important for error handling when we wish to highlight exactly where, in // the original source file, a given error has arisen. diff --git a/pkg/test/valid_corset_test.go b/pkg/test/valid_corset_test.go index d0d86f9..fc354e0 100644 --- a/pkg/test/valid_corset_test.go +++ b/pkg/test/valid_corset_test.go @@ -626,11 +626,9 @@ func Test_PureFun_02(t *testing.T) { Check(t, false, "purefun_02") } -/* - func Test_PureFun_03(t *testing.T) { - Check(t, false, "purefun_03") - } -*/ +func Test_PureFun_03(t *testing.T) { + Check(t, false, "purefun_03") +} func Test_PureFun_04(t *testing.T) { Check(t, false, "purefun_04") @@ -676,11 +674,10 @@ func Test_Array_01(t *testing.T) { Check(t, false, "array_01") } -/* - func Test_Array_02(t *testing.T) { - Check(t, false, "array_02") - } -*/ +func Test_Array_02(t *testing.T) { + Check(t, false, "array_02") +} + func Test_Array_03(t *testing.T) { Check(t, false, "array_03") }