mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use swap_tensors path in nn.Module.to for all subclasses that override __torch_dispatch__ (#152539)
Fixes https://github.com/pytorch/pytorch/issues/148977 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152539 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b5b1adb21
commit
037343657e
@ -1487,6 +1487,20 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||
|
||||
self.assertEqual(mode.count, 0)
|
||||
|
||||
# PropagateRealTensors installs weakrefs
|
||||
@expectedFailurePropagateRealTensors
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_module_to(self):
|
||||
|
||||
def _check_device(sd, device_type):
|
||||
for v in sd.values():
|
||||
self.assertEqual(v.device.type, device_type)
|
||||
|
||||
with FakeTensorMode():
|
||||
m = torch.nn.Linear(2, 2)
|
||||
_check_device(m.state_dict(), 'cpu')
|
||||
m.to('cuda')
|
||||
_check_device(m.state_dict(), 'cuda')
|
||||
|
||||
make_propagate_real_tensors_cls(FakeTensorOperatorInvariants)
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ import torch
|
||||
from torch import device, dtype, Tensor
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.nn.parameter import Buffer, Parameter
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils.hooks import BackwardHook, RemovableHandle
|
||||
|
||||
|
||||
@ -943,8 +942,10 @@ class Module:
|
||||
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
|
||||
|
||||
# subclasses may have multiple child tensors so we need to use swap_tensors
|
||||
p_should_use_swap_tensors = (
|
||||
should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
|
||||
p_should_use_swap_tensors = should_use_swap_tensors or (
|
||||
hasattr(param, "__torch_dispatch__")
|
||||
and param.__torch_dispatch__ # type: ignore[misc]
|
||||
is not torch._C._disabled_torch_dispatch_impl
|
||||
)
|
||||
|
||||
param_grad = param.grad
|
||||
|
||||
Reference in New Issue
Block a user