diff --git a/pkg/air/expr.go b/pkg/air/expr.go index ea55083..cbbb691 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -28,6 +28,11 @@ type Expr interface { // Equate one expression with another Equate(Expr) Expr + + // AsConstant determines whether or not this is a constant expression. If + // so, the constant is returned; otherwise, nil is returned. NOTE: this + // does not perform any form of simplification to determine this. + AsConstant() *fr.Element } // ============================================================================ @@ -77,6 +82,11 @@ func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } // direction (right). func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (p *Add) AsConstant() *fr.Element { return nil } + // ============================================================================ // Subtraction // ============================================================================ @@ -124,6 +134,11 @@ func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } // direction (right). func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (p *Sub) AsConstant() *fr.Element { return nil } + // ============================================================================ // Multiplication // ============================================================================ @@ -171,6 +186,11 @@ func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} } // direction (right). func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) } +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (p *Mul) AsConstant() *fr.Element { return nil } + // ============================================================================ // Constant // ============================================================================ @@ -225,6 +245,11 @@ func (p *Constant) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} // direction (right). A constant has zero shift. func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (p *Constant) AsConstant() *fr.Element { return &p.Value } + // ColumnAccess represents reading the value held at a given column in the // tabular context. Furthermore, the current row maybe shifted up (or down) by // a given amount. Suppose we are evaluating a constraint on row k=5 which @@ -290,3 +315,8 @@ func (p *ColumnAccess) Bounds() util.Bounds { // Negative shift return util.NewBounds(uint(-p.Shift), 0) } + +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (p *ColumnAccess) AsConstant() *fr.Element { return nil } diff --git a/pkg/mir/const.go b/pkg/mir/const.go index 1dcf0a0..b5854cd 100644 --- a/pkg/mir/const.go +++ b/pkg/mir/const.go @@ -88,6 +88,7 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr { is_const := true prod := one rs := make([]Expr, len(es)) + ones := 0 // for i, e := range es { rs[i] = applyConstantPropagation(e, schema) @@ -97,6 +98,9 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr { if ok && c.Value.IsZero() { // No matter what, outcome is zero. return &Constant{c.Value} + } else if ok && c.Value.IsOne() { + ones++ + rs[i] = nil } else if ok && is_const { // Continue building constant prod.Mul(&prod, &c.Value) @@ -107,8 +111,14 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr { // Check if constant if is_const { return &Constant{prod} + } else if ones > 0 { + rs = util.RemoveMatching[Expr](rs, func(item Expr) bool { return item == nil }) } - // + // Sanity check what's left. + if len(rs) == 1 { + return rs[0] + } + // Done return &Mul{rs} } diff --git a/pkg/mir/lower.go b/pkg/mir/lower.go index 646039e..07a6a87 100644 --- a/pkg/mir/lower.go +++ b/pkg/mir/lower.go @@ -63,7 +63,14 @@ func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) { lowerLookupConstraintToAir(v, schema) } else if v, ok := c.(VanishingConstraint); ok { air_expr := lowerExprTo(v.Context(), v.Constraint().Expr, schema) - schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), air_expr) + // Check whether this is a constant + constant := air_expr.AsConstant() + // Check for compile-time constants + if constant != nil && !constant.IsZero() { + panic(fmt.Sprintf("constraint %s cannot vanish!", v.Handle())) + } else if constant == nil { + schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), air_expr) + } } else if v, ok := c.(*constraint.TypeConstraint); ok { if t := v.Type().AsUint(); t != nil { // Yes, a constraint is implied. Now, decide whether to use a range diff --git a/pkg/util/arrays.go b/pkg/util/arrays.go index 0557c86..da8f683 100644 --- a/pkg/util/arrays.go +++ b/pkg/util/arrays.go @@ -21,6 +21,33 @@ func ReplaceFirstOrPanic[T comparable](columns []T, from T, to T) { panic(fmt.Sprintf("invalid replace (item %s not found)", any(from))) } +// RemoveMatching removes all elements from an array matching the given item. +func RemoveMatching[T any](items []T, predicate Predicate[T]) []T { + count := 0 + // Check how many matches we have + for _, r := range items { + if !predicate(r) { + count++ + } + } + // Check for stuff to remove + if count != len(items) { + nitems := make([]T, count) + j := 0 + // Remove items + for i, r := range items { + if !predicate(r) { + nitems[j] = items[i] + j++ + } + } + // + items = nitems + } + // + return items +} + // Equals returns true if both arrays contain equivalent elements. func Equals(lhs []*fr.Element, rhs []*fr.Element) bool { if len(lhs) != len(rhs) {