diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 2a22098a54d7..cf598791a634 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -528,30 +528,6 @@ Attempted to call function marked as skipped f(x) self.assertEqual(len(ws), 2) - def test_slice_with_tensor(self): - def fn(x, y): - return x[:y] - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)( - torch.randn(10), - torch.tensor([3]), - ), - """\ -Dynamic slicing with Tensor arguments - Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor. - Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - - Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None) - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html - -from user code: - File "test_error_messages.py", line N, in fn - return x[:y]""", - ) - def test_observed_exception(self): def fn(): raise RuntimeError("test") diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fcc45521fbb1..27b375de851f 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3799,6 +3799,118 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", def test_unbacked_slice_with_step_cpp_wrapper(self): self.test_unbacked_slice_with_step() + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_slice_with_tensor_indices(self): + # Test slicing with tensor start/stop/step on RHS (reading) + + # Test 1: Basic slice with tensor start and stop + def f1(x, start_t, stop_t): + return x[start_t:stop_t] + + x = torch.randn(20) + start_t = torch.tensor(5) + stop_t = torch.tensor(15) + fn1 = torch.compile(f1, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t))) + + # Test 2: Slice with tensor step + def f2(x, start_t, stop_t, step_t): + return x[start_t:stop_t:step_t] + + step_t = torch.tensor(2) + fn2 = torch.compile(f2, fullgraph=True, backend="inductor") + self.assertTrue( + torch.allclose( + fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t) + ) + ) + + # Test 3: Slice with only tensor start + def f3(x, start_t): + return x[start_t:] + + fn3 = torch.compile(f3, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t))) + + # Test 4: Slice with only tensor stop + def f4(x, stop_t): + return x[:stop_t] + + fn4 = torch.compile(f4, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t))) + + # Test 5: Negative indices with tensors + def f5(x, start_t): + return x[start_t:-1] + + start_t_neg = torch.tensor(-10) + fn5 = torch.compile(f5, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg))) + + # Test 6: Multidimensional slice with tensor indices + def f6(x, start_t, stop_t): + return x[:, start_t:stop_t] + + x_2d = torch.randn(10, 20) + fn6 = torch.compile(f6, fullgraph=True, backend="inductor") + self.assertTrue( + torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t)) + ) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_slice_with_tensor_indices_cpp_wrapper(self): + self.test_slice_with_tensor_indices() + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_select_with_tensor_index(self): + # Test direct tensor indexing (select) without calling .item() + + # Test 1: Simple 0-d tensor as index + def f1(x, idx_tensor): + return x[idx_tensor] + + x = torch.randn(10) + idx_tensor = torch.tensor(5) + fn1 = torch.compile(f1, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn1(x, idx_tensor), f1(x, idx_tensor))) + + # Test 2: Negative tensor index + def f2(x, idx_tensor): + return x[idx_tensor] + + idx_tensor_neg = torch.tensor(-2) + fn2 = torch.compile(f2, fullgraph=True, backend="inductor") + self.assertTrue(torch.allclose(fn2(x, idx_tensor_neg), f2(x, idx_tensor_neg))) + + # TODO support those less common patterns + # # Test 3: Multidimensional select with tensor index + # def f3(x, idx_tensor): + # return x[:, idx_tensor] + + # x_2d = torch.randn(5, 10) + # fn3 = torch.compile(f3, fullgraph=True, backend="inductor") + # self.assertTrue(torch.allclose(fn3(x_2d, idx_tensor), f3(x_2d, idx_tensor))) + + # Test 4: Multiple tensor indices + # def f4(x, idx1, idx2): + # return x[idx1, idx2] + + # x_2d = torch.randn(8, 12) + # idx1 = torch.tensor(3) + # idx2 = torch.tensor(7) + # fn4 = torch.compile(f4, fullgraph=True, backend="inductor") + # self.assertTrue(torch.allclose(fn4(x_2d, idx1, idx2), f4(x_2d, idx1, idx2))) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_select_with_tensor_index_cpp_wrapper(self): + self.test_select_with_tensor_index() + @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_tensor_split(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 5815473d41f9..c3169369c6d1 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3138,7 +3138,7 @@ class InstructionTranslatorBase( def BUILD_SLICE(self, inst: Instruction) -> None: items = self.popn(inst.argval) - self.push(SliceVariable(items)) + self.push(SliceVariable(items, tx=self)) def BUILD_LIST(self, inst: Instruction) -> None: items = self.popn(inst.argval) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5fab51234d74..48e7164e567f 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1795,7 +1795,7 @@ class VariableBuilder: ] self.install_guards(GuardBuilder.TYPE_MATCH) if isinstance(value, slice): - return SliceVariable(items, source=self.source) + return SliceVariable(items, self.tx, source=self.source) else: return RangeVariable(items, source=self.source) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 09bdb81150e6..71a98e1bef1c 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1235,6 +1235,20 @@ class BuiltinVariable(VariableTracker): # scenario is to de-sugar eagerly. fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]] + # Convert size-1 TensorVariable indices to SymIntVariable by calling .item() + # This decomposes tensor[t] to u=t.item(); tensor[u] at the dynamo level + if ( + fn is operator.getitem + and len(args) == 2 + and isinstance(args[1], variables.TensorVariable) + ): + tensor_idx = args[1] + # Only convert if we know it's size-1 (not for advanced indexing) + if tensor_idx.size is not None and all(s == 1 for s in tensor_idx.size): + args = list(args) + args[1] = tensor_idx.call_method(tx, "item", [], {}) + args = tuple(args) + if fn is operator.getitem and isinstance(args[1], SymNodeVariable): # Standard indexing will force specialization due to # __index__. Rewrite as a regular torch op which will @@ -1745,7 +1759,7 @@ class BuiltinVariable(VariableTracker): ) def call_slice(self, tx: "InstructionTranslator", *args): - return variables.SliceVariable(args) + return variables.SliceVariable(args, tx) def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs): from .builder import wrap_fx_proxy diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index c6b9434b6f05..1498aeb9c564 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1386,7 +1386,7 @@ class NamedTupleVariable(TupleVariable): class SliceVariable(VariableTracker): - def __init__(self, items, **kwargs) -> None: + def __init__(self, items, tx=None, **kwargs) -> None: items_to_map = items start, stop, step = [variables.ConstantVariable.create(None)] * 3 @@ -1399,18 +1399,27 @@ class SliceVariable(VariableTracker): else: raise AssertionError - if isinstance(start, variables.TensorVariable) or isinstance( - stop, variables.TensorVariable - ): - unimplemented_v2( - gb_type="Dynamic slicing with Tensor arguments", - context=f"SliceVariable start: {start}, stop: {stop}, step: {step}", - explanation="Creating slices with Tensor arguments is not supported. " - "e.g. `l[:x]`, where `x` is a 1-element tensor.", - hints=[ - *graph_break_hints.SUPPORTABLE, - ], + # Convert TensorVariable to SymIntVariable by calling .item() + # This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level + if isinstance(start, variables.TensorVariable): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" ) + assert start.size is None or all(s == 1 for s in start.size) + start = start.call_method(tx, "item", [], {}) + if isinstance(stop, variables.TensorVariable): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + assert stop.size is None or all(s == 1 for s in stop.size) + stop = stop.call_method(tx, "item", [], {}) + if isinstance(step, variables.TensorVariable): + assert tx is not None, ( + "tx is required when slice indices are TensorVariables" + ) + assert step.size is None or all(s == 1 for s in step.size) + step = step.call_method(tx, "item", [], {}) + self.items = (start, stop, step) super().__init__(**kwargs)