Use swap_tensors path in nn.Module.to for all subclasses that override __torch_dispatch__ (#152539)

Fixes https://github.com/pytorch/pytorch/issues/148977

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152539
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-04-30 14:53:00 -07:00
committed by PyTorch MergeBot
parent 4b5b1adb21
commit 037343657e
2 changed files with 18 additions and 3 deletions

View File

@ -1487,6 +1487,20 @@ class FakeTensorOperatorInvariants(TestCase):
self.assertEqual(mode.count, 0) self.assertEqual(mode.count, 0)
# PropagateRealTensors installs weakrefs
@expectedFailurePropagateRealTensors
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_module_to(self):
def _check_device(sd, device_type):
for v in sd.values():
self.assertEqual(v.device.type, device_type)
with FakeTensorMode():
m = torch.nn.Linear(2, 2)
_check_device(m.state_dict(), 'cpu')
m.to('cuda')
_check_device(m.state_dict(), 'cuda')
make_propagate_real_tensors_cls(FakeTensorOperatorInvariants) make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)

View File

@ -14,7 +14,6 @@ import torch
from torch import device, dtype, Tensor from torch import device, dtype, Tensor
from torch._prims_common import DeviceLikeType from torch._prims_common import DeviceLikeType
from torch.nn.parameter import Buffer, Parameter from torch.nn.parameter import Buffer, Parameter
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.hooks import BackwardHook, RemovableHandle from torch.utils.hooks import BackwardHook, RemovableHandle
@ -943,8 +942,10 @@ class Module:
p_should_use_set_data = compute_should_use_set_data(param, param_applied) p_should_use_set_data = compute_should_use_set_data(param, param_applied)
# subclasses may have multiple child tensors so we need to use swap_tensors # subclasses may have multiple child tensors so we need to use swap_tensors
p_should_use_swap_tensors = ( p_should_use_swap_tensors = should_use_swap_tensors or (
should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) hasattr(param, "__torch_dispatch__")
and param.__torch_dispatch__ # type: ignore[misc]
is not torch._C._disabled_torch_dispatch_impl
) )
param_grad = param.grad param_grad = param.grad