Update on "[WIP] Support python slicing with tensor inputs."

allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
This commit is contained in:
Laith Sakka
2025-10-19 17:27:06 -07:00
2 changed files with 141 additions and 29 deletions

View File

@ -3869,33 +3869,35 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
def test_select_with_tensor_index(self):
# Test direct tensor indexing (select) without calling .item()
# Test 1: Simple 0-d tensor as index
def f1(x, idx_tensor):
return x[idx_tensor]
# # Test 1: Simple 0-d tensor as index
# def f1(x, idx_tensor):
# return x[idx_tensor]
x = torch.randn(10)
# x = torch.randn(10)
idx_tensor = torch.tensor(5)
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor)))
# fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor)))
# Test 2: Negative tensor index
def f2(x, idx_tensor):
return x[idx_tensor]
# # Test 2: Negative tensor index
# def f2(x, idx_tensor):
# return x[idx_tensor]
idx_tensor_neg = torch.tensor(-2)
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
# idx_tensor_neg = torch.tensor(-2)
# fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
# TODO support those less common patterns
# # Test 3: Multidimensional select with tensor index
# def f3(x, idx_tensor):
# return x[:, idx_tensor]
# Test 3: Multidimensional select with tensor index
def f3(x, idx_tensor):
return x[:, idx_tensor]
# x_2d = torch.randn(5, 10)
# fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
x_2d = torch.randn(5, 10)
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
# Test 4: Multiple tensor indices
# # Hit inductor error
# u7 = u6 + 1 * (u5 + 12 if u5 < 0 else u5)
# NameError: name 'u6' is not defined
# # Test 4: Multiple tensor indices
# def f4(x, idx1, idx2):
# return x[idx1, idx2]

View File

@ -1237,17 +1237,50 @@ class BuiltinVariable(VariableTracker):
# Convert size-1 TensorVariable indices to SymIntVariable by calling .item()
# This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level
if (
fn is operator.getitem
and len(args) == 2
and isinstance(args[1], variables.TensorVariable)
):
tensor_idx = args[1]
# Only convert if we know it's size-1 (not for advanced indexing)
if tensor_idx.size is not None and all(s == 1 for s in tensor_idx.size):
# Only convert if the tensor doesn't contain unbacked symints (data-dependent values)
if fn is operator.getitem and len(args) == 2:
from torch.fx.experimental.symbolic_shapes import has_free_symbols
def should_convert_to_item(tensor_idx):
"""Check if we should convert size-1 tensor to scalar."""
if not isinstance(tensor_idx, variables.TensorVariable):
return False
# Only convert if size-1 or 0-d
if tensor_idx._size is None or not all(
s == 1 for s in tensor_idx._size
):
return False
# Don't convert if it has unbacked symints (data-dependent)
example_value = tensor_idx.proxy.node.meta.get("example_value")
return example_value is None or not has_free_symbols(example_value)
index_arg = args[1]
if isinstance(
index_arg, variables.TensorVariable
) and should_convert_to_item(index_arg):
args = list(args)
args[1] = tensor_idx.call_method(tx, "item", [], {})
args[1] = index_arg.call_method(tx, "item", [], {})
args = tuple(args)
elif isinstance(
index_arg, (variables.TupleVariable, variables.ListVariable)
):
# Multi-dimensional indexing: tensor[:, idx] or tensor[idx1, idx2]
new_items = []
changed = False
for item in index_arg.items:
if should_convert_to_item(item):
new_items.append(item.call_method(tx, "item", [], {}))
changed = True
else:
new_items.append(item)
if changed:
args = list(args)
args[1] = (
variables.TupleVariable(new_items)
if isinstance(index_arg, variables.TupleVariable)
else variables.ListVariable(new_items)
)
args = tuple(args)
if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
# Standard indexing will force specialization due to
@ -1261,6 +1294,83 @@ class BuiltinVariable(VariableTracker):
args[1],
],
)
elif fn is operator.getitem and isinstance(
args[1], variables.TupleVariable
):
# Handle tuple with SymNodeVariables: x[symint1, symint2] or x[:, symint]
# Decompose into sequential operations, tracking dimension changes
index_items = args[1].items
if any(isinstance(item, SymNodeVariable) for item in index_items):
result = args[0]
dims_removed = 0 # Track how many dimensions have been removed
for original_dim, item in enumerate(index_items):
# Current dimension in the result tensor
current_dim = original_dim - dims_removed
if isinstance(item, SymNodeVariable):
# Apply torch.select at current_dim (removes this dimension)
result = variables.TorchInGraphFunctionVariable(
torch.select
).call_function(
tx,
[
result,
variables.ConstantVariable.create(current_dim),
item,
],
{},
)
dims_removed += 1
elif isinstance(item, variables.SliceVariable):
# Slicing keeps the dimension
result = variables.BuiltinVariable(
operator.getitem
).call_function(tx, [result, item], {})
else:
# Regular scalar index (also removes dimension)
result = variables.BuiltinVariable(
operator.getitem
).call_function(tx, [result, item], {})
dims_removed += 1
return result
elif fn is operator.getitem and isinstance(
args[1], (variables.TupleVariable, variables.ListVariable)
):
# Check if we have SymNodeVariable inside tuple: tensor[:, symnode]
# Rewrite as torch.select to avoid DDE
index_items = args[1].items
symnode_indices = [
i
for i, item in enumerate(index_items)
if isinstance(item, SymNodeVariable)
]
if len(symnode_indices) == 1:
# Single SymNode in tuple - rewrite as torch.select
symnode_idx = symnode_indices[0]
symnode = index_items[symnode_idx]
# Check that all other indices are slices or ellipsis
non_symnode_indices = [
i for i in range(len(index_items)) if i != symnode_idx
]
if all(
isinstance(
index_items[i],
(variables.SliceVariable, variables.ConstantVariable),
)
for i in non_symnode_indices
):
fn, args = (
torch.select,
[
args[0],
variables.ConstantVariable.create(symnode_idx),
symnode,
],
)
# Interaction between ndarray and tensors:
# We prefer the tensor op whenever there are tensors involved