[dynamic shapes] unbacked-safe slicing (#157944)

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan
2025-08-18 22:38:12 +00:00
committed by PyTorch MergeBot
parent 0254646654
commit 56218d85e2
10 changed files with 488 additions and 35 deletions

View File

@ -3028,6 +3028,32 @@ def forward(self, causal_mask, fill_value):
}, },
) )
def test_unbacked_slice_forward(self):
class Foo(torch.nn.Module):
def forward(self, x, xs):
u0, u1 = xs.tolist()
out = x[u0:u1]
return out
x = torch.randn(10)
idxs = torch.tensor([3, 6])
mod = Foo()
ep = export(mod, (x, idxs))
for xs in [
idxs,
torch.tensor([-9, -1]),
torch.tensor([-10000, 10000]),
torch.tensor([0, -10]),
]:
self.assertTrue(torch.allclose(ep.module()(x, xs), mod(x, xs)))
# check unbacked bindings
# should be 4 symbols: u0, u1, output size, output storage offset
bound_unbacked = set()
for node in ep.graph.nodes:
bound_unbacked |= node.meta.get("unbacked_bindings", {}).keys()
self.assertEqual(len(bound_unbacked), 4)
def test_dim_hint_ranges(self): def test_dim_hint_ranges(self):
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
@ -5704,7 +5730,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
} }
self._test_export_same_as_eager(kw_func, args, kwargs) self._test_export_same_as_eager(kw_func, args, kwargs)
def test_unbacked_slice(self): def test_unbacked_slice_simple(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, scores, score_thr, topk: torch.Tensor, results=None): def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
valid_mask = scores > score_thr valid_mask = scores > score_thr

View File

