mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
set unbacked bindings in reinplace pass for newly created nodes during generalize_scatter decomp (#164948)
Two fixes: 1. in rein_place pass, set unbacked bindings for newly created nodes. 2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164948 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #164341
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6a8db0b9a
commit
017d2985f3
@ -4301,6 +4301,34 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_item_set_item3(self):
|
||||
def func(x, y):
|
||||
u0 = y.item()
|
||||
x[u0] = 0
|
||||
return x
|
||||
|
||||
compiled = torch.compile(func, fullgraph=True, disable=False)
|
||||
b = torch.tensor([0])
|
||||
a = torch.ones(9, dtype=torch.int32)
|
||||
|
||||
compiled(a, b)
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_select_scatter_unbacked_index(self):
|
||||
def func(x, y):
|
||||
u0 = y.item()
|
||||
# Create a scalar tensor to scatter into the selected index
|
||||
scalar_src = torch.tensor(42, dtype=x.dtype)
|
||||
return x.select_scatter(scalar_src, 0, u0)
|
||||
|
||||
compiled = torch.compile(func, fullgraph=True, dynamic=True, backend="inductor")
|
||||
b = torch.tensor([0])
|
||||
a = torch.ones(9, dtype=torch.int32)
|
||||
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
Reference in New Issue
Block a user