WIP Support python slicing with data depedennt inptu tensors maybe

ghstack-source-id: 4abcd9a5a4de65fa0d205c0b101998f48f6d9655
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074
This commit is contained in:
Laith Sakka
2025-10-19 17:27:06 -07:00
parent e595136187
commit a2c2c3295c
6 changed files with 262 additions and 39 deletions

View File

@ -528,30 +528,6 @@ Attempted to call function marked as skipped
f(x)
self.assertEqual(len(ws), 2)
def test_slice_with_tensor(self):
def fn(x, y):
return x[:y]
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.randn(10),
torch.tensor([3]),
),
"""\
Dynamic slicing with Tensor arguments
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
from user code:
File "test_error_messages.py", line N, in fn
return x[:y]""",
)
def test_observed_exception(self):
def fn():
raise RuntimeError("test")

View File

@ -3799,6 +3799,120 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
def test_unbacked_slice_with_step_cpp_wrapper(self):
self.test_unbacked_slice_with_step()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_slice_with_tensor_indices(self):
# Test slicing with tensor start/stop/step on RHS (reading)
# Test 1: Basic slice with tensor start and stop
def f1(x, start_t, stop_t):
return x[start_t:stop_t]
x = torch.randn(20)
start_t = torch.tensor(5)
stop_t = torch.tensor(15)
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t)))
# Test 2: Slice with tensor step
def f2(x, start_t, stop_t, step_t):
return x[start_t:stop_t:step_t]
step_t = torch.tensor(2)
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertTrue(
torch.allclose(
fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t)
)
)
# Test 3: Slice with only tensor start
def f3(x, start_t):
return x[start_t:]
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t)))
# Test 4: Slice with only tensor stop
def f4(x, stop_t):
return x[:stop_t]
fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t)))
# Test 5: Negative indices with tensors
def f5(x, start_t):
return x[start_t:-1]
start_t_neg = torch.tensor(-10)
fn5 = torch.compile(f5, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg)))
# Test 6: Multidimensional slice with tensor indices
def f6(x, start_t, stop_t):
return x[:, start_t:stop_t]
x_2d = torch.randn(10, 20)
fn6 = torch.compile(f6, fullgraph=True, backend="inductor")
self.assertTrue(
torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t))
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_slice_with_tensor_indices_cpp_wrapper(self):
self.test_slice_with_tensor_indices()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_select_with_tensor_index(self):
# Test direct tensor indexing (select) without calling .item()
# # Test 1: Simple 0-d tensor as index
# def f1(x, idx_tensor):
# return x[idx_tensor]
# x = torch.randn(10)
idx_tensor = torch.tensor(5)
# fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor)))
# # Test 2: Negative tensor index
# def f2(x, idx_tensor):
# return x[idx_tensor]
# idx_tensor_neg = torch.tensor(-2)
# fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
# Test 3: Multidimensional select with tensor index
def f3(x, idx_tensor):
return x[:, idx_tensor]
x_2d = torch.randn(5, 10)
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
# # Hit inductor error
# u7 = u6 + 1 * (u5 + 12 if u5 < 0 else u5)
# NameError: name 'u6' is not defined
# # Test 4: Multiple tensor indices
# def f4(x, idx1, idx2):
# return x[idx1, idx2]
# x_2d = torch.randn(8, 12)
# idx1 = torch.tensor(3)
# idx2 = torch.tensor(7)
# fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_select_with_tensor_index_cpp_wrapper(self):
self.test_select_with_tensor_index()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_tensor_split(self):

View File

@ -3138,7 +3138,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
self.push(SliceVariable(items))
self.push(SliceVariable(items, tx=self))
def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval)

View File

@ -1795,7 +1795,7 @@ class VariableBuilder:
]
self.install_guards(GuardBuilder.TYPE_MATCH)
if isinstance(value, slice):
return SliceVariable(items, source=self.source)
return SliceVariable(items, self.tx, source=self.source)
else:
return RangeVariable(items, source=self.source)

View File