@ -3391,6 +3391,119 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertEqual(result_compiled, result_eager) self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.frame_count, 2)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_slice(self):
from torch.fx.experimental.symbolic_shapes import statically_known_true
# standard slice
def f1(x, xs):
u0, u1 = xs.tolist()
torch._check_is_size(u0, max=x.size(0))
torch._check_is_size(u1, max=x.size(0))
torch._check(u0 <= u1)
out = x[u0:u1]
assert statically_known_true(out.size(0) == (u1 - u0))
return out
x, xs = torch.randn(10), torch.tensor([3, 6])
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertEqual(fn1(x, xs).size(0), 3)
self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs)))
with self.assertRaises(RuntimeError):
fn1(x, torch.tensor([-1, 5]))
# known negative slice
def f2(x, n):
u0 = n.item()
torch._check(u0 > 1)
torch._check(u0 <= x.size(0))
out = x[-u0:]
assert statically_known_true(out.size(0) == u0)
return out
x, n = torch.randn(10), torch.tensor([5])
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertEqual(fn2(x, n).size(0), 5)
self.assertTrue(torch.allclose(fn2(x, n), f2(x, n)))
with self.assertRaises(RuntimeError):
fn2(x, torch.tensor([-5]))
# general case: no known info
def f3(x, xs):
u0, u1 = xs.tolist()
return x[u0:u1]
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
cnts = CompileCounterWithBackend("inductor")
x, xs = torch.randn(10), torch.tensor([3, 6])
with ctx():
fn3 = torch.compile(f3, fullgraph=True, backend=cnts)
xs = torch.tensor([-9, -1]) # negative case
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
xs = torch.tensor([-1000, 1000]) # out of bounds
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
xs = torch.tensor([2, -2]) # mixed
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
self.assertEqual(cnts.frame_count, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None
sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
ge_2: "Sym(u2 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_2 = _assert_scalar = None
le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None
sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
ge_3: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_2 = None
return (slice_1,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_unbacked_slice_cpp_wrapper(self):
self.test_unbacked_slice()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_tensor_split(self):
def f1(x, xs):
xs = torch.tensor(xs.tolist())
return torch.tensor_split(x, xs)
x = torch.randn(20)
xs = torch.tensor([5, 10, 15])
fn = torch.compile(f1, fullgraph=True, backend="inductor")
def compare(x, xs):
for i, j in zip(f1(x, xs), fn(x, xs)):
self.assertTrue(torch.allclose(i, j))
compare(x, xs)
xs = torch.tensor([-15, 9, 10, 11])
compare(x, xs)
xs = torch.tensor([-15, -10, -5, -2])
compare(x, xs)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_tensor_split_cpp_wrapper(self):
self.test_tensor_split()
@unittest.skip("this test fails due to inductor/autograd issue #153041") @unittest.skip("this test fails due to inductor/autograd issue #153041")
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_non_contigious_reshape_failing(self): def test_unbacked_non_contigious_reshape_failing(self):

View File

@ -1973,7 +1973,6 @@ make_fx_failures = {
skip('item'), skip('item'),
xfail('cov'), xfail('cov'),
xfail('nn.functional.gaussian_nll_loss'), xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'), xfail('corrcoef'),
xfail('quantile'), xfail('quantile'),
xfail('nanquantile'), xfail('nanquantile'),
@ -1993,10 +1992,12 @@ make_fx_failures = {
only_real_tensor_failures = { only_real_tensor_failures = {
xfail('narrow'), xfail('narrow'),
xfail('tensor_split'),
} }
only_fake_tensor_failures = { only_fake_tensor_failures = {
xfail('narrow'), xfail('narrow'),
xfail('tensor_split'),
} }
fake_tensor_failures = set() fake_tensor_failures = set()

View File

@ -6,6 +6,7 @@ import numbers
import operator import operator
import sys import sys
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import partial, reduce from functools import partial, reduce
from itertools import chain, product from itertools import chain, product
@ -721,10 +722,7 @@ def slice_forward(
end: Optional[int] = None, end: Optional[int] = None,
step: int = 1, step: int = 1,
): ):
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import statically_known_true
guard_size_oblivious,
statically_known_true,
)
ndim = self.dim() ndim = self.dim()
if ndim == 0: if ndim == 0:
@ -739,22 +737,22 @@ def slice_forward(
start_val = start if start is not None else 0 start_val = start if start is not None else 0
end_val = end if end is not None else sys.maxsize # 2^63 - 1 end_val = end if end is not None else sys.maxsize # 2^63 - 1
if guard_size_oblivious(start_val < 0): if start_val < 0:
start_val += sizes[dim] start_val += sizes[dim]
if guard_size_oblivious(end_val < 0): if end_val < 0:
end_val += sizes[dim] end_val += sizes[dim]
if guard_size_oblivious(start_val < 0): if start_val < 0:
start_val = 0 start_val = 0
elif guard_size_oblivious(start_val > sizes[dim]): elif start_val > sizes[dim]:
start_val = sizes[dim] start_val = sizes[dim]
if statically_known_true(end_val == sys.maxsize): if statically_known_true(end_val == sys.maxsize):
end_val = sizes[dim] end_val = sizes[dim]
elif guard_size_oblivious(end_val < start_val): elif end_val < start_val:
end_val = start_val end_val = start_val
elif guard_size_oblivious(end_val > sizes[dim]): elif end_val > sizes[dim]:
end_val = sizes[dim] end_val = sizes[dim]
storage_offset = self.storage_offset() + start_val * strides[dim] storage_offset = self.storage_offset() + start_val * strides[dim]
@ -1438,7 +1436,17 @@ def tensor_split_tensor_indices_or_sections_py_impl(
assert isinstance(sections, IntLike) assert isinstance(sections, IntLike)
return self.tensor_split(sections, dim) return self.tensor_split(sections, dim)
else: else:
indices = [i.item() for i in tensor_indices_or_sections] ctx = nullcontext
if (fake_mode := torch._guards.detect_fake_mode()) and (
shape_env := fake_mode.shape_env
):
ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
# In fake tensor prop, we end up calling slice() with these unbacked indices.
# Because slice has flexible semantics, the unbacked handling generates new output sizes
# for each slice, effectively clobbering over these index symbols.
# To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
with ctx():
indices = [i.item() for i in tensor_indices_or_sections]
# WARNING: Tempted to torch._check_is_size on the indices here? You # WARNING: Tempted to torch._check_is_size on the indices here? You
# can't: tensor_split works with negative values in indices: # can't: tensor_split works with negative values in indices:
# #

View File

@ -1456,19 +1456,51 @@ class CppWrapperCpu(PythonWrapperCodegen):
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.sym)) self.unbacked_symbol_decls.add(str(node.sym))
def codegen_dynamic_select_index(self, node): def codegen_dynamic_select_index(self, node, clamp):
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
index_compute_str = ( # codegen index
sym = node.unbacked_offset_symbol
index_str = (
f"{index_cpp_str} < 0 ? {index_cpp_str} + " f"{index_cpp_str} < 0 ? {index_cpp_str} + "
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
) )
self.writeline(f"auto {sym}_index = {index_str};")
index_str_clamped = (
f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)"
if clamp
else f"{sym}_index"
)
self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};")
self.writeline( self.writeline(
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});" f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;"
) )
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) self.unbacked_symbol_decls.add(str(sym))
def codegen_dynamic_slice_size(self, node):
start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int)
end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int)
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
sym = node.unbacked_size_symbol
def codegen_clamp(index_str, start=True):
suf = "start" if start else "end"
index_ = f"{sym}_{suf}_index"
self.writeline(
f"auto {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};"
)
self.writeline(
f"auto {sym}_{suf}_clamped = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});"
)
codegen_clamp(start_cpp_str, start=True)
codegen_clamp(end_cpp_str, start=False)
self.writeline(f"auto {sym}_raw = {sym}_end_clamped - {sym}_start_clamped;")
self.writeline(f"auto {sym} = {sym}_raw < 0 ? 0 : {sym}_raw;")
self.unbacked_symbol_decls.add(str(sym))
def make_buffer_free(self, buffer): def make_buffer_free(self, buffer):
return ( return (

View File

@ -1817,14 +1817,33 @@ class PythonWrapperCodegen(CodeGen):
arg_name = node.input_name(0) arg_name = node.input_name(0)
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
def codegen_dynamic_select_index(self, node): def codegen_dynamic_select_index(self, node, clamp):
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
if clamp:
index_str = f"max(0, min({node.size}, {index_str}))"
self.writeline( self.writeline(
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
) )
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
def codegen_dynamic_slice_size(self, node):
def clamp_index(x):
pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size)))
neg = self.codegen_sizevar(
sympy.Max(0, sympy.Min(x + node.size, node.size))
)
return f"{pos} if {x} >= 0 else {neg}"
# codegen start, end
sym = node.unbacked_size_symbol
start = clamp_index(node.start)
end = clamp_index(node.end)
self.writeline(f"{sym}_start = {start}")
self.writeline(f"{sym}_end = {end}")
self.writeline(f"{sym} = max(0, {sym}_end - {sym}_start)")
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
def codegen_dynamic_scalar(self, node): def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs) (data,) = (t.codegen_reference() for t in node.inputs)
if len(node.keypath) == 0: if len(node.keypath) == 0:

