Revert "[dynamic shapes] unbacked-safe slicing (#157944)"

This reverts commit 2f0cba934de7094a66c6ce68f5e937254f23142a.

Reverted https://github.com/pytorch/pytorch/pull/157944 on behalf of https://github.com/seemethere due to This is blocking internal sync due to merge conflicts ([comment](https://github.com/pytorch/pytorch/pull/157944#issuecomment-3206833193))
This commit is contained in:
PyTorch MergeBot
2025-08-20 15:16:45 +00:00
parent a818fa77e3
commit 6ea4be1e2e
11 changed files with 39 additions and 493 deletions

View File

@ -6,7 +6,6 @@ import numbers
import operator
import sys
from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial, reduce
from itertools import chain, product
@ -722,7 +721,10 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import statically_known_true
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)
ndim = self.dim()
if ndim == 0:
@ -737,22 +739,22 @@ def slice_forward(
start_val = start if start is not None else 0
end_val = end if end is not None else sys.maxsize # 2^63 - 1
if start_val < 0:
if guard_size_oblivious(start_val < 0):
start_val += sizes[dim]
if end_val < 0:
if guard_size_oblivious(end_val < 0):
end_val += sizes[dim]
if start_val < 0:
if guard_size_oblivious(start_val < 0):
start_val = 0
elif start_val > sizes[dim]:
elif guard_size_oblivious(start_val > sizes[dim]):
start_val = sizes[dim]
if statically_known_true(end_val == sys.maxsize):
end_val = sizes[dim]
elif end_val < start_val:
elif guard_size_oblivious(end_val < start_val):
end_val = start_val
elif end_val > sizes[dim]:
elif guard_size_oblivious(end_val > sizes[dim]):
end_val = sizes[dim]
storage_offset = self.storage_offset() + start_val * strides[dim]
@ -1436,17 +1438,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
assert isinstance(sections, IntLike)
return self.tensor_split(sections, dim)
else:
ctx = nullcontext
if (fake_mode := torch._guards.detect_fake_mode()) and (
shape_env := fake_mode.shape_env
):
ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
# In fake tensor prop, we end up calling slice() with these unbacked indices.
# Because slice has flexible semantics, the unbacked handling generates new output sizes
# for each slice, effectively clobbering over these index symbols.
# To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
with ctx():
indices = [i.item() for i in tensor_indices_or_sections]
indices = [i.item() for i in tensor_indices_or_sections]
# WARNING: Tempted to torch._check_is_size on the indices here? You
# can't: tensor_split works with negative values in indices:
#