Skip to content

Commit

Permalink
Merge pull request #25395 from gnecula:poly_better_eq
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705105803
  • Loading branch information
Google-ML-Automation committed Dec 11, 2024
2 parents 98c4055 + 60f9da5 commit 01206f8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
37 changes: 21 additions & 16 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ class _SymbolicConstraint:
# Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2
cmp: Comparator
debug_str: str # The form in which the user expressed it, for error messages
e1: DimSize # This has been normalized w.r.t. previous constraints only
e2: DimSize # This has been normalized w.r.t. previous constraints only
# e1, e2, and diff == e1 - e2, are normalized w.r.t. previous constraints only
e1: DimSize
e2: DimSize
# we pre-compute diff to avoid having the normalization rule kick in later.
diff: DimSize

def __repr__(self):
return f"Constraint({self.debug_str})"
Expand Down Expand Up @@ -1061,29 +1064,33 @@ def _parse_and_process_explicit_constraint(self, c_str: str):
if cmp == Comparator.GEQ and not is_geq:
e1, e2 = e2, e1

diff = e1 - e2
if (diff_const := _DimExpr._to_constant(diff)) is not None:
if ((cmp == Comparator.EQ and diff_const != 0) or
(cmp == Comparator.GEQ and diff_const < 0)):
raise ValueError(f"Unsatisfiable explicit constraint: {c_str}")
# Compute e1 - e2 before we add to normalization rules
constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2,
diff=e1 - e2)
self._process_explicit_constraint(constr)

def _process_explicit_constraint(self, constr: _SymbolicConstraint):
if (diff_const := _DimExpr._to_constant(constr.diff)) is not None:
if ((constr.cmp == Comparator.EQ and diff_const != 0) or
(constr.cmp == Comparator.GEQ and diff_const < 0)):
raise ValueError(f"Unsatisfiable explicit constraint: {constr.debug_str}")
return

if cmp == Comparator.EQ:
if not isinstance(e1, _DimExpr):
if constr.cmp == Comparator.EQ:
if not isinstance(constr.e1, _DimExpr):
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
"The left-hand-side must be of the form `term * coefficient`.")
(before, before_k), *rest = e1._sorted_terms
(before, before_k), *rest = constr.e1._sorted_terms
if rest:
raise ValueError("Invalid equality constraint: {e1} == {e2}. "
"The left-hand-side must be of the form `term * coefficient`.")

after = _ensure_poly(e2, "parse_constraint", e1.scope) # type: ignore[name-error,unused-ignore]
after = _ensure_poly(constr.e2, "parse_constraint", constr.e1.scope) # type: ignore[name-error,unused-ignore]
if before in self._normalization_rules:
raise NotImplementedError(
f"Found multiple equality constraints with the same left-hand-side: {before}")
self._normalization_rules[before] = (after, before_k)

constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2)
self._explicit_constraints.append(constr)

def _check_same_scope(self, other: _DimExpr,
Expand Down Expand Up @@ -2120,14 +2127,12 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv):
for constr in scope._explicit_constraints:
# We can't just construct constr.e1 - constr.e2 because for an equality
# constraint it would be reduced to 0.
c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore
c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore
c_diff = c_e1 - c_e2
c_diff = constr.diff._evaluate(shape_env) if not core.is_constant_dim(constr.diff) else constr.diff # type: ignore
shape_constraints.add_constraint(
constr.cmp, c_diff, 0,
error_message_pieces=[
f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. "
f"Expected '{constr.e1} - {constr.e2}' to be "
f"Expected '{constr.diff}' to be "
f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, "
"but found ", c_diff,

Expand Down
22 changes: 7 additions & 15 deletions jax/_src/export/shape_poly_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,11 @@ def initialize(self) -> _DecisionByElimination:
# the result (albeit, for now, without a good feedback loop to understand
# how the order matters for inequalities).
for constr in self.scope._explicit_constraints:
if not core.is_constant_dim(constr.e1):
self.add_implicit_constraints_expr(constr.e1) # type: ignore
if not core.is_constant_dim(constr.e2):
self.add_implicit_constraints_expr(constr.e2) # type: ignore
# The equality constraints are not needed for inequality decisions,
# because the LHS should always be rewritten in terms of the RHS.
# In fact, adding them may break the assumption that if we eliminate
# the leading term we end up with only smaller terms, because the LHS
# may appear in the rest and may be rewritten to something larger.
# However, we want to add the implicit constraints within.
if constr.cmp == Comparator.GEQ:
self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0,
constr.debug_str)
if not core.is_constant_dim(constr.diff):
self.add_implicit_constraints_expr(constr.diff) # type: ignore

self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
constr.debug_str)


