mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Don't uselessly recompute axiom dict every static eval call (#135429)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/135429 Approved by: https://github.com/isuruf
This commit is contained in:
committed by
PyTorch MergeBot
parent
6ecb73bafd
commit
1d6e0412f5
@ -10002,6 +10002,9 @@ ShapeEnv not equal: field values don't match:
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {0 < Mod(s0, 3): False, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, False: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False, True: True}
|
||||
> Right: {}
|
||||
==> divisible: values don't match.
|
||||
> Left: {Mod(s0, 3)}
|
||||
> Right: {}
|
||||
@ -10039,6 +10042,9 @@ ShapeEnv not equal: field values don't match:
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {False: False, True: True}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [Eq(s0, 3)]
|
||||
> Right: []
|
||||
@ -10080,6 +10086,9 @@ ShapeEnv not equal: field values don't match:
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {3 <= s0: True, s0 < 3: False}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [s0 >= 3]
|
||||
> Right: []
|
||||
@ -10112,6 +10121,9 @@ ShapeEnv not equal: field values don't match:
|
||||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {0 < PythonMod(u0, 3): False, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, False: False, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) <= 0: True, True: True}
|
||||
> Right: {}
|
||||
==> deferred_runtime_asserts: values don't match.
|
||||
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
|
||||
> Right: {}
|
||||
|
@ -2637,6 +2637,7 @@ class ShapeEnv:
|
||||
)
|
||||
|
||||
self.guards: List[ShapeGuard] = []
|
||||
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
|
||||
# Maps symbolic ints to their original concrete values
|
||||
# Currently populated from tensors
|
||||
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
|
||||
@ -4658,15 +4659,16 @@ class ShapeEnv:
|
||||
expr = canonicalize_bool_expr(expr)
|
||||
|
||||
# Pattern matching
|
||||
symbols = tuple(expr.free_symbols)
|
||||
if axioms is None:
|
||||
axioms = self.get_axioms(symbols, compute_hint=compute_hint)
|
||||
subst = {}
|
||||
for e in axioms:
|
||||
if e.free_symbols.issubset(expr.free_symbols):
|
||||
subst.update(dict(self.get_implications(self.simplify(e))))
|
||||
subst = self.axioms
|
||||
else:
|
||||
subst = {}
|
||||
for e in axioms:
|
||||
if e.free_symbols.issubset(expr.free_symbols):
|
||||
subst.update(dict(self.get_implications(self.simplify(e))))
|
||||
|
||||
expr = expr.xreplace(subst)
|
||||
# TODO: compute hint might have gotten broken here
|
||||
|
||||
fs = expr.free_symbols
|
||||
|
||||
@ -5421,6 +5423,7 @@ class ShapeEnv:
|
||||
stack = CapturedTraceback.extract(skip=1)
|
||||
guard = ShapeGuard(g, stack)
|
||||
self.guards.append(guard)
|
||||
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
||||
else:
|
||||
# it's fine to defer simple guards here without checking,
|
||||
# the _maybe_guard_rel() call above will set replacements if possible,
|
||||
@ -5532,6 +5535,7 @@ class ShapeEnv:
|
||||
# and the guard in question has no unbacked SymInts in front
|
||||
ix = cands[-1] if cands else None
|
||||
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
|
||||
self.axioms.update(dict(self.get_implications(self.simplify(expr))))
|
||||
self.num_deferred_runtime_asserts += 1
|
||||
self._update_version_counter()
|
||||
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
|
||||
|
Reference in New Issue
Block a user