Add max kwarg to torch._check with alternate size oblivious semantics (#144471)

Fixes https://github.com/pytorch/pytorch/issues/120288 for the static bound case

I had been tying myself in knots in the original issue about the fact that we can't really do symbolic bounds like u0 < s0. But then I realized, "Wait, but the static bounds are easy!" So this makes it so you can also exclude a specific upper bound when doing size oblivious tests, which is enough to solve https://github.com/pytorch/pytorch/issues/123592#issuecomment-2574556708

It's written very dirtily, maybe there's some cleanup. Bikeshed on the public API name also welcome.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144471
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Edward Z. Yang
2025-01-13 19:38:52 -08:00
committed by PyTorch MergeBot
parent 95b41d2aa4
commit ffb3f32693
3 changed files with 67 additions and 7 deletions

View File

@ -1637,11 +1637,17 @@ def _check(cond, message=None): # noqa: F811
_check_with(RuntimeError, cond, message)
def _check_is_size(i, message=None):
def _check_is_size(i, message=None, *, max=None):
"""Checks that a given integer is a valid size (i.e., is non-negative).
You should use this over _check(i >= 0) because we can use the semantic
information (that i is a size) to make some further inferences in case
i is an unbacked SymInt.
You should use this over ``_check(i >= 0)`` because it can prevent
``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
equivalently to all other values.
When max is not None, this specifies an upper bound equivalent to
``_check(i <= max)``. This bound is also subject to alternate semantics:
in ``guard_size_oblivious`` tests, we assume that the max bound is treated
equivalently to all other values.
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
to infer the size from context, or if you should wrap-around or truncate).
@ -1653,6 +1659,13 @@ def _check_is_size(i, message=None):
_advise_is_size(i)
if max is not None:
_check(i <= max, message)
from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
_advise_is_bounded(i, max)
def _check_index(cond, message=None): # noqa: F811
r"""Throws error containing an optional message if the specified condition