From 1d6e0412f5205b1cd709e034526d7f21d6f2d56f Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 27 Sep 2024 07:45:45 -0700 Subject: [PATCH] Don't uselessly recompute axiom dict every static eval call (#135429) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/135429 Approved by: https://github.com/isuruf --- test/dynamo/test_misc.py | 12 ++++++++++++ torch/fx/experimental/symbolic_shapes.py | 16 ++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 7cc8da7b01dd..f0374be7a6e6 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10002,6 +10002,9 @@ ShapeEnv not equal: field values don't match: """\ ShapeEnv not equal: field values don't match: +==> axioms: values don't match. + > Left: {0 < Mod(s0, 3): False, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, False: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False, True: True} + > Right: {} ==> divisible: values don't match. > Left: {Mod(s0, 3)} > Right: {} @@ -10039,6 +10042,9 @@ ShapeEnv not equal: field values don't match: """\ ShapeEnv not equal: field values don't match: +==> axioms: values don't match. + > Left: {False: False, True: True} + > Right: {} ==> guards: values don't match. > Left: [Eq(s0, 3)] > Right: [] @@ -10080,6 +10086,9 @@ ShapeEnv not equal: field values don't match: """\ ShapeEnv not equal: field values don't match: +==> axioms: values don't match. + > Left: {3 <= s0: True, s0 < 3: False} + > Right: {} ==> guards: values don't match. > Left: [s0 >= 3] > Right: [] @@ -10112,6 +10121,9 @@ ShapeEnv not equal: field values don't match: """\ ShapeEnv not equal: field values don't match: +==> axioms: values don't match. + > Left: {0 < PythonMod(u0, 3): False, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, False: False, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) <= 0: True, True: True} + > Right: {} ==> deferred_runtime_asserts: values don't match. > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Right: {} diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 1e80e6ec45a1..5ae10bc83842 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2637,6 +2637,7 @@ class ShapeEnv: ) self.guards: List[ShapeGuard] = [] + self.axioms: Dict[sympy.Expr, sympy.Expr] = {} # Maps symbolic ints to their original concrete values # Currently populated from tensors self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} @@ -4658,15 +4659,16 @@ class ShapeEnv: expr = canonicalize_bool_expr(expr) # Pattern matching - symbols = tuple(expr.free_symbols) if axioms is None: - axioms = self.get_axioms(symbols, compute_hint=compute_hint) - subst = {} - for e in axioms: - if e.free_symbols.issubset(expr.free_symbols): - subst.update(dict(self.get_implications(self.simplify(e)))) + subst = self.axioms + else: + subst = {} + for e in axioms: + if e.free_symbols.issubset(expr.free_symbols): + subst.update(dict(self.get_implications(self.simplify(e)))) expr = expr.xreplace(subst) + # TODO: compute hint might have gotten broken here fs = expr.free_symbols @@ -5421,6 +5423,7 @@ class ShapeEnv: stack = CapturedTraceback.extract(skip=1) guard = ShapeGuard(g, stack) self.guards.append(guard) + self.axioms.update(dict(self.get_implications(self.simplify(g)))) else: # it's fine to defer simple guards here without checking, # the _maybe_guard_rel() call above will set replacements if possible, @@ -5532,6 +5535,7 @@ class ShapeEnv: # and the guard in question has no unbacked SymInts in front ix = cands[-1] if cands else None self.deferred_runtime_asserts.setdefault(ix, []).append(ra) + self.axioms.update(dict(self.get_implications(self.simplify(expr)))) self.num_deferred_runtime_asserts += 1 self._update_version_counter() self._log_guard("runtime_assert", orig_expr, forcing_spec=False)