fix set_() with functionalization (#90722)

This should fix https://github.com/pytorch/pytorch/issues/90573

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90722
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2022-12-16 15:04:17 +00:00
committed by PyTorch MergeBot
parent 548960f68e
commit 440a3f2398
3 changed files with 21 additions and 1 deletions

View File

@ -146,6 +146,17 @@ class TestFunctionalization(TestCase):
r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
self.assertEqual(r.stride(), (5, 1))
def test_set_(self):
def f(x):
y = torch.ones(2)
y.set_(x.storage())
return y
# We should probaby get the crossref test to work,
# but fixing it for Storage() objects is annoying.
r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
self.assertEqual(str(r.device), 'cpu')
def test_view_clone_view_inplace(self):
def f(input):
shape = [1, 1024, 128, 128]