View File

@ -3437,7 +3437,6 @@ class SliceView(View):
if val is None: if val is None:
# TODO(rec): can this really happen? # TODO(rec): can this really happen?
return default return default
val = cls.handle_negative_index(val, dim_size)
return clamp(val, lower, upper) return clamp(val, lower, upper)
start = clamp_wrap(start, 0, dim_size, 0) start = clamp_wrap(start, 0, dim_size, 0)
@ -3454,14 +3453,6 @@ class SliceView(View):
step: int = 1, step: int = 1,
clamp: bool = True, clamp: bool = True,
) -> IRNode: ) -> IRNode:
step = sympy.expand(step)
assert isinstance(step, Expr) or step > 0, step
try:
if start == 0 and end >= 2**63 - 1 and step == 1:
return x
except TypeError:
pass
new_size = list(x.get_size()) new_size = list(x.get_size())
# NB: Ordinarily we default to clamping. # NB: Ordinarily we default to clamping.
@ -7221,6 +7212,7 @@ class DynamicSelectStorageOffset(ExternKernel):
base_offset: Union[sympy.Symbol, int], base_offset: Union[sympy.Symbol, int],
base_dim_stride: Union[sympy.Symbol, int], base_dim_stride: Union[sympy.Symbol, int],
size: Union[sympy.Symbol, int], size: Union[sympy.Symbol, int],
clamp: bool,
) -> None: ) -> None:
super().__init__(None, NoneLayout(device=torch.device("cpu")), []) super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
# This node codegen the following: # This node codegen the following:
@ -7230,6 +7222,7 @@ class DynamicSelectStorageOffset(ExternKernel):
self.base_offset = base_offset self.base_offset = base_offset
self.base_dim_stride = base_dim_stride self.base_dim_stride = base_dim_stride
self.size = size self.size = size
self.clamp = clamp
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_offset_symbol]) return OrderedSet([self.unbacked_offset_symbol])
@ -7240,7 +7233,57 @@ class DynamicSelectStorageOffset(ExternKernel):
return get_free_symbols(self.index, unbacked_only) return get_free_symbols(self.index, unbacked_only)
def codegen(self, wrapper: PythonWrapperCodegen) -> None: def codegen(self, wrapper: PythonWrapperCodegen) -> None:
wrapper.codegen_dynamic_select_index(self) wrapper.codegen_dynamic_select_index(self, clamp=self.clamp)
class DynamicSliceSize(ExternKernel):
"""
Computes the output size of a slice call, handling the correct semantics in codegen.
We do this for flexible handling for unbacked indices (to not data-dependent error).
Slicing has 4 semantics for indices, i.e. x[start:] could be:
1) start < -x.size(0) -> x[0:] # negative out-of-bounds
2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing
3) start in [0, x.size(0)) -> x[start:] # standard slicing
4) start >= x.size(0) -> empty slice # positive out-of-bounds
If the appropriate semantics are known beforehand, the output size is computed based on
the start & end indices. If not (with unbacked indices), a new unbacked symbol is created
to represent the output size, and codegen handles computing the correct case.
"""
def get_reads(self) -> OrderedSet[Dep]:
return OrderedSet()
def should_allocate(self) -> bool:
return False
def __init__(
self,
unbacked_size_symbol: sympy.Symbol,
start: sympy.Symbol,
end: Union[sympy.Symbol, int],
size: Union[sympy.Symbol, int],
):
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
# This node codegen
self.unbacked_size_symbol = unbacked_size_symbol
self.start = start
self.end = end
self.size = size
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_size_symbol])
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return get_free_symbols(self.start, unbacked_only).union(
get_free_symbols(self.end, unbacked_only)
)
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
wrapper.codegen_dynamic_slice_size(self)
class DynamicScalar(ExternKernel): class DynamicScalar(ExternKernel):

