mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Restore fake device (#157972)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/157972 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
27c50799c1
commit
1f1f22991d
@ -253,6 +253,19 @@ class FakeTensorTest(TestCase):
|
||||
assert x.copy_(y).device.type == "cpu"
|
||||
assert y.copy_(x).device.type == "cuda"
|
||||
|
||||
def test_fake_device(self):
|
||||
t = torch.ones(3)
|
||||
t = t.view(1, 3)
|
||||
|
||||
fake_mode1 = FakeTensorMode(allow_non_fake_inputs=True)
|
||||
fake_t = fake_mode1.from_tensor(t)
|
||||
fake_t.fake_device = torch.device("cuda")
|
||||
|
||||
fake_mode2 = FakeTensorMode(allow_non_fake_inputs=True)
|
||||
new_fake_t = fake_mode2.from_tensor(fake_t)
|
||||
|
||||
self.assertEqual(new_fake_t.device, fake_t.device)
|
||||
|
||||
def test_fake_dispatch_keys(self):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([4])
|
||||
|
Reference in New Issue
Block a user