From a2c2c3295ca8ef762e420634e4c325cb2ddd7749 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sun, 19 Oct 2025 17:27:06 -0700 Subject: [PATCH] WIP Support python slicing with data depedennt inptu tensors maybe ghstack-source-id: 4abcd9a5a4de65fa0d205c0b101998f48f6d9655 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074 --- test/dynamo/test_error_messages.py | 24 ------ test/test_dynamic_shapes.py | 114 ++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/variables/builder.py | 2 +- torch/_dynamo/variables/builtin.py | 126 ++++++++++++++++++++++++++++- torch/_dynamo/variables/lists.py | 33 +++++--- 6 files changed, 262 insertions(+), 39 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 2a22098a54d7..cf598791a634 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -528,30 +528,6 @@ Attempted to call function marked as skipped f(x) self.assertEqual(len(ws), 2) - def test_slice_with_tensor(self): - def fn(x, y): - return x[:y] - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)( - torch.randn(10), - torch.tensor([3]), - ), - """\ -Dynamic slicing with Tensor arguments - Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor. - Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - - Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None) - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html - -from user code: - File "test_error_messages.py", line N, in fn - return x[:y]""", - ) - def test_observed_exception(self): def fn(): raise RuntimeError("test") diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fcc45521fbb1..ab1a24ed1bc4 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3799,6 +3799,120 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", def test_unbacked_slice_with_step_cpp_wrapper(self): self.test_unbacked_slice_with_step() + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_slice_with_tensor_indices(self): + # Test slicing with tensor start/stop/step on RHS (reading) + + # Test 1: Basic slice with tensor start and stop + def f1(x, start_t, stop_t): + return x[start_t:stop_t] + + x = torch.randn(20) + start_t = torch.tensor(5) + stop_t = torch.tensor(15) + fn1 = torch.compile(f1, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t))) + + # Test 2: Slice with tensor step + def f2(x, start_t, stop_t, step_t): + return x[start_t:stop_t:step_t] + + step_t = torch.tensor(2) + fn2 = torch.compile(f2, fullgraph=True, backend="inductor") + self.assertTrue( + torch.allclose( + fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t) + ) + ) + + # Test 3: Slice with only tensor start + def f3(x, start_t): + return x[start_t:] + + fn3 = torch.compile(f3, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t))) + + # Test 4: Slice with only tensor stop + def f4(x, stop_t): + return x[:stop_t] + + fn4 = torch.compile(f4, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t))) + + # Test 5: Negative indices with tensors + def f5(x, start_t): + return x[start_t:-1] + + start_t_neg = torch.tensor(-10) + fn5 = torch.compile(f5, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg))) + + # Test 6: Multidimensional slice with tensor indices + def f6(x, start_t, stop_t): + return x[:, start_t:stop_t] + + x_2d = torch.randn(10, 20) + fn6 = torch.compile(f6, fullgraph=True, backend="inductor") + self.assertTrue( + torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t)) + ) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_slice_with_tensor_indices_cpp_wrapper(self): + self.test_slice_with_tensor_indices() + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_select_with_tensor_index(self): + # Test direct tensor indexing (select) without calling .item() + + # # Test 1: Simple 0-d tensor as index + # def f1(x, idx_tensor): + # return x[idx_tensor] + + # x = torch.randn(10) + idx_tensor = torch.tensor(5) + # fn1 = torch.compile(f1, fullgraph=True, backend="inductor") + # self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor))) + + # # Test 2: Negative tensor index + # def f2(x, idx_tensor): + # return x[idx_tensor] + + # idx_tensor_neg = torch.tensor(-2) + # fn2 = torch.compile(f2, fullgraph=True, backend="inductor") + # self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg))) + + # Test 3: Multidimensional select with tensor index + def f3(x, idx_tensor): + return x[:, idx_tensor] + + x_2d = torch.randn(5, 10) + fn3 = torch.compile(f3, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor))) + + # # Hit inductor error + # u7 = u6 + 1 * (u5 + 12 if u5 < 0 else u5) + # NameError: name 'u6' is not defined + # # Test 4: Multiple tensor indices + # def f4(x, idx1, idx2): + # return x[idx1, idx2] + + # x_2d = torch.randn(8, 12) + # idx1 = torch.tensor(3) + # idx2 = torch.tensor(7) + # fn4 = torch.compile(f4, fullgraph=True, backend="inductor") + # self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2))) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_select_with_tensor_index_cpp_wrapper(self): + self.test_select_with_tensor_index() + @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_tensor_split(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 5815473d41f9..c3169369c6d1 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3138,7 +3138,7 @@ class InstructionTranslatorBase( def BUILD_SLICE(self, inst: Instruction) -> None: items = self.popn(inst.argval) - self.push(SliceVariable(items)) + self.push(SliceVariable(items, tx=self)) def BUILD_LIST(self, inst: Instruction) -> None: items = self.popn(inst.argval) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5fab51234d74..48e7164e567f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1795,7 +1795,7 @@ class VariableBuilder: ] self.install_guards(GuardBuilder.TYPE_MATCH) if isinstance(value, slice): - return SliceVariable(items, source=self.source) + return SliceVariable(items, self.tx, source=self.source) else: return RangeVariable(items, source=self.source) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 09bdb81150e6..f0d172f068d0 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1235,6 +1235,53 @@ class BuiltinVariable(VariableTracker): # scenario is to de-sugar eagerly. fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]] + # Convert size-1 TensorVariable indices to SymIntVariable by calling .item() + # This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level + # Only convert if the tensor doesn't contain unbacked symints (data-dependent values) + if fn is operator.getitem and len(args) == 2: + from torch.fx.experimental.symbolic_shapes import has_free_symbols + + def should_convert_to_item(tensor_idx): + """Check if we should convert size-1 tensor to scalar.""" + if not isinstance(tensor_idx, variables.TensorVariable): + return False + # Only convert if size-1 or 0-d + if tensor_idx._size is None or not all( + s == 1 for s in tensor_idx._size + ): + return False + # Don't convert if it has unbacked symints (data-dependent) + example_value = tensor_idx.proxy.node.meta.get("example_value") + return example_value is None or not has_free_symbols(example_value) + + index_arg = args[1] + if isinstance( + index_arg, variables.TensorVariable + ) and should_convert_to_item(index_arg): + args = list(args) + args[1] = index_arg.call_method(tx, "item", [], {}) + args = tuple(args) + elif isinstance( + index_arg, (variables.TupleVariable, variables.ListVariable) + ): + # Multi-dimensional indexing: tensor[:, idx] or tensor[idx1, idx2] + new_items = [] + changed = False + for item in index_arg.items: + if should_convert_to_item(item): + new_items.append(item.call_method(tx, "item", [], {})) + changed = True + else: + new_items.append(item) + if changed: + args = list(args) + args[1] = ( + variables.TupleVariable(new_items) + if isinstance(index_arg, variables.TupleVariable) + else variables.ListVariable(new_items) + ) + args = tuple(args) + if fn is operator.getitem and isinstance(args[1], SymNodeVariable): # Standard indexing will force specialization due to # __index__. Rewrite as a regular torch op which will @@ -1247,6 +1294,83 @@ class BuiltinVariable(VariableTracker): args[1], ], ) + elif fn is operator.getitem and isinstance( + args[1], variables.TupleVariable + ): + # Handle tuple with SymNodeVariables: x[symint1, symint2] or x[:, symint] + # Decompose into sequential operations, tracking dimension changes + index_items = args[1].items + if any(isinstance(item, SymNodeVariable) for item in index_items): + result = args[0] + dims_removed = 0 # Track how many dimensions have been removed + + for original_dim, item in enumerate(index_items): + # Current dimension in the result tensor + current_dim = original_dim - dims_removed + + if isinstance(item, SymNodeVariable): + # Apply torch.select at current_dim (removes this dimension) + result = variables.TorchInGraphFunctionVariable( + torch.select + ).call_function( + tx, + [ + result, + variables.ConstantVariable.create(current_dim), + item, + ], + {}, + ) + dims_removed += 1 + elif isinstance(item, variables.SliceVariable): + # Slicing keeps the dimension + result = variables.BuiltinVariable( + operator.getitem + ).call_function(tx, [result, item], {}) + else: + # Regular scalar index (also removes dimension) + result = variables.BuiltinVariable( + operator.getitem + ).call_function(tx, [result, item], {}) + dims_removed += 1 + return result + + elif fn is operator.getitem and isinstance( + args[1], (variables.TupleVariable, variables.ListVariable) + ): + # Check if we have SymNodeVariable inside tuple: tensor[:, symnode] + # Rewrite as torch.select to avoid DDE + index_items = args[1].items + symnode_indices = [ + i + for i, item in enumerate(index_items) + if isinstance(item, SymNodeVariable) + ] + + if len(symnode_indices) == 1: + # Single SymNode in tuple - rewrite as torch.select + symnode_idx = symnode_indices[0] + symnode = index_items[symnode_idx] + + # Check that all other indices are slices or ellipsis + non_symnode_indices = [ + i for i in range(len(index_items)) if i != symnode_idx + ] + if all( + isinstance( + index_items[i], + (variables.SliceVariable, variables.ConstantVariable), + ) + for i in non_symnode_indices + ): + fn, args = ( + torch.select, + [ + args[0], + variables.ConstantVariable.create(symnode_idx), + symnode, + ], + ) # Interaction between ndarray and tensors: # We prefer the tensor op whenever there are tensors involved @@ -1745,7 +1869,7 @@ class BuiltinVariable(VariableTracker): ) def call_slice(self, tx: "InstructionTranslator", *args): - return variables.SliceVariable(args) + return variables.SliceVariable(args, tx) def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): from .builder import wrap_fx_proxy diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index c6b9434b6f05..1498aeb9c564 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable): class SliceVariable(VariableTracker): - def __init__(self, items, **kwargs) -> None: + def __init__(self, items, tx=None, **kwargs) -> None: items_to_map = items start, stop, step = [variables.ConstantVariable.create(None)] * 3 @@ -1399,18 +1399,27 @@ class SliceVariable(VariableTracker): else: raise AssertionError - if isinstance(start, variables.TensorVariable) or isinstance( - stop, variables.TensorVariable - ): - unimplemented_v2( - gb_type="Dynamic slicing with Tensor arguments", - context=f"SliceVariable start: {start}, stop: {stop}, step: {step}", - explanation="Creating slices with Tensor arguments is not supported. " - "e.g. `l[:x]`, where `x` is a 1-element tensor.", - hints=[ - *graph_break_hints.SUPPORTABLE, - ], + # Convert TensorVariable to SymIntVariable by calling .item() + # This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level + if isinstance(start, variables.TensorVariable): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" ) + assert start.size is None or all(s == 1 for s in start.size) + start = start.call_method(tx, "item", [], {}) + if isinstance(stop, variables.TensorVariable): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + assert stop.size is None or all(s == 1 for s in stop.size) + stop = stop.call_method(tx, "item", [], {}) + if isinstance(step, variables.TensorVariable): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + assert step.size is None or all(s == 1 for s in step.size) + step = step.call_method(tx, "item", [], {}) + self.items = (start, stop, step) super().__init__(**kwargs)