mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[easy] Add testing utilties for torch.nn.utils.set_swap_module_params_on_conversion (#118023)
For above PR to parametrize existing `load_state_dict` tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/118023 Approved by: https://github.com/albanD ghstack dependencies: #118028, #117167
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5a718d27b
commit
23b030a79c
@ -15,7 +15,7 @@ from torch.testing._internal.common_device_type import (
|
||||
from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
|
||||
gradgradcheck, parametrize)
|
||||
gradgradcheck, parametrize, wrapSwapTensorsTest)
|
||||
from unittest.mock import patch, call
|
||||
|
||||
|
||||
@ -856,6 +856,7 @@ class TestModule(TestCase):
|
||||
@modules([module for module in module_db if not module.is_lazy])
|
||||
@parametrize('swap', [True, False])
|
||||
@parametrize('set_grad', [True, False])
|
||||
@wrapSwapTensorsTest()
|
||||
def test_to(self, device, dtype, module_info, training, swap, set_grad):
|
||||
module_cls = module_info.module_cls
|
||||
devices = ['cpu']
|
||||
@ -866,72 +867,69 @@ class TestModule(TestCase):
|
||||
requires_grad=False, training=training)
|
||||
torch.__future__.set_swap_module_params_on_conversion(swap)
|
||||
|
||||
try:
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
# Avoid using `module.to()` when constructing module since that is the method we are testing
|
||||
def _to(m, set_grad=False):
|
||||
for c in m.children():
|
||||
_to(c, set_grad=set_grad)
|
||||
for n, p in m.named_parameters(recurse=False):
|
||||
new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype))
|
||||
setattr(m, n, new_p)
|
||||
if set_grad:
|
||||
new_p.grad = torch.randn_like(new_p)
|
||||
for n, b in m.named_buffers(recurse=False):
|
||||
new_b = b.detach().clone().to(device, dtype)
|
||||
setattr(m, n, new_b)
|
||||
_to(m, set_grad=set_grad)
|
||||
|
||||
prev_device, prev_dtype = device, dtype
|
||||
for device_, dtype_ in product(devices, dtypes):
|
||||
# if device/dtype do not change, grad.to(device, dtype) is a no-op so
|
||||
# swapping will not change ._cdata
|
||||
# parameters will be wrapped in an nn.Parameter before swapping
|
||||
# which will cause the ._cdata to change
|
||||
g_no_swap = device_ == prev_device and dtype_ == prev_dtype
|
||||
prev_device, prev_dtype = device_, dtype_
|
||||
|
||||
p_ids_before = [id(p) for p in m.parameters()]
|
||||
p_cdatas_before = [p._cdata for p in m.parameters()]
|
||||
# Avoid using `module.to()` when constructing module since that is the method we are testing
|
||||
def _to(m, set_grad=False):
|
||||
for c in m.children():
|
||||
_to(c, set_grad=set_grad)
|
||||
for n, p in m.named_parameters(recurse=False):
|
||||
new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype))
|
||||
setattr(m, n, new_p)
|
||||
if set_grad:
|
||||
g_ids_before = [id(p.grad) for p in m.parameters()]
|
||||
g_cdatas_before = [p.grad._cdata for p in m.parameters()]
|
||||
new_p.grad = torch.randn_like(new_p)
|
||||
for n, b in m.named_buffers(recurse=False):
|
||||
new_b = b.detach().clone().to(device, dtype)
|
||||
setattr(m, n, new_b)
|
||||
_to(m, set_grad=set_grad)
|
||||
|
||||
m.to(device=device_, dtype=dtype_)
|
||||
prev_device, prev_dtype = device, dtype
|
||||
for device_, dtype_ in product(devices, dtypes):
|
||||
# if device/dtype do not change, grad.to(device, dtype) is a no-op so
|
||||
# swapping will not change ._cdata
|
||||
# parameters will be wrapped in an nn.Parameter before swapping
|
||||
# which will cause the ._cdata to change
|
||||
g_no_swap = device_ == prev_device and dtype_ == prev_dtype
|
||||
prev_device, prev_dtype = device_, dtype_
|
||||
|
||||
self.assertTrue(isinstance(p, torch.nn.Parameter) for p in m.parameters())
|
||||
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
p_cdatas_after = [p._cdata for p in m.parameters()]
|
||||
p_ids_before = [id(p) for p in m.parameters()]
|
||||
p_cdatas_before = [p._cdata for p in m.parameters()]
|
||||
if set_grad:
|
||||
g_ids_before = [id(p.grad) for p in m.parameters()]
|
||||
g_cdatas_before = [p.grad._cdata for p in m.parameters()]
|
||||
|
||||
m.to(device=device_, dtype=dtype_)
|
||||
|
||||
self.assertTrue(isinstance(p, torch.nn.Parameter) for p in m.parameters())
|
||||
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
p_cdatas_after = [p._cdata for p in m.parameters()]
|
||||
|
||||
if set_grad:
|
||||
self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters()))
|
||||
g_ids_after = [id(p.grad) for p in m.parameters()]
|
||||
g_cdatas_after = [p.grad._cdata for p in m.parameters()]
|
||||
|
||||
if swap:
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
if set_grad:
|
||||
self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters()))
|
||||
g_ids_after = [id(p.grad) for p in m.parameters()]
|
||||
g_cdatas_after = [p.grad._cdata for p in m.parameters()]
|
||||
|
||||
if swap:
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
if set_grad:
|
||||
self.assertTrue(
|
||||
all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after)))
|
||||
else:
|
||||
# id and _cdata remain the same --> .data setting
|
||||
self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
if set_grad:
|
||||
self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
|
||||
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
|
||||
finally:
|
||||
torch.__future__.set_swap_module_params_on_conversion(False)
|
||||
self.assertTrue(
|
||||
all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after)))
|
||||
else:
|
||||
# id and _cdata remain the same --> .data setting
|
||||
self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
if set_grad:
|
||||
self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after)))
|
||||
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
|
||||
|
Reference in New Issue
Block a user