Update on "[WIP] Support python slicing with tensor inputs."

allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
This commit is contained in:
Laith Sakka
2025-10-19 08:41:38 -07:00
5 changed files with 42 additions and 39 deletions

View File

@ -528,30 +528,6 @@ Attempted to call function marked as skipped
f(x)
self.assertEqual(len(ws), 2)
def test_slice_with_tensor(self):
def fn(x, y):
return x[:y]
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.randn(10),
torch.tensor([3]),
),
"""\
Dynamic slicing with Tensor arguments
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
from user code:
File "test_error_messages.py", line N, in fn
return x[:y]""",
)
def test_observed_exception(self):
def fn():
raise RuntimeError("test")

View File

@ -3886,23 +3886,24 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
# Test 3: Multidimensional select with tensor index
def f3(x, idx_tensor):
return x[:, idx_tensor]
# TODO support those less common patterns
# # Test 3: Multidimensional select with tensor index
# def f3(x, idx_tensor):
# return x[:, idx_tensor]
x_2d = torch.randn(5, 10)
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
# x_2d = torch.randn(5, 10)
# fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
# Test 4: Multiple tensor indices
def f4(x, idx1, idx2):
return x[idx1, idx2]
# def f4(x, idx1, idx2):
# return x[idx1, idx2]
x_2d = torch.randn(8, 12)
idx1 = torch.tensor(3)
idx2 = torch.tensor(7)
fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
# x_2d = torch.randn(8, 12)
# idx1 = torch.tensor(3)
# idx2 = torch.tensor(7)
# fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)

View File

@ -1795,7 +1795,7 @@ class VariableBuilder:
]
self.install_guards(GuardBuilder.TYPE_MATCH)
if isinstance(value, slice):
return SliceVariable(items, source=self.source)
return SliceVariable(items, self.tx, source=self.source)
else:
return RangeVariable(items, source=self.source)

View File

@ -1235,6 +1235,20 @@ class BuiltinVariable(VariableTracker):
# scenario is to de-sugar eagerly.
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
# Convert size-1 TensorVariable indices to SymIntVariable by calling .item()
# This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level
if (
fn is operator.getitem
and len(args) == 2
and isinstance(args[1], variables.TensorVariable)
):
tensor_idx = args[1]
# Only convert if we know it's size-1 (not for advanced indexing)
if tensor_idx.size is not None and all(s == 1 for s in tensor_idx.size):
args = list(args)
args[1] = tensor_idx.call_method(tx, "item", [], {})
args = tuple(args)
if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
# Standard indexing will force specialization due to
# __index__. Rewrite as a regular torch op which will

View File

@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable):
class SliceVariable(VariableTracker):
def __init__(self, items, tx, **kwargs) -> None:
def __init__(self, items, tx=None, **kwargs) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1402,10 +1402,22 @@ class SliceVariable(VariableTracker):
# Convert TensorVariable to SymIntVariable by calling .item()
# This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
if isinstance(start, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert start.size is None or all(s == 1 for s in start.size)
start = start.call_method(tx, "item", [], {})
if isinstance(stop, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert stop.size is None or all(s == 1 for s in stop.size)
stop = stop.call_method(tx, "item", [], {})
if isinstance(step, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert step.size is None or all(s == 1 for s in step.size)
step = step.call_method(tx, "item", [], {})
self.items = (start, stop, step)