Skip to content

Commit

Permalink
Prevent adding constraints which always vanish
Browse files Browse the repository at this point in the history
  • Loading branch information
DavePearce committed Oct 7, 2024
1 parent cfa4977 commit 8b969de
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 2 deletions.
30 changes: 30 additions & 0 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ type Expr interface {

// Equate one expression with another
Equate(Expr) Expr

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
AsConstant() *fr.Element
}

// ============================================================================
Expand Down Expand Up @@ -77,6 +82,11 @@ func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }
// direction (right).
func (p *Add) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (p *Add) AsConstant() *fr.Element { return nil }

// ============================================================================
// Subtraction
// ============================================================================
Expand Down Expand Up @@ -124,6 +134,11 @@ func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }
// direction (right).
func (p *Sub) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (p *Sub) AsConstant() *fr.Element { return nil }

// ============================================================================
// Multiplication
// ============================================================================
Expand Down Expand Up @@ -171,6 +186,11 @@ func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }
// direction (right).
func (p *Mul) Bounds() util.Bounds { return util.BoundsForArray(p.Args) }

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (p *Mul) AsConstant() *fr.Element { return nil }

// ============================================================================
// Constant
// ============================================================================
Expand Down Expand Up @@ -225,6 +245,11 @@ func (p *Constant) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}}
// direction (right). A constant has zero shift.
func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND }

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (p *Constant) AsConstant() *fr.Element { return &p.Value }

// ColumnAccess represents reading the value held at a given column in the
// tabular context. Furthermore, the current row maybe shifted up (or down) by
// a given amount. Suppose we are evaluating a constraint on row k=5 which
Expand Down Expand Up @@ -290,3 +315,8 @@ func (p *ColumnAccess) Bounds() util.Bounds {
// Negative shift
return util.NewBounds(uint(-p.Shift), 0)
}

// AsConstant determines whether or not this is a constant expression. If
// so, the constant is returned; otherwise, nil is returned. NOTE: this
// does not perform any form of simplification to determine this.
func (p *ColumnAccess) AsConstant() *fr.Element { return nil }
12 changes: 11 additions & 1 deletion pkg/mir/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
is_const := true
prod := one
rs := make([]Expr, len(es))
ones := 0
//
for i, e := range es {
rs[i] = applyConstantPropagation(e, schema)
Expand All @@ -97,6 +98,9 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
if ok && c.Value.IsZero() {
// No matter what, outcome is zero.
return &Constant{c.Value}
} else if ok && c.Value.IsOne() {
ones++
rs[i] = nil
} else if ok && is_const {
// Continue building constant
prod.Mul(&prod, &c.Value)
Expand All @@ -107,8 +111,14 @@ func applyConstantPropagationMul(es []Expr, schema sc.Schema) Expr {
// Check if constant
if is_const {
return &Constant{prod}
} else if ones > 0 {
rs = util.RemoveMatching[Expr](rs, func(item Expr) bool { return item == nil })
}
//
// Sanity check what's left.
if len(rs) == 1 {
return rs[0]
}
// Done
return &Mul{rs}
}

Expand Down
9 changes: 8 additions & 1 deletion pkg/mir/lower.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ func lowerConstraintToAir(c sc.Constraint, schema *air.Schema) {
lowerLookupConstraintToAir(v, schema)
} else if v, ok := c.(VanishingConstraint); ok {
air_expr := lowerExprTo(v.Context(), v.Constraint().Expr, schema)
schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), air_expr)
// Check whether this is a constant
constant := air_expr.AsConstant()
// Check for compile-time constants
if constant != nil && !constant.IsZero() {
panic(fmt.Sprintf("constraint %s cannot vanish!", v.Handle()))
} else if constant == nil {
schema.AddVanishingConstraint(v.Handle(), v.Context(), v.Domain(), air_expr)
}
} else if v, ok := c.(*constraint.TypeConstraint); ok {
if t := v.Type().AsUint(); t != nil {
// Yes, a constraint is implied. Now, decide whether to use a range
Expand Down
27 changes: 27 additions & 0 deletions pkg/util/arrays.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,33 @@ func ReplaceFirstOrPanic[T comparable](columns []T, from T, to T) {
panic(fmt.Sprintf("invalid replace (item %s not found)", any(from)))
}

// RemoveMatching removes all elements from an array matching the given item.
func RemoveMatching[T any](items []T, predicate Predicate[T]) []T {
count := 0
// Check how many matches we have
for _, r := range items {
if !predicate(r) {
count++
}
}
// Check for stuff to remove
if count != len(items) {
nitems := make([]T, count)
j := 0
// Remove items
for i, r := range items {
if !predicate(r) {
nitems[j] = items[i]
j++
}
}
//
items = nitems
}
//
return items
}

// Equals returns true if both arrays contain equivalent elements.
func Equals(lhs []*fr.Element, rhs []*fr.Element) bool {
if len(lhs) != len(rhs) {
Expand Down

0 comments on commit 8b969de

Please sign in to comment.