mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add max kwarg to torch._check with alternate size oblivious semantics (#144471)
Fixes https://github.com/pytorch/pytorch/issues/120288 for the static bound case I had been tying myself in knots in the original issue about the fact that we can't really do symbolic bounds like u0 < s0. But then I realized, "Wait, but the static bounds are easy!" So this makes it so you can also exclude a specific upper bound when doing size oblivious tests, which is enough to solve https://github.com/pytorch/pytorch/issues/123592#issuecomment-2574556708 It's written very dirtily, maybe there's some cleanup. Bikeshed on the public API name also welcome. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144471 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
95b41d2aa4
commit
ffb3f32693
@ -297,6 +297,18 @@ class TestDynamismExpression(TestCase):
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_export_slice_unbacked_dim1(self):
|
||||
class MySlice(torch.nn.Module):
|
||||
def forward(self, x, seq_len):
|
||||
l = seq_len.item()
|
||||
torch._check_is_size(l, max=x.size(1))
|
||||
x = x.narrow(1, 0, l)
|
||||
return x
|
||||
|
||||
x = torch.randn(10, 7)
|
||||
seq_len = torch.tensor(5)
|
||||
torch.export.export(MySlice(), args=(x, seq_len))
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
class ConflictingConstraints(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -1637,11 +1637,17 @@ def _check(cond, message=None): # noqa: F811
|
||||
_check_with(RuntimeError, cond, message)
|
||||
|
||||
|
||||
def _check_is_size(i, message=None):
|
||||
def _check_is_size(i, message=None, *, max=None):
|
||||
"""Checks that a given integer is a valid size (i.e., is non-negative).
|
||||
You should use this over _check(i >= 0) because we can use the semantic
|
||||
information (that i is a size) to make some further inferences in case
|
||||
i is an unbacked SymInt.
|
||||
You should use this over ``_check(i >= 0)`` because it can prevent
|
||||
``GuardOnDataDependentSymNode`` exceptions by opting yourself into alternate
|
||||
semantics for ``guard_size_oblivious`` tests that treat values 0 and 1
|
||||
equivalently to all other values.
|
||||
|
||||
When max is not None, this specifies an upper bound equivalent to
|
||||
``_check(i <= max)``. This bound is also subject to alternate semantics:
|
||||
in ``guard_size_oblivious`` tests, we assume that the max bound is treated
|
||||
equivalently to all other values.
|
||||
|
||||
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
|
||||
to infer the size from context, or if you should wrap-around or truncate).
|
||||
@ -1653,6 +1659,13 @@ def _check_is_size(i, message=None):
|
||||
|
||||
_advise_is_size(i)
|
||||
|
||||
if max is not None:
|
||||
_check(i <= max, message)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import _advise_is_bounded
|
||||
|
||||
_advise_is_bounded(i, max)
|
||||
|
||||
|
||||
def _check_index(cond, message=None): # noqa: F811
|
||||
r"""Throws error containing an optional message if the specified condition
|
||||
|
@ -1294,6 +1294,17 @@ def _advise_is_size(a: SymInt) -> None:
|
||||
_constrain_range_for_size(a)
|
||||
|
||||
|
||||
def _advise_is_bounded(a: SymInt, upper_bound: Union[int, SymInt]) -> None:
|
||||
if (
|
||||
isinstance(a, SymInt)
|
||||
and isinstance(a.node, SymNode)
|
||||
and isinstance(a.node.expr, sympy.Symbol)
|
||||
and a.node.shape_env.is_unbacked_symint(a.node.expr)
|
||||
and isinstance(upper_bound, int) # TODO: relax
|
||||
):
|
||||
a.node.shape_env._constrain_is_bounded(a.node.expr, upper_bound)
|
||||
|
||||
|
||||
def _constrain_range_for_size(
|
||||
a: SymInt, min: Optional[int] = None, max: Optional[int] = None
|
||||
) -> None:
|
||||
@ -1940,7 +1951,9 @@ def safe_expand(r: _SympyT) -> _SympyT:
|
||||
@lru_cache(None)
|
||||
def _maybe_evaluate_static_worker(
|
||||
expr: _SympyT,
|
||||
symbol_info: Tuple[Tuple[sympy.Symbol, ValueRanges, sympy.Integer, bool], ...],
|
||||
symbol_info: Tuple[
|
||||
Tuple[sympy.Symbol, ValueRanges, sympy.Integer, bool, sympy.Integer], ...
|
||||
],
|
||||
unbacked_only: bool,
|
||||
size_oblivious: bool,
|
||||
) -> Optional[_SympyT]:
|
||||
@ -1955,7 +1968,7 @@ def _maybe_evaluate_static_worker(
|
||||
new_shape_env = {}
|
||||
new_range_env = {}
|
||||
for idx, sinfo in enumerate(symbol_info):
|
||||
k, vr, val, is_size_like = sinfo
|
||||
k, vr, val, is_size_like, oblivious_upper_bound_exclusive = sinfo
|
||||
if isinstance(val, SingletonInt):
|
||||
# Skip var_ranges logic for SingletonInt which is only used
|
||||
# for jagged layout NestedTensors today
|
||||
@ -1969,6 +1982,8 @@ def _maybe_evaluate_static_worker(
|
||||
# This is similar to the flavor where size oblivious omits
|
||||
# 0/1, it changes semantics but in a benign way.
|
||||
upper = min(2**48, vr.upper)
|
||||
if oblivious_upper_bound_exclusive is not None:
|
||||
upper = min(upper, oblivious_upper_bound_exclusive - 1)
|
||||
# This is a bit dodgy: what this means is that there was a
|
||||
# size-like unbacked symbol whose upper bound < 2. This
|
||||
# causes... problems.
|
||||
@ -3159,6 +3174,13 @@ class ShapeEnv:
|
||||
# practice
|
||||
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
|
||||
self.var_to_range_sloc: Dict[sympy.Symbol, ValueRangesSLoc] = {}
|
||||
# When doing a size-oblivious test, exclude this integer and
|
||||
# everything higher than it from the acceptable range. This solves
|
||||
# https://github.com/pytorch/pytorch/issues/120288 for constant range
|
||||
# case
|
||||
# TODO: generalize this to work with expressions (in that case, we
|
||||
# need to maintain a SET and we need extra symbolic reasoning on top)
|
||||
self.oblivious_upper_bound_exclusive: Dict[sympy.Symbol, sympy.Integer] = {}
|
||||
self.source_name_to_debug_name: Dict[str, str] = {}
|
||||
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
|
||||
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
|
||||
@ -3465,6 +3487,13 @@ class ShapeEnv:
|
||||
if dest is not None:
|
||||
self._set_replacement(new_s, dest, "rename_unbacked_to_dest")
|
||||
|
||||
@record_shapeenv_event()
|
||||
def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
|
||||
self.oblivious_upper_bound_exclusive[a] = min(
|
||||
self.oblivious_upper_bound_exclusive.get(a, int_oo),
|
||||
sympy.Integer(upper_bound),
|
||||
)
|
||||
|
||||
@record_shapeenv_event()
|
||||
def _constrain_range_for_size(
|
||||
self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None
|
||||
@ -5592,7 +5621,13 @@ class ShapeEnv:
|
||||
var_ranges = dict(var_to_range)
|
||||
|
||||
symbol_info = tuple(
|
||||
(s, var_ranges.get(s), self.var_to_val.get(s), s in self.size_like)
|
||||
(
|
||||
s,
|
||||
var_ranges.get(s),
|
||||
self.var_to_val.get(s),
|
||||
s in self.size_like,
|
||||
self.oblivious_upper_bound_exclusive.get(s),
|
||||
)
|
||||
for s in sorted(fs, key=lambda s: str(s)) # TODO: speed up sort?
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user