Use swap_tensors path in nn.Module.to for FakeTensor (#152539)

Fixes https://github.com/pytorch/pytorch/issues/148977

Differential Revision: [D76458023](https://our.internmc.facebook.com/intern/diff/D76458023)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152539
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-06-12 08:25:52 -07:00
committed by PyTorch MergeBot
parent db01f1032f
commit 38bfd462b8
2 changed files with 19 additions and 1 deletions

View File

@ -1489,6 +1489,20 @@ class FakeTensorOperatorInvariants(TestCase):
self.assertEqual(mode.count, 0)
# PropagateRealTensors installs weakrefs
@expectedFailurePropagateRealTensors
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_module_to(self):
def _check_device(sd, device_type):
for v in sd.values():
self.assertEqual(v.device.type, device_type)
with FakeTensorMode():
m = torch.nn.Linear(2, 2)
_check_device(m.state_dict(), "cpu")
m.to("cuda")
_check_device(m.state_dict(), "cuda")
make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)

View File

@ -954,9 +954,13 @@ class Module:
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
from torch._subclasses.fake_tensor import FakeTensor
# subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = (
should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
should_use_swap_tensors
or is_traceable_wrapper_subclass(param_applied)
or isinstance(param, FakeTensor)
)
param_grad = param.grad