mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[dynamic shapes] unbacked-safe slicing (#157944)"
This reverts commit 2f0cba934de7094a66c6ce68f5e937254f23142a. Reverted https://github.com/pytorch/pytorch/pull/157944 on behalf of https://github.com/seemethere due to This is blocking internal sync due to merge conflicts ([comment](https://github.com/pytorch/pytorch/pull/157944#issuecomment-3206833193))
This commit is contained in:
@ -6,7 +6,6 @@ import numbers
|
||||
import operator
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
from itertools import chain, product
|
||||
@ -722,7 +721,10 @@ def slice_forward(
|
||||
end: Optional[int] = None,
|
||||
step: int = 1,
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_size_oblivious,
|
||||
statically_known_true,
|
||||
)
|
||||
|
||||
ndim = self.dim()
|
||||
if ndim == 0:
|
||||
@ -737,22 +739,22 @@ def slice_forward(
|
||||
start_val = start if start is not None else 0
|
||||
end_val = end if end is not None else sys.maxsize # 2^63 - 1
|
||||
|
||||
if start_val < 0:
|
||||
if guard_size_oblivious(start_val < 0):
|
||||
start_val += sizes[dim]
|
||||
|
||||
if end_val < 0:
|
||||
if guard_size_oblivious(end_val < 0):
|
||||
end_val += sizes[dim]
|
||||
|
||||
if start_val < 0:
|
||||
if guard_size_oblivious(start_val < 0):
|
||||
start_val = 0
|
||||
elif start_val > sizes[dim]:
|
||||
elif guard_size_oblivious(start_val > sizes[dim]):
|
||||
start_val = sizes[dim]
|
||||
|
||||
if statically_known_true(end_val == sys.maxsize):
|
||||
end_val = sizes[dim]
|
||||
elif end_val < start_val:
|
||||
elif guard_size_oblivious(end_val < start_val):
|
||||
end_val = start_val
|
||||
elif end_val > sizes[dim]:
|
||||
elif guard_size_oblivious(end_val > sizes[dim]):
|
||||
end_val = sizes[dim]
|
||||
|
||||
storage_offset = self.storage_offset() + start_val * strides[dim]
|
||||
@ -1436,17 +1438,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
|
||||
assert isinstance(sections, IntLike)
|
||||
return self.tensor_split(sections, dim)
|
||||
else:
|
||||
ctx = nullcontext
|
||||
if (fake_mode := torch._guards.detect_fake_mode()) and (
|
||||
shape_env := fake_mode.shape_env
|
||||
):
|
||||
ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
|
||||
# In fake tensor prop, we end up calling slice() with these unbacked indices.
|
||||
# Because slice has flexible semantics, the unbacked handling generates new output sizes
|
||||
# for each slice, effectively clobbering over these index symbols.
|
||||
# To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
|
||||
with ctx():
|
||||
indices = [i.item() for i in tensor_indices_or_sections]
|
||||
indices = [i.item() for i in tensor_indices_or_sections]
|
||||
# WARNING: Tempted to torch._check_is_size on the indices here? You
|
||||
# can't: tensor_split works with negative values in indices:
|
||||
#
|
||||
|
Reference in New Issue
Block a user