mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update base for Update on "[WIP] Support python slicing with tensor inputs."
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela
[ghstack-poisoned]
This commit is contained in: