diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 6baaaf26b9c5..fcc45521fbb1 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4301,6 +4301,34 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1])) ) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_item_set_item3(self): + def func(x, y): + u0 = y.item() + x[u0] = 0 + return x + + compiled = torch.compile(func, fullgraph=True, disable=False) + b = torch.tensor([0]) + a = torch.ones(9, dtype=torch.int32) + + compiled(a, b) + self.assertEqual(compiled(a, b), func(a, b)) + + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_select_scatter_unbacked_index(self): + def func(x, y): + u0 = y.item() + # Create a scalar tensor to scatter into the selected index + scalar_src = torch.tensor(42, dtype=x.dtype) + return x.select_scatter(scalar_src, 0, u0) + + compiled = torch.compile(func, fullgraph=True, dynamic=True, backend="inductor") + b = torch.tensor([0]) + a = torch.ones(9, dtype=torch.int32) + + self.assertEqual(compiled(a, b), func(a, b)) + instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 3a4900900540..8b9deac6ba5a 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -24,7 +24,10 @@ from torch._inductor.lowering import ( inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, ) from torch._inductor.virtualized import V -from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + GuardOnDataDependentSymNode, +) from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.passes.reinplace import _is_view_op from torch.utils import _pytree as pytree @@ -60,7 +63,9 @@ def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): fake_result = fn(*fake_args, **fake_kwargs) node = graph.call_function(fn, args, kwargs) + node.meta["val"] = fake_result + return node @@ -171,6 +176,13 @@ def _decompose_scatter_mutating( tmp = inp for view in view_ops: # type: ignore[union-attr] tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] + # we need to set unbacked bindings that could have been created in the view ops. + if (V.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings( + V.fake_mode.shape_env, tmp.meta["val"] + ) + ): + tmp.meta["unbacked_bindings"] = symbol_to_path graph_call_function(graph, aten.copy_.default, tmp, src) return inp # type: ignore[return-value] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4c28ee8faf59..56a88caf6c7d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4542,9 +4542,7 @@ class ComputedBuffer(OperationBuffer): unbacked_only ) | self.data.get_free_symbol_uses(unbacked_only) - if self.has_store_function() and isinstance( - self.get_store_function(), LoopBody - ): + if self.has_store_function(): result |= self.get_read_writes().get_free_symbol_uses(unbacked_only) return result