From 323fb4dad0effe80b36abe60bc47c6894e1fe8b2 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 21 Jan 2025 07:39:46 -0800 Subject: [PATCH] Unconditionally exclude upper bound in all size oblivious tests (#144867) I was thinking about https://github.com/pytorch/pytorch/pull/144471 some more and I thought, "Hmm, why not just always exclude the constant upper bound." So here it is. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/144867 Approved by: https://github.com/bobrenjc93 --- torch/__init__.py | 5 +++-- torch/fx/experimental/symbolic_shapes.py | 22 ++++++---------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index faedf38c6673..2708984c61ac 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1664,8 +1664,9 @@ def _check_is_size(i, message=None, *, max=None): When max is not None, this specifies an upper bound equivalent to ``_check(i <= max)``. This bound is also subject to alternate semantics: - in ``guard_size_oblivious`` tests, we assume that the max bound is treated - equivalently to all other values. + in ``guard_size_oblivious`` tests, we assume that a constant max bound is + treated equivalently to all other values. Symbolic max bounds are not yet + supported. NB: Do NOT use this in contexts where a -1 size would be valid (indicating to infer the size from context, or if you should wrap-around or truncate). diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 576dd1b90841..5469fdfe4675 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1945,7 +1945,6 @@ class _SymbolInfo(NamedTuple): vr: Optional[ValueRanges] val: Optional[sympy.Integer] is_size_like: bool - oblivious_upper_bound_exclusive: sympy.Integer @lru_cache(None) @@ -1967,7 +1966,7 @@ def _maybe_evaluate_static_worker( new_shape_env = {} new_range_env = {} for idx, sinfo in enumerate(symbol_info): - k, vr, val, is_size_like, oblivious_upper_bound_exclusive = sinfo + k, vr, val, is_size_like = sinfo if isinstance(val, SingletonInt): # Skip var_ranges logic for SingletonInt which is only used # for jagged layout NestedTensors today @@ -1982,8 +1981,9 @@ def _maybe_evaluate_static_worker( # This is similar to the flavor where size oblivious omits # 0/1, it changes semantics but in a benign way. upper = min(2**48, vr.upper) - if oblivious_upper_bound_exclusive is not None: - upper = min(upper, oblivious_upper_bound_exclusive - 1) + # Excluding the very upper bound can be helpful + if upper > lower: + upper = upper - 1 # This is a bit dodgy: what this means is that there was a # size-like unbacked symbol whose upper bound < 2. This # causes... problems. @@ -3175,13 +3175,6 @@ class ShapeEnv: # practice self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} - # When doing a size-oblivious test, exclude this integer and - # everything higher than it from the acceptable range. This solves - # https://github.com/pytorch/pytorch/issues/120288 for constant range - # case - # TODO: generalize this to work with expressions (in that case, we - # need to maintain a SET and we need extra symbolic reasoning on top) - self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {} self.source_name_to_debug_name: dict[str, str] = {} self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {} @@ -3490,10 +3483,8 @@ class ShapeEnv: @record_shapeenv_event() def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None: - self.oblivious_upper_bound_exclusive[a] = min( - self.oblivious_upper_bound_exclusive.get(a, int_oo), - sympy.Integer(upper_bound), - ) + # TODO: Do something nontrivial when upper_bound is expression + pass @record_shapeenv_event() def _constrain_range_for_size( @@ -5627,7 +5618,6 @@ class ShapeEnv: var_ranges.get(s), self.var_to_val.get(s), s in self.size_like, - self.oblivious_upper_bound_exclusive.get(s), ) for s in sorted(fs, key=str) # TODO: speed up sort? )