View File

@ -1172,9 +1172,130 @@ def permute(x, dims):
@register_lowering(aten.slice, type_promotion_kind=None) @register_lowering(aten.slice, type_promotion_kind=None)
def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
"""
Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
if the indices are unbacked and appropriate semantics aren't known.
If they are known (indices are static/backed/unbacked with info), a SliceView is created.
"""
from torch.fx.experimental.symbolic_shapes import (
CallMethodKey,
resolve_unbacked_bindings,
)
assert isinstance(x, TensorBox) assert isinstance(x, TensorBox)
dim = _validate_dim(x, dim, 0) dim = _validate_dim(x, dim, 0)
return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)) size = x.get_size()[dim]
step = sympy.expand(step)
assert isinstance(step, sympy.Expr) or step > 0, step
# maybe apply slice optimization
try:
if (
start == 0
and V.graph.sizevars.statically_known_leq(size, end)
and step == 1
):
return x
except TypeError:
pass
# try to avoid dynamic slice
def handle_negative_index(idx, size, default):
if idx is None:
return default
idx = sympy.expand(idx)
size = sympy.expand(size)
if V.graph.sizevars.guard_or_false(idx >= 0):
return idx
elif V.graph.sizevars.guard_or_false(idx < 0):
return size + idx
return None
ambiguous_slice = clamp
if ambiguous_slice:
start_index = handle_negative_index(start, size, 0)
end_index = handle_negative_index(end, size, size)
if start_index is not None and end_index is not None:
start, end = start_index, end_index
ambiguous_slice = False
# ambiguous_slice=False means we know what semantics this slice call follows,
# and don't need to generate an extern kernel to represent the output size.
# This is assumed True for clamp=False
# (meant to follow standard indexing semantics: 0 <= index < size)
if not ambiguous_slice:
return TensorBox(
ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)
) # go to SliceView/ReinterpretView
# unbacked territory: create DynamicSlice ExternKernel
# clamp is True, unbacked start / end
assert clamp
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
)
assert unbacked_bindings is not None
assert len(unbacked_bindings) <= 2, unbacked_bindings
sym_size, sym_storage = None, None
for sym, keypath in unbacked_bindings.items():
if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)):
sym_size = sym
elif keypath == (CallMethodKey("storage_offset"),):
sym_storage = sym
def compute_slice_index(index, size):
fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731
if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)):
return index
elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)):
return -index
elif fn(sympy.Gt(index, size)):
return size
elif fn(sympy.Lt(index, -size)):
return 0
return None
start_index = compute_slice_index(start, size)
end_index = compute_slice_index(end, size)
if start_index is not None and end_index is not None:
# we shouldn't have allocated size symbol, if output size was determinable from input indices
assert sym_size is None
new_size = sympy.Max(0, end_index - start_index)
else:
b_size = ir.DynamicSliceSize(
sym_size,
start,
end,
x.get_size()[dim],
)
b_size.name = V.graph.register_buffer(b_size)
V.graph.register_operation(b_size)
new_size = sym_size
if start_index is not None:
# we shouldn't have allocated storage offset symbol if start index was determinable
assert sym_storage is None
new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim]
else:
b_storage = ir.DynamicSelectStorageOffset(
sym_storage,
start,
x.get_layout().offset,
x.get_stride()[dim],
x.get_size()[dim],
clamp=True,
)
b_storage.name = V.graph.register_buffer(b_storage)
V.graph.register_operation(b_storage)
new_storage_offset = sym_storage
new_sizes = list(x.get_size())
new_strides = list(x.get_stride())
new_sizes[dim] = new_size
new_strides[dim] *= step
return as_strided(x, new_sizes, new_strides, new_storage_offset)
@register_lowering(aten.as_strided, type_promotion_kind=None) @register_lowering(aten.as_strided, type_promotion_kind=None)
@ -1800,6 +1921,7 @@ def select(x, dim, idx):
x.get_layout().offset, x.get_layout().offset,
new_stride[dim], new_stride[dim],
x.get_size()[dim], x.get_size()[dim],
clamp=False,
) )
buffer.name = V.graph.register_buffer(buffer) buffer.name = V.graph.register_buffer(buffer)
V.graph.register_operation(buffer) V.graph.register_operation(buffer)
@ -2991,6 +3113,8 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
dim = _validate_dim(x, dim, 0) dim = _validate_dim(x, dim, 0)
dim_size = x.get_size()[dim] dim_size = x.get_size()[dim]
start = ir.SliceView.handle_negative_index(start, dim_size)
end = ir.SliceView.handle_negative_index(end, dim_size)
start, end = ir.SliceView.normalize_start_end(x, dim, start, end) start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
src_size = list(x.get_size()) src_size = list(x.get_size())

