Restore fake device (#157972)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157972
Approved by: https://github.com/ezyang
This commit is contained in:
angelayi
2025-07-11 16:11:57 +00:00
committed by PyTorch MergeBot
parent 27c50799c1
commit 1f1f22991d
2 changed files with 15 additions and 0 deletions

View File

@ -253,6 +253,19 @@ class FakeTensorTest(TestCase):
assert x.copy_(y).device.type == "cpu"
assert y.copy_(x).device.type == "cuda"
def test_fake_device(self):
t = torch.ones(3)
t = t.view(1, 3)
fake_mode1 = FakeTensorMode(allow_non_fake_inputs=True)
fake_t = fake_mode1.from_tensor(t)
fake_t.fake_device = torch.device("cuda")
fake_mode2 = FakeTensorMode(allow_non_fake_inputs=True)
new_fake_t = fake_mode2.from_tensor(fake_t)
self.assertEqual(new_fake_t.device, fake_t.device)
def test_fake_dispatch_keys(self):
with FakeTensorMode():
x = torch.rand([4])