@ -1235,6 +1235,53 @@ class BuiltinVariable(VariableTracker):
# scenario is to de-sugar eagerly.
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
# Convert size-1 TensorVariable indices to SymIntVariable by calling .item()
# This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level
# Only convert if the tensor doesn't contain unbacked symints (data-dependent values)
if fn is operator.getitem and len(args) == 2:
from torch.fx.experimental.symbolic_shapes import has_free_symbols
def should_convert_to_item(tensor_idx):
"""Check if we should convert size-1 tensor to scalar."""
if not isinstance(tensor_idx, variables.TensorVariable):
return False
# Only convert if size-1 or 0-d
if tensor_idx._size is None or not all(
s == 1 for s in tensor_idx._size
):
return False
# Don't convert if it has unbacked symints (data-dependent)
example_value = tensor_idx.proxy.node.meta.get("example_value")
return example_value is None or not has_free_symbols(example_value)
index_arg = args[1]
if isinstance(
index_arg, variables.TensorVariable
) and should_convert_to_item(index_arg):
args = list(args)
args[1] = index_arg.call_method(tx, "item", [], {})
args = tuple(args)
elif isinstance(
index_arg, (variables.TupleVariable, variables.ListVariable)
):
# Multi-dimensional indexing: tensor[:, idx] or tensor[idx1, idx2]
new_items = []
changed = False
for item in index_arg.items:
if should_convert_to_item(item):
new_items.append(item.call_method(tx, "item", [], {}))
changed = True
else:
new_items.append(item)
if changed:
args = list(args)
args[1] = (
variables.TupleVariable(new_items)
if isinstance(index_arg, variables.TupleVariable)
else variables.ListVariable(new_items)
)
args = tuple(args)
if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
# Standard indexing will force specialization due to
# __index__. Rewrite as a regular torch op which will
@ -1247,6 +1294,83 @@ class BuiltinVariable(VariableTracker):
args[1],
],
)
elif fn is operator.getitem and isinstance(
args[1], variables.TupleVariable
):
# Handle tuple with SymNodeVariables: x[symint1, symint2] or x[:, symint]
# Decompose into sequential operations, tracking dimension changes
index_items = args[1].items
if any(isinstance(item, SymNodeVariable) for item in index_items):
result = args[0]
dims_removed = 0 # Track how many dimensions have been removed
for original_dim, item in enumerate(index_items):
# Current dimension in the result tensor
current_dim = original_dim - dims_removed
if isinstance(item, SymNodeVariable):
# Apply torch.select at current_dim (removes this dimension)
result = variables.TorchInGraphFunctionVariable(
torch.select
).call_function(
tx,
[
result,
variables.ConstantVariable.create(current_dim),
item,
],
{},
)
dims_removed += 1
elif isinstance(item, variables.SliceVariable):
# Slicing keeps the dimension
result = variables.BuiltinVariable(
operator.getitem
).call_function(tx, [result, item], {})
else:
# Regular scalar index (also removes dimension)
result = variables.BuiltinVariable(
operator.getitem
).call_function(tx, [result, item], {})
dims_removed += 1
return result
elif fn is operator.getitem and isinstance(
args[1], (variables.TupleVariable, variables.ListVariable)
):
# Check if we have SymNodeVariable inside tuple: tensor[:, symnode]
# Rewrite as torch.select to avoid DDE
index_items = args[1].items
symnode_indices = [
i
for i, item in enumerate(index_items)
if isinstance(item, SymNodeVariable)
]
if len(symnode_indices) == 1:
# Single SymNode in tuple - rewrite as torch.select
symnode_idx = symnode_indices[0]
symnode = index_items[symnode_idx]
# Check that all other indices are slices or ellipsis
non_symnode_indices = [
i for i in range(len(index_items)) if i != symnode_idx
]
if all(
isinstance(
index_items[i],
(variables.SliceVariable, variables.ConstantVariable),
)
for i in non_symnode_indices
):
fn, args = (
torch.select,
[
args[0],
variables.ConstantVariable.create(symnode_idx),
symnode,
],
)
# Interaction between ndarray and tensors:
# We prefer the tensor op whenever there are tensors involved
@ -1745,7 +1869,7 @@ class BuiltinVariable(VariableTracker):
)
def call_slice(self, tx: "InstructionTranslator", *args):
return variables.SliceVariable(args)
return variables.SliceVariable(args, tx)
def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
from .builder import wrap_fx_proxy

View File

@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable):
class SliceVariable(VariableTracker):
def __init__(self, items, **kwargs) -> None:
def __init__(self, items, tx=None, **kwargs) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1399,18 +1399,27 @@ class SliceVariable(VariableTracker):
else:
raise AssertionError
if isinstance(start, variables.TensorVariable) or isinstance(
stop, variables.TensorVariable
):
unimplemented_v2(
gb_type="Dynamic slicing with Tensor arguments",
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
explanation="Creating slices with Tensor arguments is not supported. "
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
# Convert TensorVariable to SymIntVariable by calling .item()
# This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
if isinstance(start, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert start.size is None or all(s == 1 for s in start.size)
start = start.call_method(tx, "item", [], {})
if isinstance(stop, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert stop.size is None or all(s == 1 for s in stop.size)
stop = stop.call_method(tx, "item", [], {})
if isinstance(step, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert step.size is None or all(s == 1 for s in step.size)
step = step.call_method(tx, "item", [], {})
self.items = (start, stop, step)
super().__init__(**kwargs)