WIP Support python slicing with data depedennt inptu tensors maybe

ghstack-source-id: 5363cf5565c2c024260b6ae504ba12ffca2f9984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074
This commit is contained in:
Laith Sakka
2025-10-19 08:41:38 -07:00
parent e595136187
commit 0df34492ef
6 changed files with 150 additions and 39 deletions

View File

@ -528,30 +528,6 @@ Attempted to call function marked as skipped
f(x) f(x)
self.assertEqual(len(ws), 2) self.assertEqual(len(ws), 2)
def test_slice_with_tensor(self):
def fn(x, y):
return x[:y]
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.randn(10),
torch.tensor([3]),
),
"""\
Dynamic slicing with Tensor arguments
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
from user code:
File "test_error_messages.py", line N, in fn
return x[:y]""",
)
def test_observed_exception(self): def test_observed_exception(self):
def fn(): def fn():
raise RuntimeError("test") raise RuntimeError("test")

View File

@ -3799,6 +3799,118 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
def test_unbacked_slice_with_step_cpp_wrapper(self): def test_unbacked_slice_with_step_cpp_wrapper(self):
self.test_unbacked_slice_with_step() self.test_unbacked_slice_with_step()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_slice_with_tensor_indices(self):
# Test slicing with tensor start/stop/step on RHS (reading)
# Test 1: Basic slice with tensor start and stop
def f1(x, start_t, stop_t):
return x[start_t:stop_t]
x = torch.randn(20)
start_t = torch.tensor(5)
stop_t = torch.tensor(15)
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t)))
# Test 2: Slice with tensor step
def f2(x, start_t, stop_t, step_t):
return x[start_t:stop_t:step_t]
step_t = torch.tensor(2)
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertTrue(
torch.allclose(
fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t)
)
)
# Test 3: Slice with only tensor start
def f3(x, start_t):
return x[start_t:]
fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t)))
# Test 4: Slice with only tensor stop
def f4(x, stop_t):
return x[:stop_t]
fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t)))
# Test 5: Negative indices with tensors
def f5(x, start_t):
return x[start_t:-1]
start_t_neg = torch.tensor(-10)
fn5 = torch.compile(f5, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg)))
# Test 6: Multidimensional slice with tensor indices
def f6(x, start_t, stop_t):
return x[:, start_t:stop_t]
x_2d = torch.randn(10, 20)
fn6 = torch.compile(f6, fullgraph=True, backend="inductor")
self.assertTrue(
torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t))
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_slice_with_tensor_indices_cpp_wrapper(self):
self.test_slice_with_tensor_indices()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_select_with_tensor_index(self):
# Test direct tensor indexing (select) without calling .item()
# Test 1: Simple 0-d tensor as index
def f1(x, idx_tensor):
return x[idx_tensor]
x = torch.randn(10)
idx_tensor = torch.tensor(5)
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor)))
# Test 2: Negative tensor index
def f2(x, idx_tensor):
return x[idx_tensor]
idx_tensor_neg = torch.tensor(-2)
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg)))
# TODO support those less common patterns
# # Test 3: Multidimensional select with tensor index
# def f3(x, idx_tensor):
# return x[:, idx_tensor]
# x_2d = torch.randn(5, 10)
# fn3 = torch.compile(f3, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor)))
# Test 4: Multiple tensor indices
# def f4(x, idx1, idx2):
# return x[idx1, idx2]
# x_2d = torch.randn(8, 12)
# idx1 = torch.tensor(3)
# idx2 = torch.tensor(7)
# fn4 = torch.compile(f4, fullgraph=True, backend="inductor")
# self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2)))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_select_with_tensor_index_cpp_wrapper(self):
self.test_select_with_tensor_index()
@fresh_cache() @fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_tensor_split(self): def test_tensor_split(self):

View File

@ -3138,7 +3138,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None: def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval) items = self.popn(inst.argval)
self.push(SliceVariable(items)) self.push(SliceVariable(items, tx=self))
def BUILD_LIST(self, inst: Instruction) -> None: def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval) items = self.popn(inst.argval)

View File

@ -1795,7 +1795,7 @@ class VariableBuilder:
] ]
self.install_guards(GuardBuilder.TYPE_MATCH) self.install_guards(GuardBuilder.TYPE_MATCH)
if isinstance(value, slice): if isinstance(value, slice):
return SliceVariable(items, source=self.source) return SliceVariable(items, self.tx, source=self.source)
else: else:
return RangeVariable(items, source=self.source) return RangeVariable(items, source=self.source)

View File

@ -1235,6 +1235,20 @@ class BuiltinVariable(VariableTracker):
# scenario is to de-sugar eagerly. # scenario is to de-sugar eagerly.
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]] fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]]
# Convert size-1 TensorVariable indices to SymIntVariable by calling .item()
# This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level
if (
fn is operator.getitem
and len(args) == 2
and isinstance(args[1], variables.TensorVariable)
):
tensor_idx = args[1]
# Only convert if we know it's size-1 (not for advanced indexing)
if tensor_idx.size is not None and all(s == 1 for s in tensor_idx.size):
args = list(args)
args[1] = tensor_idx.call_method(tx, "item", [], {})
args = tuple(args)
if fn is operator.getitem and isinstance(args[1], SymNodeVariable): if fn is operator.getitem and isinstance(args[1], SymNodeVariable):
# Standard indexing will force specialization due to # Standard indexing will force specialization due to
# __index__. Rewrite as a regular torch op which will # __index__. Rewrite as a regular torch op which will
@ -1745,7 +1759,7 @@ class BuiltinVariable(VariableTracker):
) )
def call_slice(self, tx: "InstructionTranslator", *args): def call_slice(self, tx: "InstructionTranslator", *args):
return variables.SliceVariable(args) return variables.SliceVariable(args, tx)
def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
from .builder import wrap_fx_proxy from .builder import wrap_fx_proxy

View File

@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable):
class SliceVariable(VariableTracker): class SliceVariable(VariableTracker):
def __init__(self, items, **kwargs) -> None: def __init__(self, items, tx=None, **kwargs) -> None:
items_to_map = items items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3 start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1399,18 +1399,27 @@ class SliceVariable(VariableTracker):
else: else:
raise AssertionError raise AssertionError
if isinstance(start, variables.TensorVariable) or isinstance( # Convert TensorVariable to SymIntVariable by calling .item()
stop, variables.TensorVariable # This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
): if isinstance(start, variables.TensorVariable):
unimplemented_v2( assert tx is not None, (
gb_type="Dynamic slicing with Tensor arguments", "tx is required when slice indices are TensorVariables"
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
explanation="Creating slices with Tensor arguments is not supported. "
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
) )
assert start.size is None or all(s == 1 for s in start.size)
start = start.call_method(tx, "item", [], {})
if isinstance(stop, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert stop.size is None or all(s == 1 for s in stop.size)
stop = stop.call_method(tx, "item", [], {})
if isinstance(step, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
assert step.size is None or all(s == 1 for s in step.size)
step = step.call_method(tx, "item", [], {})
self.items = (start, stop, step) self.items = (start, stop, step)
super().__init__(**kwargs) super().__init__(**kwargs)