Add max kwarg to torch._check with alternate size oblivious semantics (#144471)

Fixes https://github.com/pytorch/pytorch/issues/120288 for the static bound case

I had been tying myself in knots in the original issue about the fact that we can't really do symbolic bounds like u0 < s0. But then I realized, "Wait, but the static bounds are easy!" So this makes it so you can also exclude a specific upper bound when doing size oblivious tests, which is enough to solve https://github.com/pytorch/pytorch/issues/123592#issuecomment-2574556708

It's written very dirtily, maybe there's some cleanup. Bikeshed on the public API name also welcome.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144471
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Edward Z. Yang
2025-01-13 19:38:52 -08:00
committed by PyTorch MergeBot
parent 95b41d2aa4
commit ffb3f32693
3 changed files with 67 additions and 7 deletions

View File

@ -297,6 +297,18 @@ class TestDynamismExpression(TestCase):
dynamic_shapes=dynamic_shapes,
)
def test_export_slice_unbacked_dim1(self):
class MySlice(torch.nn.Module):
def forward(self, x, seq_len):
l = seq_len.item()
torch._check_is_size(l, max=x.size(1))
x = x.narrow(1, 0, l)
return x
x = torch.randn(10, 7)
seq_len = torch.tensor(5)
torch.export.export(MySlice(), args=(x, seq_len))
def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):

View File

@ -1637,11 +1637,17 @@ def _check(cond, message=None): # noqa: F811
_check_with(RuntimeError, cond, message)
def _check_is_size(i, message=None):
def _check_is_size(i, message=None, *, max=None):
"""Checks that a given integer is a valid size (i.e., is non-negative).
You should use this over _check(i >= 0) because we can use the semantic
information (that i is a size) to make some further inferences in case
i is an unbacked SymInt.
You should use this over ``_check(i >= 0)`` because it can prevent
``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
equivalently to all other values.
When max is not None, this specifies an upper bound equivalent to
``_check(i <= max)``. This bound is also subject to alternate semantics:
in ``guard_size_oblivious`` tests, we assume that the max bound is treated
equivalently to all other values.
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
to infer the size from context, or if you should wrap-around or truncate).
@ -1653,6 +1659,13 @@ def _check_is_size(i, message=None):
_advise_is_size(i)
if max is not None:
_check(i <= max, message)
from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
_advise_is_bounded(i, max)
def _check_index(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition

View File

@ -1294,6 +1294,17 @@ def _advise_is_size(a: SymInt) -> None:
_constrain_range_for_size(a)
def _advise_is_bounded(a: SymInt, upper_bound: Union[int, SymInt]) -> None:
if (
isinstance(a, SymInt)
and isinstance(a.node, SymNode)
and isinstance(a.node.expr, sympy.Symbol)
and a.node.shape_env.is_unbacked_symint(a.node.expr)
and isinstance(upper_bound, int) # TODO: relax
):
a.node.shape_env._constrain_is_bounded(a.node.expr, upper_bound)
def _constrain_range_for_size(
a: SymInt, min: Optional[int] = None, max: Optional[int] = None
) -> None:
@ -1940,7 +1951,9 @@ def safe_expand(r: _SympyT) -> _SympyT:
@lru_cache(None)
def _maybe_evaluate_static_worker(
expr: _SympyT,
symbol_info: Tuple[Tuple[sympy.Symbol, ValueRanges, sympy.Integer, bool], ...],
symbol_info: Tuple[
Tuple[sympy.Symbol, ValueRanges, sympy.Integer, bool, sympy.Integer], ...
],
unbacked_only: bool,
size_oblivious: bool,
) -> Optional[_SympyT]:
@ -1955,7 +1968,7 @@ def _maybe_evaluate_static_worker(
new_shape_env = {}
new_range_env = {}
for idx, sinfo in enumerate(symbol_info):
k, vr, val, is_size_like = sinfo
k, vr, val, is_size_like, oblivious_upper_bound_exclusive = sinfo
if isinstance(val, SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today
@ -1969,6 +1982,8 @@ def _maybe_evaluate_static_worker(
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2**48, vr.upper)
if oblivious_upper_bound_exclusive is not None:
upper = min(upper, oblivious_upper_bound_exclusive - 1)
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
@ -3159,6 +3174,13 @@ class ShapeEnv:
# practice
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.var_to_range_sloc: Dict[sympy.Symbol, ValueRangesSLoc] = {}
# When doing a size-oblivious test, exclude this integer and
# everything higher than it from the acceptable range. This solves
# https://github.com/pytorch/pytorch/issues/120288 for constant range
# case
# TODO: generalize this to work with expressions (in that case, we
# need to maintain a SET and we need extra symbolic reasoning on top)
self.oblivious_upper_bound_exclusive: Dict[sympy.Symbol, sympy.Integer] = {}
self.source_name_to_debug_name: Dict[str, str] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
@ -3465,6 +3487,13 @@ class ShapeEnv:
if dest is not None:
self._set_replacement(new_s, dest, "rename_unbacked_to_dest")
@record_shapeenv_event()
def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
self.oblivious_upper_bound_exclusive[a] = min(
self.oblivious_upper_bound_exclusive.get(a, int_oo),
sympy.Integer(upper_bound),
)
@record_shapeenv_event()
def _constrain_range_for_size(
self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None
@ -5592,7 +5621,13 @@ class ShapeEnv:
var_ranges = dict(var_to_range)
symbol_info = tuple(
(s, var_ranges.get(s), self.var_to_val.get(s), s in self.size_like)
(
s,
var_ranges.get(s),
self.var_to_val.get(s),
s in self.size_like,
self.oblivious_upper_bound_exclusive.get(s),
)
for s in sorted(fs, key=lambda s: str(s)) # TODO: speed up sort?
)