mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 01:50:04 +08:00
fix set_() with functionalization (#90722)
This should fix https://github.com/pytorch/pytorch/issues/90573 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90722 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
548960f68e
commit
440a3f2398
@ -146,6 +146,17 @@ class TestFunctionalization(TestCase):
|
||||
r = _functionalize(f, reapply_views=True, crossref=self.crossref)(torch.ones(2, 2))
|
||||
self.assertEqual(r.stride(), (5, 1))
|
||||
|
||||
def test_set_(self):
|
||||
def f(x):
|
||||
y = torch.ones(2)
|
||||
y.set_(x.storage())
|
||||
return y
|
||||
|
||||
# We should probaby get the crossref test to work,
|
||||
# but fixing it for Storage() objects is annoying.
|
||||
r = _functionalize(f, reapply_views=True, crossref=False)(torch.ones(2))
|
||||
self.assertEqual(str(r.device), 'cpu')
|
||||
|
||||
def test_view_clone_view_inplace(self):
|
||||
def f(input):
|
||||
shape = [1, 1024, 128, 128]
|
||||
|
||||
Reference in New Issue
Block a user