Don't uselessly recompute axiom dict every static eval call (#135429)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135429
Approved by: https://github.com/isuruf
This commit is contained in:
Edward Z. Yang
2024-09-27 07:45:45 -07:00
committed by PyTorch MergeBot
parent 6ecb73bafd
commit 1d6e0412f5
2 changed files with 22 additions and 6 deletions

View File

@ -10002,6 +10002,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < Mod(s0, 3): False, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, False: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False, True: True}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(s0, 3)}
> Right: {}
@ -10039,6 +10042,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {False: False, True: True}
> Right: {}
==> guards: values don't match.
> Left: [Eq(s0, 3)]
> Right: []
@ -10080,6 +10086,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {3 <= s0: True, s0 < 3: False}
> Right: {}
==> guards: values don't match.
> Left: [s0 >= 3]
> Right: []
@ -10112,6 +10121,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < PythonMod(u0, 3): False, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, False: False, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) <= 0: True, True: True}
> Right: {}
==> deferred_runtime_asserts: values don't match.
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
> Right: {}

View File

@ -2637,6 +2637,7 @@ class ShapeEnv:
)
self.guards: List[ShapeGuard] = []
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
@ -4658,15 +4659,16 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr)
# Pattern matching
symbols = tuple(expr.free_symbols)
if axioms is None:
axioms = self.get_axioms(symbols, compute_hint=compute_hint)
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(self.simplify(e))))
subst = self.axioms
else:
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(self.simplify(e))))
expr = expr.xreplace(subst)
# TODO: compute hint might have gotten broken here
fs = expr.free_symbols
@ -5421,6 +5423,7 @@ class ShapeEnv:
stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
self.axioms.update(dict(self.get_implications(self.simplify(g))))
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
@ -5532,6 +5535,7 @@ class ShapeEnv:
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
self.axioms.update(dict(self.get_implications(self.simplify(expr))))
self.num_deferred_runtime_asserts += 1
self._update_version_counter()
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)