mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add max kwarg to torch._check with alternate size oblivious semantics (#144471)
Fixes https://github.com/pytorch/pytorch/issues/120288 for the static bound case I had been tying myself in knots in the original issue about the fact that we can't really do symbolic bounds like u0 < s0. But then I realized, "Wait, but the static bounds are easy!" So this makes it so you can also exclude a specific upper bound when doing size oblivious tests, which is enough to solve https://github.com/pytorch/pytorch/issues/123592#issuecomment-2574556708 It's written very dirtily, maybe there's some cleanup. Bikeshed on the public API name also welcome. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144471 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
95b41d2aa4
commit
ffb3f32693
@ -1637,11 +1637,17 @@ def _check(cond, message=None): # noqa: F811
|
||||
_check_with(RuntimeError, cond, message)
|
||||
|
||||
|
||||
def _check_is_size(i, message=None):
|
||||
def _check_is_size(i, message=None, *, max=None):
|
||||
"""Checks that a given integer is a valid size (i.e., is non-negative).
|
||||
You should use this over _check(i >= 0) because we can use the semantic
|
||||
information (that i is a size) to make some further inferences in case
|
||||
i is an unbacked SymInt.
|
||||
You should use this over ``_check(i >= 0)`` because it can prevent
|
||||
``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
|
||||
semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
|
||||
equivalently to all other values.
|
||||
|
||||
When max is not None, this specifies an upper bound equivalent to
|
||||
``_check(i <= max)``. This bound is also subject to alternate semantics:
|
||||
in ``guard_size_oblivious`` tests, we assume that the max bound is treated
|
||||
equivalently to all other values.
|
||||
|
||||
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
|
||||
to infer the size from context, or if you should wrap-around or truncate).
|
||||
@ -1653,6 +1659,13 @@ def _check_is_size(i, message=None):
|
||||
|
||||
_advise_is_size(i)
|
||||
|
||||
if max is not None:
|
||||
_check(i <= max, message)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
|
||||
|
||||
_advise_is_bounded(i, max)
|
||||
|
||||
|
||||
def _check_index(cond, message=None): # noqa: F811
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
|
Reference in New Issue
Block a user