# Clear the cache, since we have added constraints.
Expand Down Expand Up @@ -197,7 +189,7 @@ def combine_term_with_existing(self, t: _DimTerm, t_k: int, *,
Combine a term with existing constraints.
For input (t, t_k) the tuple (c_eq, c, c_s, t_s) is among the returned
tuples if there exists a constraint `c =[c_eq] 0` that can be combined
with `t*t_k` to eliminate `t`.
with `t*t_k` to eliminate `t`, and:
* `c =[c_eq] 0`
* The term `comb = t*t_k*t_s + c*c_s` does not contain `t`, and if
Expand All @@ -207,7 +199,7 @@ def combine_term_with_existing(self, t: _DimTerm, t_k: int, *,
"""
# TODO: maybe a generator is useful here instead of materializing the list
acc: list[tuple[Comparator, _DimExpr, int, int]] = []
# First combine with the existing term constraints
# First combine with the existing term bounds
t_lb, t_ub = self._term_bounds.get(t, (-np.inf, np.inf))
if t_lb == t_ub:
acc.append((Comparator.EQ, _DimExpr(((t, 1),), scope) - int(t_lb),
Expand Down
25 changes: 18 additions & 7 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,27 +1114,38 @@ def test_constraints_eq_threefry(self):
self.assertEqual(x_reshaped, (a + a % 2) // -2)
self.assertEqual(2 * x_reshaped, a)

def test_constraints_a_minus_4d_eq(self):
def test_constraints_eq_a_minus_4d(self):
# simulates d = div(a, 4) and m = mod(a, 4)
assumptions = ["4*d == a - m", "m >= 0", "m <= 3"]
scope = shape_poly.SymbolicScope(assumptions)
constraints = ["4*d == a - m", "m >= 0", "m <= 3"]
scope = shape_poly.SymbolicScope(constraints)
a, d = shape_poly.symbolic_shape("a, d", scope=scope)
self.assertEqual(_bounds(a - 4*d), (1, 3)) # a - 4d = m >= 1
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a - 2*d),
_expect(best=(3, np.inf), current=(-np.inf, np.inf))) # a - 2d = m + 2d >= 3
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a),
_expect(best=(5, np.inf), current=(1, np.inf))) # a >= 4d + m >= 5
_expect(best=(5, np.inf), current=(4, np.inf))) # a >= 4d + m >= 5

# Now with a different order of constraints
assumptions1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"]
scope1 = shape_poly.SymbolicScope(assumptions1)
constraints1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"]
scope1 = shape_poly.SymbolicScope(constraints1)
a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1)
self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1
self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3
self.assertEqual(_bounds(a1), (5, np.inf)) # a >= 4d + m >= 5

def test_constraints_eq_geq(self):
# We ensure that an equality constraint it is usable not just for
# normalization but also for inequality reasoning.
a, b = export.symbolic_shape(
"a, b", constraints=["4 * a == b"])
self.assertGreaterEqual(b, a)
self.assertGreaterEqual(b, 3*a)
self.assertGreaterEqual(b, 4 * a)
self.assertGreaterEqual(5 * a, b)
self.assertGreaterEqual(9 * a, 2*b)

def test_constraints_error_msg(self):
a, b = shape_poly.symbolic_shape("a, b",
constraints=("a >= 5",))
Expand Down Expand Up @@ -1713,7 +1724,7 @@ def f(x): # x: i32[a]

with self.assertRaisesRegex(
ValueError,
re.escape("Expected '4 - a' to be greater or equal to 0, but found -1")):
re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")):
exp.call(np.arange(5, dtype=np.int32))

def test_constraints_eq_0_compile_time_check(self):
Expand Down

0 comments on commit 01206f8

Please sign in to comment.