View File

@ -6,7 +6,7 @@ import math
import operator import operator
import sys import sys
from functools import reduce from functools import reduce
from typing import Callable, Union from typing import Callable, Optional, Union
import torch import torch
import torch._custom_op import torch._custom_op
@ -15,6 +15,7 @@ import torch._prims_common as utils
from torch._dispatch.python import no_python_dispatcher from torch._dispatch.python import no_python_dispatcher
from torch._ops import OpOverload from torch._ops import OpOverload
from torch._prims_common import ( from torch._prims_common import (
canonicalize_dim,
contiguous_for_memory_format_or_false, contiguous_for_memory_format_or_false,
elementwise_dtypes, elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND, ELEMENTWISE_TYPE_PROMOTION_KIND,
@ -746,6 +747,88 @@ def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=No
return padded.new_empty(output_shape) return padded.new_empty(output_shape)
def _compute_slice_index(size, index):
from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and
if guard_or_false(sym_and(index >= 0, index <= size)):
return index
elif guard_or_false(sym_and(index < 0, index >= -size)):
return index + size
elif guard_or_false(index < -size):
return 0
elif guard_or_false(index > size):
return size
return None
@register_op_impl(torch.ops.aten.slice.Tensor)
def slice_forward(
fake_mode,
func,
self,
dim: int = 0,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
statically_known_true,
)
shape_env = fake_mode.shape_env
ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
dim = canonicalize_dim(self.dim(), dim)
sizes = list(self.size())
strides = list(self.stride())
if step <= 0:
raise RuntimeError("slice step must be positive")
# start, end
start_index = 0 if start is None else _compute_slice_index(sizes[dim], start)
end_index = (
sizes[dim]
if statically_known_true(end == sys.maxsize) or end is None
else _compute_slice_index(sizes[dim], end)
)
# size
new_size = None
if start_index is not None and end_index is not None:
if guard_or_false(end_index >= start_index):
new_size = (end_index - start_index + step - 1) // step
elif guard_or_false(start_index >= end_index):
new_size = 0
# create unbacked if case unknown
if new_size is None:
new_size = shape_env.create_unbacked_symint()
torch._check_is_size(new_size, max=sizes[dim])
# stride
new_stride = strides[dim] * step
# storage offset
if start_index is not None:
storage_offset = self.storage_offset() + start_index * strides[dim]
else:
storage_offset = shape_env.create_unbacked_symint()
torch._check(storage_offset >= 0)
sizes[dim] = new_size
strides[dim] = new_stride
if self.is_quantized:
raise NotImplementedError(
"Slice decomposition for quantized tensors aren't implemented"
)
else:
return self.as_strided(sizes, strides, storage_offset)
@register_op_impl(torch.ops.aten.masked_select.default) @register_op_impl(torch.ops.aten.masked_select.default)
def masked_select(fake_mode, func, self, mask): def masked_select(fake_mode, func, self, mask):
if ( if (

View File

@ -2616,7 +2616,9 @@ class FakeTensorMode(TorchDispatchMode):
if ( if (
func not in meta_table func not in meta_table
and not self.cpp_meta_supports_symint(func) and not self.cpp_meta_supports_symint(func)
and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops) and not (
has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops
)
): ):
from torch._decomp import decomposition_table from torch._decomp import decomposition_table
@ -2925,8 +2927,10 @@ class FakeTensorMode(TorchDispatchMode):
aten._sparse_coo_tensor_with_dims_and_tensors.default, aten._sparse_coo_tensor_with_dims_and_tensors.default,
) )
_view_fake_tensor_impl_ops = ordered_set( _unbacked_special_fake_handling_ops = ordered_set(
aten.view.default, aten._unsafe_view.default aten.view.default,
aten._unsafe_view.default,
aten.slice.Tensor,
) )
def cpp_meta_supports_symint(self, func: OpOverload) -> bool: def cpp_meta_supports_symint(self, func: OpOverload) -> bool: