mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamic shapes] unbacked-safe slicing (#157944)
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
0254646654
commit
56218d85e2
@ -3028,6 +3028,32 @@ def forward(self, causal_mask, fill_value):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_unbacked_slice_forward(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x, xs):
|
||||||
|
u0, u1 = xs.tolist()
|
||||||
|
out = x[u0:u1]
|
||||||
|
return out
|
||||||
|
|
||||||
|
x = torch.randn(10)
|
||||||
|
idxs = torch.tensor([3, 6])
|
||||||
|
mod = Foo()
|
||||||
|
ep = export(mod, (x, idxs))
|
||||||
|
for xs in [
|
||||||
|
idxs,
|
||||||
|
torch.tensor([-9, -1]),
|
||||||
|
torch.tensor([-10000, 10000]),
|
||||||
|
torch.tensor([0, -10]),
|
||||||
|
]:
|
||||||
|
self.assertTrue(torch.allclose(ep.module()(x, xs), mod(x, xs)))
|
||||||
|
|
||||||
|
# check unbacked bindings
|
||||||
|
# should be 4 symbols: u0, u1, output size, output storage offset
|
||||||
|
bound_unbacked = set()
|
||||||
|
for node in ep.graph.nodes:
|
||||||
|
bound_unbacked |= node.meta.get("unbacked_bindings", {}).keys()
|
||||||
|
self.assertEqual(len(bound_unbacked), 4)
|
||||||
|
|
||||||
def test_dim_hint_ranges(self):
|
def test_dim_hint_ranges(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
@ -5704,7 +5730,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||||||
}
|
}
|
||||||
self._test_export_same_as_eager(kw_func, args, kwargs)
|
self._test_export_same_as_eager(kw_func, args, kwargs)
|
||||||
|
|
||||||
def test_unbacked_slice(self):
|
def test_unbacked_slice_simple(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
|
def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
|
||||||
valid_mask = scores > score_thr
|
valid_mask = scores > score_thr
|
||||||
|
@ -3391,6 +3391,119 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||||||
self.assertEqual(result_compiled, result_eager)
|
self.assertEqual(result_compiled, result_eager)
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_unbacked_slice(self):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||||
|
|
||||||
|
# standard slice
|
||||||
|
def f1(x, xs):
|
||||||
|
u0, u1 = xs.tolist()
|
||||||
|
torch._check_is_size(u0, max=x.size(0))
|
||||||
|
torch._check_is_size(u1, max=x.size(0))
|
||||||
|
torch._check(u0 <= u1)
|
||||||
|
out = x[u0:u1]
|
||||||
|
assert statically_known_true(out.size(0) == (u1 - u0))
|
||||||
|
return out
|
||||||
|
|
||||||
|
x, xs = torch.randn(10), torch.tensor([3, 6])
|
||||||
|
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
|
||||||
|
self.assertEqual(fn1(x, xs).size(0), 3)
|
||||||
|
self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs)))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
fn1(x, torch.tensor([-1, 5]))
|
||||||
|
|
||||||
|
# known negative slice
|
||||||
|
def f2(x, n):
|
||||||
|
u0 = n.item()
|
||||||
|
torch._check(u0 > 1)
|
||||||
|
torch._check(u0 <= x.size(0))
|
||||||
|
out = x[-u0:]
|
||||||
|
assert statically_known_true(out.size(0) == u0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
x, n = torch.randn(10), torch.tensor([5])
|
||||||
|
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
|
||||||
|
self.assertEqual(fn2(x, n).size(0), 5)
|
||||||
|
self.assertTrue(torch.allclose(fn2(x, n), f2(x, n)))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
fn2(x, torch.tensor([-5]))
|
||||||
|
|
||||||
|
# general case: no known info
|
||||||
|
def f3(x, xs):
|
||||||
|
u0, u1 = xs.tolist()
|
||||||
|
return x[u0:u1]
|
||||||
|
|
||||||
|
log_stream, ctx = logs_to_string(
|
||||||
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||||
|
)
|
||||||
|
cnts = CompileCounterWithBackend("inductor")
|
||||||
|
x, xs = torch.randn(10), torch.tensor([3, 6])
|
||||||
|
with ctx():
|
||||||
|
fn3 = torch.compile(f3, fullgraph=True, backend=cnts)
|
||||||
|
xs = torch.tensor([-9, -1]) # negative case
|
||||||
|
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
|
||||||
|
xs = torch.tensor([-1000, 1000]) # out of bounds
|
||||||
|
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
|
||||||
|
xs = torch.tensor([2, -2]) # mixed
|
||||||
|
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
|
||||||
|
self.assertEqual(cnts.frame_count, 1)
|
||||||
|
|
||||||
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||||
|
self.assertExpectedInline(
|
||||||
|
aot_graphs,
|
||||||
|
"""\
|
||||||
|
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||||
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
||||||
|
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
||||||
|
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
||||||
|
slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None
|
||||||
|
sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
|
||||||
|
ge_2: "Sym(u2 >= 0)" = sym_size_int >= 0
|
||||||
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_2 = _assert_scalar = None
|
||||||
|
le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None
|
||||||
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None
|
||||||
|
sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
|
||||||
|
ge_3: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None
|
||||||
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_2 = None
|
||||||
|
return (slice_1,)""", # noqa: B950
|
||||||
|
ignore_comments=True,
|
||||||
|
ignore_empty_lines=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||||
|
def test_unbacked_slice_cpp_wrapper(self):
|
||||||
|
self.test_unbacked_slice()
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_tensor_split(self):
|
||||||
|
def f1(x, xs):
|
||||||
|
xs = torch.tensor(xs.tolist())
|
||||||
|
return torch.tensor_split(x, xs)
|
||||||
|
|
||||||
|
x = torch.randn(20)
|
||||||
|
xs = torch.tensor([5, 10, 15])
|
||||||
|
fn = torch.compile(f1, fullgraph=True, backend="inductor")
|
||||||
|
|
||||||
|
def compare(x, xs):
|
||||||
|
for i, j in zip(f1(x, xs), fn(x, xs)):
|
||||||
|
self.assertTrue(torch.allclose(i, j))
|
||||||
|
|
||||||
|
compare(x, xs)
|
||||||
|
xs = torch.tensor([-15, 9, 10, 11])
|
||||||
|
compare(x, xs)
|
||||||
|
xs = torch.tensor([-15, -10, -5, -2])
|
||||||
|
compare(x, xs)
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||||
|
def test_tensor_split_cpp_wrapper(self):
|
||||||
|
self.test_tensor_split()
|
||||||
|
|
||||||
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
||||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
def test_unbacked_non_contigious_reshape_failing(self):
|
def test_unbacked_non_contigious_reshape_failing(self):
|
||||||
|
@ -1973,7 +1973,6 @@ make_fx_failures = {
|
|||||||
skip('item'),
|
skip('item'),
|
||||||
xfail('cov'),
|
xfail('cov'),
|
||||||
xfail('nn.functional.gaussian_nll_loss'),
|
xfail('nn.functional.gaussian_nll_loss'),
|
||||||
xfail('tensor_split'),
|
|
||||||
xfail('corrcoef'),
|
xfail('corrcoef'),
|
||||||
xfail('quantile'),
|
xfail('quantile'),
|
||||||
xfail('nanquantile'),
|
xfail('nanquantile'),
|
||||||
@ -1993,10 +1992,12 @@ make_fx_failures = {
|
|||||||
|
|
||||||
only_real_tensor_failures = {
|
only_real_tensor_failures = {
|
||||||
xfail('narrow'),
|
xfail('narrow'),
|
||||||
|
xfail('tensor_split'),
|
||||||
}
|
}
|
||||||
|
|
||||||
only_fake_tensor_failures = {
|
only_fake_tensor_failures = {
|
||||||
xfail('narrow'),
|
xfail('narrow'),
|
||||||
|
xfail('tensor_split'),
|
||||||
}
|
}
|
||||||
|
|
||||||
fake_tensor_failures = set()
|
fake_tensor_failures = set()
|
||||||
|
@ -6,6 +6,7 @@ import numbers
|
|||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, reduce
|
from functools import partial, reduce
|
||||||
from itertools import chain, product
|
from itertools import chain, product
|
||||||
@ -721,10 +722,7 @@ def slice_forward(
|
|||||||
end: Optional[int] = None,
|
end: Optional[int] = None,
|
||||||
step: int = 1,
|
step: int = 1,
|
||||||
):
|
):
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||||
guard_size_oblivious,
|
|
||||||
statically_known_true,
|
|
||||||
)
|
|
||||||
|
|
||||||
ndim = self.dim()
|
ndim = self.dim()
|
||||||
if ndim == 0:
|
if ndim == 0:
|
||||||
@ -739,22 +737,22 @@ def slice_forward(
|
|||||||
start_val = start if start is not None else 0
|
start_val = start if start is not None else 0
|
||||||
end_val = end if end is not None else sys.maxsize # 2^63 - 1
|
end_val = end if end is not None else sys.maxsize # 2^63 - 1
|
||||||
|
|
||||||
if guard_size_oblivious(start_val < 0):
|
if start_val < 0:
|
||||||
start_val += sizes[dim]
|
start_val += sizes[dim]
|
||||||
|
|
||||||
if guard_size_oblivious(end_val < 0):
|
if end_val < 0:
|
||||||
end_val += sizes[dim]
|
end_val += sizes[dim]
|
||||||
|
|
||||||
if guard_size_oblivious(start_val < 0):
|
if start_val < 0:
|
||||||
start_val = 0
|
start_val = 0
|
||||||
elif guard_size_oblivious(start_val > sizes[dim]):
|
elif start_val > sizes[dim]:
|
||||||
start_val = sizes[dim]
|
start_val = sizes[dim]
|
||||||
|
|
||||||
if statically_known_true(end_val == sys.maxsize):
|
if statically_known_true(end_val == sys.maxsize):
|
||||||
end_val = sizes[dim]
|
end_val = sizes[dim]
|
||||||
elif guard_size_oblivious(end_val < start_val):
|
elif end_val < start_val:
|
||||||
end_val = start_val
|
end_val = start_val
|
||||||
elif guard_size_oblivious(end_val > sizes[dim]):
|
elif end_val > sizes[dim]:
|
||||||
end_val = sizes[dim]
|
end_val = sizes[dim]
|
||||||
|
|
||||||
storage_offset = self.storage_offset() + start_val * strides[dim]
|
storage_offset = self.storage_offset() + start_val * strides[dim]
|
||||||
@ -1438,7 +1436,17 @@ def tensor_split_tensor_indices_or_sections_py_impl(
|
|||||||
assert isinstance(sections, IntLike)
|
assert isinstance(sections, IntLike)
|
||||||
return self.tensor_split(sections, dim)
|
return self.tensor_split(sections, dim)
|
||||||
else:
|
else:
|
||||||
indices = [i.item() for i in tensor_indices_or_sections]
|
ctx = nullcontext
|
||||||
|
if (fake_mode := torch._guards.detect_fake_mode()) and (
|
||||||
|
shape_env := fake_mode.shape_env
|
||||||
|
):
|
||||||
|
ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
|
||||||
|
# In fake tensor prop, we end up calling slice() with these unbacked indices.
|
||||||
|
# Because slice has flexible semantics, the unbacked handling generates new output sizes
|
||||||
|
# for each slice, effectively clobbering over these index symbols.
|
||||||
|
# To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
|
||||||
|
with ctx():
|
||||||
|
indices = [i.item() for i in tensor_indices_or_sections]
|
||||||
# WARNING: Tempted to torch._check_is_size on the indices here? You
|
# WARNING: Tempted to torch._check_is_size on the indices here? You
|
||||||
# can't: tensor_split works with negative values in indices:
|
# can't: tensor_split works with negative values in indices:
|
||||||
#
|
#
|
||||||
|
@ -1456,19 +1456,51 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||||
self.unbacked_symbol_decls.add(str(node.sym))
|
self.unbacked_symbol_decls.add(str(node.sym))
|
||||||
|
|
||||||
def codegen_dynamic_select_index(self, node):
|
def codegen_dynamic_select_index(self, node, clamp):
|
||||||
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
|
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
|
||||||
|
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
|
||||||
|
|
||||||
index_compute_str = (
|
# codegen index
|
||||||
|
sym = node.unbacked_offset_symbol
|
||||||
|
index_str = (
|
||||||
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
|
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
|
||||||
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
|
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
|
||||||
)
|
)
|
||||||
|
self.writeline(f"auto {sym}_index = {index_str};")
|
||||||
|
index_str_clamped = (
|
||||||
|
f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)"
|
||||||
|
if clamp
|
||||||
|
else f"{sym}_index"
|
||||||
|
)
|
||||||
|
self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};")
|
||||||
self.writeline(
|
self.writeline(
|
||||||
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
|
f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
|
||||||
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});"
|
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;"
|
||||||
)
|
)
|
||||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
self.unbacked_symbol_decls.add(str(sym))
|
||||||
|
|
||||||
|
def codegen_dynamic_slice_size(self, node):
|
||||||
|
start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int)
|
||||||
|
end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int)
|
||||||
|
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
|
||||||
|
sym = node.unbacked_size_symbol
|
||||||
|
|
||||||
|
def codegen_clamp(index_str, start=True):
|
||||||
|
suf = "start" if start else "end"
|
||||||
|
index_ = f"{sym}_{suf}_index"
|
||||||
|
self.writeline(
|
||||||
|
f"auto {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};"
|
||||||
|
)
|
||||||
|
self.writeline(
|
||||||
|
f"auto {sym}_{suf}_clamped = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});"
|
||||||
|
)
|
||||||
|
|
||||||
|
codegen_clamp(start_cpp_str, start=True)
|
||||||
|
codegen_clamp(end_cpp_str, start=False)
|
||||||
|
self.writeline(f"auto {sym}_raw = {sym}_end_clamped - {sym}_start_clamped;")
|
||||||
|
self.writeline(f"auto {sym} = {sym}_raw < 0 ? 0 : {sym}_raw;")
|
||||||
|
self.unbacked_symbol_decls.add(str(sym))
|
||||||
|
|
||||||
def make_buffer_free(self, buffer):
|
def make_buffer_free(self, buffer):
|
||||||
return (
|
return (
|
||||||
|
@ -1817,14 +1817,33 @@ class PythonWrapperCodegen(CodeGen):
|
|||||||
arg_name = node.input_name(0)
|
arg_name = node.input_name(0)
|
||||||
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
||||||
|
|
||||||
def codegen_dynamic_select_index(self, node):
|
def codegen_dynamic_select_index(self, node, clamp):
|
||||||
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
|
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
|
||||||
|
if clamp:
|
||||||
|
index_str = f"max(0, min({node.size}, {index_str}))"
|
||||||
self.writeline(
|
self.writeline(
|
||||||
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
|
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
|
||||||
)
|
)
|
||||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
||||||
|
|
||||||
|
def codegen_dynamic_slice_size(self, node):
|
||||||
|
def clamp_index(x):
|
||||||
|
pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size)))
|
||||||
|
neg = self.codegen_sizevar(
|
||||||
|
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||||
|
)
|
||||||
|
return f"{pos} if {x} >= 0 else {neg}"
|
||||||
|
|
||||||
|
# codegen start, end
|
||||||
|
sym = node.unbacked_size_symbol
|
||||||
|
start = clamp_index(node.start)
|
||||||
|
end = clamp_index(node.end)
|
||||||
|
self.writeline(f"{sym}_start = {start}")
|
||||||
|
self.writeline(f"{sym}_end = {end}")
|
||||||
|
self.writeline(f"{sym} = max(0, {sym}_end - {sym}_start)")
|
||||||
|
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
|
||||||
|
|
||||||
def codegen_dynamic_scalar(self, node):
|
def codegen_dynamic_scalar(self, node):
|
||||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||||
if len(node.keypath) == 0:
|
if len(node.keypath) == 0:
|
||||||
|
@ -3437,7 +3437,6 @@ class SliceView(View):
|
|||||||
if val is None:
|
if val is None:
|
||||||
# TODO(rec): can this really happen?
|
# TODO(rec): can this really happen?
|
||||||
return default
|
return default
|
||||||
val = cls.handle_negative_index(val, dim_size)
|
|
||||||
return clamp(val, lower, upper)
|
return clamp(val, lower, upper)
|
||||||
|
|
||||||
start = clamp_wrap(start, 0, dim_size, 0)
|
start = clamp_wrap(start, 0, dim_size, 0)
|
||||||
@ -3454,14 +3453,6 @@ class SliceView(View):
|
|||||||
step: int = 1,
|
step: int = 1,
|
||||||
clamp: bool = True,
|
clamp: bool = True,
|
||||||
) -> IRNode:
|
) -> IRNode:
|
||||||
step = sympy.expand(step)
|
|
||||||
assert isinstance(step, Expr) or step > 0, step
|
|
||||||
try:
|
|
||||||
if start == 0 and end >= 2**63 - 1 and step == 1:
|
|
||||||
return x
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
new_size = list(x.get_size())
|
new_size = list(x.get_size())
|
||||||
|
|
||||||
# NB: Ordinarily we default to clamping.
|
# NB: Ordinarily we default to clamping.
|
||||||
@ -7221,6 +7212,7 @@ class DynamicSelectStorageOffset(ExternKernel):
|
|||||||
base_offset: Union[sympy.Symbol, int],
|
base_offset: Union[sympy.Symbol, int],
|
||||||
base_dim_stride: Union[sympy.Symbol, int],
|
base_dim_stride: Union[sympy.Symbol, int],
|
||||||
size: Union[sympy.Symbol, int],
|
size: Union[sympy.Symbol, int],
|
||||||
|
clamp: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
||||||
# This node codegen the following:
|
# This node codegen the following:
|
||||||
@ -7230,6 +7222,7 @@ class DynamicSelectStorageOffset(ExternKernel):
|
|||||||
self.base_offset = base_offset
|
self.base_offset = base_offset
|
||||||
self.base_dim_stride = base_dim_stride
|
self.base_dim_stride = base_dim_stride
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.clamp = clamp
|
||||||
|
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
return OrderedSet([self.unbacked_offset_symbol])
|
return OrderedSet([self.unbacked_offset_symbol])
|
||||||
@ -7240,7 +7233,57 @@ class DynamicSelectStorageOffset(ExternKernel):
|
|||||||
return get_free_symbols(self.index, unbacked_only)
|
return get_free_symbols(self.index, unbacked_only)
|
||||||
|
|
||||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||||
wrapper.codegen_dynamic_select_index(self)
|
wrapper.codegen_dynamic_select_index(self, clamp=self.clamp)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicSliceSize(ExternKernel):
|
||||||
|
"""
|
||||||
|
Computes the output size of a slice call, handling the correct semantics in codegen.
|
||||||
|
We do this for flexible handling for unbacked indices (to not data-dependent error).
|
||||||
|
|
||||||
|
Slicing has 4 semantics for indices, i.e. x[start:] could be:
|
||||||
|
1) start < -x.size(0) -> x[0:] # negative out-of-bounds
|
||||||
|
2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing
|
||||||
|
3) start in [0, x.size(0)) -> x[start:] # standard slicing
|
||||||
|
4) start >= x.size(0) -> empty slice # positive out-of-bounds
|
||||||
|
|
||||||
|
If the appropriate semantics are known beforehand, the output size is computed based on
|
||||||
|
the start & end indices. If not (with unbacked indices), a new unbacked symbol is created
|
||||||
|
to represent the output size, and codegen handles computing the correct case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_reads(self) -> OrderedSet[Dep]:
|
||||||
|
return OrderedSet()
|
||||||
|
|
||||||
|
def should_allocate(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unbacked_size_symbol: sympy.Symbol,
|
||||||
|
start: sympy.Symbol,
|
||||||
|
end: Union[sympy.Symbol, int],
|
||||||
|
size: Union[sympy.Symbol, int],
|
||||||
|
):
|
||||||
|
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
||||||
|
# This node codegen
|
||||||
|
self.unbacked_size_symbol = unbacked_size_symbol
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
|
return OrderedSet([self.unbacked_size_symbol])
|
||||||
|
|
||||||
|
def get_free_symbol_uses(
|
||||||
|
self, unbacked_only: bool = False
|
||||||
|
) -> OrderedSet[sympy.Symbol]:
|
||||||
|
return get_free_symbols(self.start, unbacked_only).union(
|
||||||
|
get_free_symbols(self.end, unbacked_only)
|
||||||
|
)
|
||||||
|
|
||||||
|
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||||
|
wrapper.codegen_dynamic_slice_size(self)
|
||||||
|
|
||||||
|
|
||||||
class DynamicScalar(ExternKernel):
|
class DynamicScalar(ExternKernel):
|
||||||
|
@ -1172,9 +1172,130 @@ def permute(x, dims):
|
|||||||
|
|
||||||
@register_lowering(aten.slice, type_promotion_kind=None)
|
@register_lowering(aten.slice, type_promotion_kind=None)
|
||||||
def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
|
def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
|
||||||
|
"""
|
||||||
|
Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
|
||||||
|
if the indices are unbacked and appropriate semantics aren't known.
|
||||||
|
If they are known (indices are static/backed/unbacked with info), a SliceView is created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
CallMethodKey,
|
||||||
|
resolve_unbacked_bindings,
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(x, TensorBox)
|
assert isinstance(x, TensorBox)
|
||||||
dim = _validate_dim(x, dim, 0)
|
dim = _validate_dim(x, dim, 0)
|
||||||
return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp))
|
size = x.get_size()[dim]
|
||||||
|
step = sympy.expand(step)
|
||||||
|
assert isinstance(step, sympy.Expr) or step > 0, step
|
||||||
|
|
||||||
|
# maybe apply slice optimization
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
start == 0
|
||||||
|
and V.graph.sizevars.statically_known_leq(size, end)
|
||||||
|
and step == 1
|
||||||
|
):
|
||||||
|
return x
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# try to avoid dynamic slice
|
||||||
|
def handle_negative_index(idx, size, default):
|
||||||
|
if idx is None:
|
||||||
|
return default
|
||||||
|
idx = sympy.expand(idx)
|
||||||
|
size = sympy.expand(size)
|
||||||
|
if V.graph.sizevars.guard_or_false(idx >= 0):
|
||||||
|
return idx
|
||||||
|
elif V.graph.sizevars.guard_or_false(idx < 0):
|
||||||
|
return size + idx
|
||||||
|
return None
|
||||||
|
|
||||||
|
ambiguous_slice = clamp
|
||||||
|
if ambiguous_slice:
|
||||||
|
start_index = handle_negative_index(start, size, 0)
|
||||||
|
end_index = handle_negative_index(end, size, size)
|
||||||
|
if start_index is not None and end_index is not None:
|
||||||
|
start, end = start_index, end_index
|
||||||
|
ambiguous_slice = False
|
||||||
|
|
||||||
|
# ambiguous_slice=False means we know what semantics this slice call follows,
|
||||||
|
# and don't need to generate an extern kernel to represent the output size.
|
||||||
|
# This is assumed True for clamp=False
|
||||||
|
# (meant to follow standard indexing semantics: 0 <= index < size)
|
||||||
|
if not ambiguous_slice:
|
||||||
|
return TensorBox(
|
||||||
|
ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)
|
||||||
|
) # go to SliceView/ReinterpretView
|
||||||
|
|
||||||
|
# unbacked territory: create DynamicSlice ExternKernel
|
||||||
|
# clamp is True, unbacked start / end
|
||||||
|
assert clamp
|
||||||
|
unbacked_bindings = resolve_unbacked_bindings(
|
||||||
|
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
||||||
|
)
|
||||||
|
assert unbacked_bindings is not None
|
||||||
|
assert len(unbacked_bindings) <= 2, unbacked_bindings
|
||||||
|
sym_size, sym_storage = None, None
|
||||||
|
for sym, keypath in unbacked_bindings.items():
|
||||||
|
if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)):
|
||||||
|
sym_size = sym
|
||||||
|
elif keypath == (CallMethodKey("storage_offset"),):
|
||||||
|
sym_storage = sym
|
||||||
|
|
||||||
|
def compute_slice_index(index, size):
|
||||||
|
fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731
|
||||||
|
|
||||||
|
if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)):
|
||||||
|
return index
|
||||||
|
elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)):
|
||||||
|
return -index
|
||||||
|
elif fn(sympy.Gt(index, size)):
|
||||||
|
return size
|
||||||
|
elif fn(sympy.Lt(index, -size)):
|
||||||
|
return 0
|
||||||
|
return None
|
||||||
|
|
||||||
|
start_index = compute_slice_index(start, size)
|
||||||
|
end_index = compute_slice_index(end, size)
|
||||||
|
if start_index is not None and end_index is not None:
|
||||||
|
# we shouldn't have allocated size symbol, if output size was determinable from input indices
|
||||||
|
assert sym_size is None
|
||||||
|
new_size = sympy.Max(0, end_index - start_index)
|
||||||
|
else:
|
||||||
|
b_size = ir.DynamicSliceSize(
|
||||||
|
sym_size,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
x.get_size()[dim],
|
||||||
|
)
|
||||||
|
b_size.name = V.graph.register_buffer(b_size)
|
||||||
|
V.graph.register_operation(b_size)
|
||||||
|
new_size = sym_size
|
||||||
|
|
||||||
|
if start_index is not None:
|
||||||
|
# we shouldn't have allocated storage offset symbol if start index was determinable
|
||||||
|
assert sym_storage is None
|
||||||
|
new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim]
|
||||||
|
else:
|
||||||
|
b_storage = ir.DynamicSelectStorageOffset(
|
||||||
|
sym_storage,
|
||||||
|
start,
|
||||||
|
x.get_layout().offset,
|
||||||
|
x.get_stride()[dim],
|
||||||
|
x.get_size()[dim],
|
||||||
|
clamp=True,
|
||||||
|
)
|
||||||
|
b_storage.name = V.graph.register_buffer(b_storage)
|
||||||
|
V.graph.register_operation(b_storage)
|
||||||
|
new_storage_offset = sym_storage
|
||||||
|
|
||||||
|
new_sizes = list(x.get_size())
|
||||||
|
new_strides = list(x.get_stride())
|
||||||
|
new_sizes[dim] = new_size
|
||||||
|
new_strides[dim] *= step
|
||||||
|
return as_strided(x, new_sizes, new_strides, new_storage_offset)
|
||||||
|
|
||||||
|
|
||||||
@register_lowering(aten.as_strided, type_promotion_kind=None)
|
@register_lowering(aten.as_strided, type_promotion_kind=None)
|
||||||
@ -1800,6 +1921,7 @@ def select(x, dim, idx):
|
|||||||
x.get_layout().offset,
|
x.get_layout().offset,
|
||||||
new_stride[dim],
|
new_stride[dim],
|
||||||
x.get_size()[dim],
|
x.get_size()[dim],
|
||||||
|
clamp=False,
|
||||||
)
|
)
|
||||||
buffer.name = V.graph.register_buffer(buffer)
|
buffer.name = V.graph.register_buffer(buffer)
|
||||||
V.graph.register_operation(buffer)
|
V.graph.register_operation(buffer)
|
||||||
@ -2991,6 +3113,8 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
|
|||||||
dim = _validate_dim(x, dim, 0)
|
dim = _validate_dim(x, dim, 0)
|
||||||
dim_size = x.get_size()[dim]
|
dim_size = x.get_size()[dim]
|
||||||
|
|
||||||
|
start = ir.SliceView.handle_negative_index(start, dim_size)
|
||||||
|
end = ir.SliceView.handle_negative_index(end, dim_size)
|
||||||
start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
|
start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
|
||||||
|
|
||||||
src_size = list(x.get_size())
|
src_size = list(x.get_size())
|
||||||
|
@ -6,7 +6,7 @@ import math
|
|||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Callable, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._custom_op
|
import torch._custom_op
|
||||||
@ -15,6 +15,7 @@ import torch._prims_common as utils
|
|||||||
from torch._dispatch.python import no_python_dispatcher
|
from torch._dispatch.python import no_python_dispatcher
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
from torch._prims_common import (
|
from torch._prims_common import (
|
||||||
|
canonicalize_dim,
|
||||||
contiguous_for_memory_format_or_false,
|
contiguous_for_memory_format_or_false,
|
||||||
elementwise_dtypes,
|
elementwise_dtypes,
|
||||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||||
@ -746,6 +747,88 @@ def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=No
|
|||||||
return padded.new_empty(output_shape)
|
return padded.new_empty(output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_slice_index(size, index):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and
|
||||||
|
|
||||||
|
if guard_or_false(sym_and(index >= 0, index <= size)):
|
||||||
|
return index
|
||||||
|
elif guard_or_false(sym_and(index < 0, index >= -size)):
|
||||||
|
return index + size
|
||||||
|
elif guard_or_false(index < -size):
|
||||||
|
return 0
|
||||||
|
elif guard_or_false(index > size):
|
||||||
|
return size
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@register_op_impl(torch.ops.aten.slice.Tensor)
|
||||||
|
def slice_forward(
|
||||||
|
fake_mode,
|
||||||
|
func,
|
||||||
|
self,
|
||||||
|
dim: int = 0,
|
||||||
|
start: Optional[int] = None,
|
||||||
|
end: Optional[int] = None,
|
||||||
|
step: int = 1,
|
||||||
|
):
|
||||||
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
|
guard_or_false,
|
||||||
|
statically_known_true,
|
||||||
|
)
|
||||||
|
|
||||||
|
shape_env = fake_mode.shape_env
|
||||||
|
|
||||||
|
ndim = self.dim()
|
||||||
|
if ndim == 0:
|
||||||
|
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
|
||||||
|
dim = canonicalize_dim(self.dim(), dim)
|
||||||
|
sizes = list(self.size())
|
||||||
|
strides = list(self.stride())
|
||||||
|
|
||||||
|
if step <= 0:
|
||||||
|
raise RuntimeError("slice step must be positive")
|
||||||
|
|
||||||
|
# start, end
|
||||||
|
start_index = 0 if start is None else _compute_slice_index(sizes[dim], start)
|
||||||
|
end_index = (
|
||||||
|
sizes[dim]
|
||||||
|
if statically_known_true(end == sys.maxsize) or end is None
|
||||||
|
else _compute_slice_index(sizes[dim], end)
|
||||||
|
)
|
||||||
|
|
||||||
|
# size
|
||||||
|
new_size = None
|
||||||
|
if start_index is not None and end_index is not None:
|
||||||
|
if guard_or_false(end_index >= start_index):
|
||||||
|
new_size = (end_index - start_index + step - 1) // step
|
||||||
|
elif guard_or_false(start_index >= end_index):
|
||||||
|
new_size = 0
|
||||||
|
|
||||||
|
# create unbacked if case unknown
|
||||||
|
if new_size is None:
|
||||||
|
new_size = shape_env.create_unbacked_symint()
|
||||||
|
torch._check_is_size(new_size, max=sizes[dim])
|
||||||
|
|
||||||
|
# stride
|
||||||
|
new_stride = strides[dim] * step
|
||||||
|
|
||||||
|
# storage offset
|
||||||
|
if start_index is not None:
|
||||||
|
storage_offset = self.storage_offset() + start_index * strides[dim]
|
||||||
|
else:
|
||||||
|
storage_offset = shape_env.create_unbacked_symint()
|
||||||
|
torch._check(storage_offset >= 0)
|
||||||
|
|
||||||
|
sizes[dim] = new_size
|
||||||
|
strides[dim] = new_stride
|
||||||
|
if self.is_quantized:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Slice decomposition for quantized tensors aren't implemented"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.as_strided(sizes, strides, storage_offset)
|
||||||
|
|
||||||
|
|
||||||
@register_op_impl(torch.ops.aten.masked_select.default)
|
@register_op_impl(torch.ops.aten.masked_select.default)
|
||||||
def masked_select(fake_mode, func, self, mask):
|
def masked_select(fake_mode, func, self, mask):
|
||||||
if (
|
if (
|
||||||
|
@ -2616,7 +2616,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
if (
|
if (
|
||||||
func not in meta_table
|
func not in meta_table
|
||||||
and not self.cpp_meta_supports_symint(func)
|
and not self.cpp_meta_supports_symint(func)
|
||||||
and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops)
|
and not (
|
||||||
|
has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops
|
||||||
|
)
|
||||||
):
|
):
|
||||||
from torch._decomp import decomposition_table
|
from torch._decomp import decomposition_table
|
||||||
|
|
||||||
@ -2925,8 +2927,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
aten._sparse_coo_tensor_with_dims_and_tensors.default,
|
aten._sparse_coo_tensor_with_dims_and_tensors.default,
|
||||||
)
|
)
|
||||||
|
|
||||||
_view_fake_tensor_impl_ops = ordered_set(
|
_unbacked_special_fake_handling_ops = ordered_set(
|
||||||
aten.view.default, aten._unsafe_view.default
|
aten.view.default,
|
||||||
|
aten._unsafe_view.default,
|
||||||
|
aten.slice.Tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cpp_meta_supports_symint(self, func: OpOverload) -> bool:
|
def cpp_meta_supports_symint(self, func: OpOverload) -> bool:
|
||||||
|
Reference in New Issue
Block a user