mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use swap_tensors path in nn.Module.to for all subclasses that override __torch_dispatch__ (#152539)
Fixes https://github.com/pytorch/pytorch/issues/148977 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152539 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b5b1adb21
commit
037343657e
@ -1487,6 +1487,20 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||
|
||||
self.assertEqual(mode.count, 0)
|
||||
|
||||
# PropagateRealTensors installs weakrefs
|
||||
@expectedFailurePropagateRealTensors
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_module_to(self):
|
||||
|
||||
def _check_device(sd, device_type):
|
||||
for v in sd.values():
|
||||
self.assertEqual(v.device.type, device_type)
|
||||
|
||||
with FakeTensorMode():
|
||||
m = torch.nn.Linear(2, 2)
|
||||
_check_device(m.state_dict(), 'cpu')
|
||||
m.to('cuda')
|
||||
_check_device(m.state_dict(), 'cuda')
|
||||
|
||||
make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)
|
||||
|
||||
|
Reference in New Issue
Block a user