set unbacked bindings in reinplace pass for newly created nodes during generalize_scatter decomp (#164948)

Two fixes:
1. in rein_place pass, set unbacked bindings for newly created nodes.
2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164948
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #164341
This commit is contained in:
Laith Sakka
2025-10-17 11:01:15 -07:00
committed by PyTorch MergeBot
parent c6a8db0b9a
commit 017d2985f3
3 changed files with 42 additions and 4 deletions

View File

@ -4301,6 +4301,34 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_item_set_item3(self):
def func(x, y):
u0 = y.item()
x[u0] = 0
return x
compiled = torch.compile(func, fullgraph=True, disable=False)
b = torch.tensor([0])
a = torch.ones(9, dtype=torch.int32)
compiled(a, b)
self.assertEqual(compiled(a, b), func(a, b))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_select_scatter_unbacked_index(self):
def func(x, y):
u0 = y.item()
# Create a scalar tensor to scatter into the selected index
scalar_src = torch.tensor(42, dtype=x.dtype)
return x.select_scatter(scalar_src, 0, u0)
compiled = torch.compile(func, fullgraph=True, dynamic=True, backend="inductor")
b = torch.tensor([0])
a = torch.ones(9, dtype=torch.int32)
self.assertEqual(compiled(a, b), func(a, b))
instantiate_parametrized_tests(TestUnbacked)