mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update on "[WIP] Support python slicing with tensor inputs."
allow things like ``` #!/usr/bin/env python import torch print("="*60) print("Testing tensor slicing with torch.compile") print("="*60) # Test 1: Simple eager mode print("\n1. Eager mode test:") x = torch.randn(10) idx = torch.tensor(4) result = x[:idx] print(f" x[:idx] where idx=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Eager mode works!") # Test 2: With torch.compile print("\n2. Compiled mode test:") def slice_fn(x, idx): return x[:idx] try: compiled_fn = torch.compile(slice_fn) x = torch.randn(10) idx = torch.tensor(4) result = compiled_fn(x, idx) print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Compiled mode works!") except Exception as e: print(f" ✗ Compiled mode failed: {e}") import traceback traceback.print_exc() # Test 3: With dynamic slicing from sum print("\n3. Dynamic slicing with sum:") def dynamic_slice_fn(x, lengths): idx = lengths.sum() return x[:idx] try: compiled_fn = torch.compile(dynamic_slice_fn) x = torch.randn(10) lengths = torch.tensor([1, 1, 1, 1]) result = compiled_fn(x, lengths) print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}") assert result.shape[0] == 4 print(" ✓ Dynamic slicing works!") except Exception as e: print(f" ✗ Dynamic slicing failed: {e}") import traceback traceback.print_exc() print("\n" + "="*60) print("SUMMARY: Check results above") print("="*60) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
This commit is contained in:
@ -528,30 +528,6 @@ Attempted to call function marked as skipped
|
||||
f(x)
|
||||
self.assertEqual(len(ws), 2)
|
||||
|
||||
def test_slice_with_tensor(self):
|
||||
def fn(x, y):
|
||||
return x[:y]
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
torch.randn(10),
|
||||
torch.tensor([3]),
|
||||
),
|
||||
"""\
|
||||
Dynamic slicing with Tensor arguments
|
||||
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
|
||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
||||
|
||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
return x[:y]""",
|
||||
)
|
||||
|
||||
def test_observed_exception(self):
|
||||
def fn():
|
||||
raise RuntimeError("test")
|
||||
|
@ -3886,23 +3886,24 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
|
||||
self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
|
||||
|
||||
# Test 3: Multidimensional select with tensor index
|
||||
def f3(x, idx_tensor):
|
||||
return x[:, idx_tensor]
|
||||
# TODO support those less common patterns
|
||||
# # Test 3: Multidimensional select with tensor index
|
||||
# def f3(x, idx_tensor):
|
||||
# return x[:, idx_tensor]
|
||||
|
||||
x_2d = torch.randn(5, 10)
|
||||
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
|
||||
self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
|
||||
# x_2d = torch.randn(5, 10)
|
||||
# fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
|
||||
# self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
|
||||
|
||||
# Test 4: Multiple tensor indices
|
||||
def f4(x, idx1, idx2):
|
||||
return x[idx1, idx2]
|
||||
# def f4(x, idx1, idx2):
|
||||
# return x[idx1, idx2]
|
||||
|
||||
x_2d = torch.randn(8, 12)
|
||||
idx1 = torch.tensor(3)
|
||||
idx2 = torch.tensor(7)
|
||||
fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
|
||||
self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
|
||||
# x_2d = torch.randn(8, 12)
|
||||
# idx1 = torch.tensor(3)
|
||||
# idx2 = torch.tensor(7)
|
||||
# fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
|
||||
# self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
|
@ -1795,7 +1795,7 @@ class VariableBuilder:
|
||||
]
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
if isinstance(value, slice):
|
||||
return SliceVariable(items, source=self.source)
|
||||
return SliceVariable(items, self.tx, source=self.source)
|
||||
else:
|
||||
return RangeVariable(items, source=self.source)
|
||||
|
||||
|
@ -1235,6 +1235,20 @@ class BuiltinVariable(VariableTracker):
|
||||
# scenario is to de-sugar eagerly.
|
||||
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
|
||||
|
||||
# Convert size-1 TensorVariable indices to SymIntVariable by calling .item()
|
||||
# This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level
|
||||
if (
|
||||
fn is operator.getitem
|
||||
and len(args) == 2
|
||||
and isinstance(args[1], variables.TensorVariable)
|
||||
):
|
||||
tensor_idx = args[1]
|
||||
# Only convert if we know it's size-1 (not for advanced indexing)
|
||||
if tensor_idx.size is not None and all(s == 1 for s in tensor_idx.size):
|
||||
args = list(args)
|
||||
args[1] = tensor_idx.call_method(tx, "item", [], {})
|
||||
args = tuple(args)
|
||||
|
||||
if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
|
||||
# Standard indexing will force specialization due to
|
||||
# __index__. Rewrite as a regular torch op which will
|
||||
|
@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable):
|
||||
|
||||
|
||||
class SliceVariable(VariableTracker):
|
||||
def __init__(self, items, tx, **kwargs) -> None:
|
||||
def __init__(self, items, tx=None, **kwargs) -> None:
|
||||
items_to_map = items
|
||||
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
||||
|
||||
@ -1402,10 +1402,22 @@ class SliceVariable(VariableTracker):
|
||||
# Convert TensorVariable to SymIntVariable by calling .item()
|
||||
# This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
|
||||
if isinstance(start, variables.TensorVariable):
|
||||
assert tx is not None, (
|
||||
"tx is required when slice indices are TensorVariables"
|
||||
)
|
||||
assert start.size is None or all(s == 1 for s in start.size)
|
||||
start = start.call_method(tx, "item", [], {})
|
||||
if isinstance(stop, variables.TensorVariable):
|
||||
assert tx is not None, (
|
||||
"tx is required when slice indices are TensorVariables"
|
||||
)
|
||||
assert stop.size is None or all(s == 1 for s in stop.size)
|
||||
stop = stop.call_method(tx, "item", [], {})
|
||||
if isinstance(step, variables.TensorVariable):
|
||||
assert tx is not None, (
|
||||
"tx is required when slice indices are TensorVariables"
|
||||
)
|
||||
assert step.size is None or all(s == 1 for s in step.size)
|
||||
step = step.call_method(tx, "item", [], {})
|
||||
|
||||
self.items = (start, stop, step)
|
||||
|
Reference in New Issue
Block a user