mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update nn.Module._apply to not gate on should_use_set_data when swap_tensors is set (#120659)
This updates the nesting of if statements in `nn.Module._apply` such that if `torch.__future__.set_swap_module_params_on_conversion(True)`, we always try to swap regardless of whether - `torch._has_compatible_shallow_copy_type(param, fn(param)` - `torch.__future__.set_overwrite_module_params_on_conversion` is set This means that `meta_module.to_empty('device')` can now use the swap_tensors path cc @awgu Pull Request resolved: https://github.com/pytorch/pytorch/pull/120659 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
213b3ac3f2
commit
677e67c399
@ -869,7 +869,6 @@ class TestModule(TestCase):
|
||||
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
@ -904,7 +903,7 @@ class TestModule(TestCase):
|
||||
|
||||
m.to(device=device_, dtype=dtype_)
|
||||
|
||||
self.assertTrue(isinstance(p, torch.nn.Parameter) for p in m.parameters())
|
||||
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
|
||||
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
@ -932,6 +931,47 @@ class TestModule(TestCase):
|
||||
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
|
||||
|
||||
|
||||
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
|
||||
@parametrize('swap', [True, False])
|
||||
@wrapSwapTensorsTest()
|
||||
def test_to_empty(self, device, dtype, module_info, swap, training):
|
||||
module_cls = module_info.module_cls
|
||||
|
||||
with torch.device("meta"):
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
|
||||
requires_grad=False, training=training)
|
||||
|
||||
torch.__future__.set_swap_module_params_on_conversion(swap)
|
||||
device_ = torch.device(device)
|
||||
|
||||
for module_input in module_inputs:
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
with torch.device("meta"):
|
||||
m = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
p_ids_before = [id(p) for p in m.parameters()]
|
||||
p_cdatas_before = [p._cdata for p in m.parameters()]
|
||||
m.to_empty(device=device_)
|
||||
|
||||
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
|
||||
self.assertTrue(all(p.device == device_ for p in m.parameters()))
|
||||
self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
|
||||
p_ids_after = [id(p) for p in m.parameters()]
|
||||
p_cdatas_after = [p._cdata for p in m.parameters()]
|
||||
|
||||
if swap:
|
||||
# id same, ._cdata differs --> swapped cdata of THPVariable
|
||||
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
else:
|
||||
# id and ._cdata differ
|
||||
# meta and device have different shallow copy types, so this will create a new
|
||||
# parameter and assign it to the module
|
||||
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
|
||||
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -9,10 +9,10 @@ def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
||||
|
||||
When enabled, the following methods will assign new parameters to the module:
|
||||
|
||||
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
|
||||
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
|
||||
(for converting a module to a different dtype)
|
||||
#. ``module.to()``
|
||||
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
|
||||
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
|
||||
#. :meth:`nn.Module.to`
|
||||
#. :meth:`nn.Module.to_empty`
|
||||
|
||||
Args:
|
||||
value (bool): Whether to assign new tensors or not.
|
||||
@ -25,7 +25,7 @@ def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
||||
def get_overwrite_module_params_on_conversion() -> bool:
|
||||
"""
|
||||
Returns whether to assign new tensors to the parameters instead of changing the
|
||||
existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
|
||||
existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
|
||||
|
||||
See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
|
||||
"""
|
||||
@ -39,20 +39,19 @@ def set_swap_module_params_on_conversion(value: bool) -> None:
|
||||
of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
|
||||
|
||||
.. note::
|
||||
If :func:`~torch.__future__.get_overwrite_module_params_on_conversion` returns ``True``,
|
||||
for methods other than :meth:`~nn.Module.load_state_dict` no swapping will occur.
|
||||
This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
|
||||
|
||||
When enabled, the following methods will swap the existing parameters in-place:
|
||||
|
||||
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
|
||||
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
|
||||
(for converting a module to a different dtype)
|
||||
#. ``module.to()``
|
||||
#. ``module.load_state_dict(state_dict)``
|
||||
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
|
||||
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
|
||||
#. :meth:`nn.Module.to`
|
||||
#. :meth:`nn.Module.to_empty`
|
||||
#. :meth:`nn.Module.load_state_dict`
|
||||
|
||||
The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
|
||||
|
||||
#. For each parameter/buffer, its corresponding``state_dict['key']`` is transformed via
|
||||
#. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
|
||||
:meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
|
||||
#. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
|
||||
#. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
|
||||
|
@ -803,23 +803,22 @@ class Module:
|
||||
param_applied = fn(param)
|
||||
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
|
||||
param_grad = param.grad
|
||||
if p_should_use_set_data:
|
||||
if should_use_swap_tensors:
|
||||
try:
|
||||
if param_grad is not None:
|
||||
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
|
||||
# Decrement use count of the gradient by setting to None
|
||||
param.grad = None
|
||||
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad)
|
||||
torch.utils.swap_tensors(param, param_applied)
|
||||
except Exception as e:
|
||||
if param_grad is not None:
|
||||
param.grad = param_grad
|
||||
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
|
||||
out_param = param
|
||||
else:
|
||||
param.data = param_applied
|
||||
out_param = param
|
||||
if should_use_swap_tensors:
|
||||
try:
|
||||
if param_grad is not None:
|
||||
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
|
||||
# Decrement use count of the gradient by setting to None
|
||||
param.grad = None
|
||||
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad)
|
||||
torch.utils.swap_tensors(param, param_applied)
|
||||
except Exception as e:
|
||||
if param_grad is not None:
|
||||
param.grad = param_grad
|
||||
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
|
||||
out_param = param
|
||||
elif p_should_use_set_data:
|
||||
param.data = param_applied
|
||||
out_param = param
|
||||
else:
|
||||
assert isinstance(param, Parameter)
|
||||
assert param.is_leaf
|
||||
@ -830,17 +829,16 @@ class Module:
|
||||
with torch.no_grad():
|
||||
grad_applied = fn(param_grad)
|
||||
g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied)
|
||||
if g_should_use_set_data:
|
||||
if should_use_swap_tensors:
|
||||
grad_applied.requires_grad_(param_grad.requires_grad)
|
||||
try:
|
||||
torch.utils.swap_tensors(param_grad, grad_applied)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e
|
||||
out_param.grad = param_grad
|
||||
else:
|
||||
assert out_param.grad is not None
|
||||
out_param.grad.data = grad_applied
|
||||
if should_use_swap_tensors:
|
||||
grad_applied.requires_grad_(param_grad.requires_grad)
|
||||
try:
|
||||
torch.utils.swap_tensors(param_grad, grad_applied)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e
|
||||
out_param.grad = param_grad
|
||||
elif g_should_use_set_data:
|
||||
assert out_param.grad is not None
|
||||
out_param.grad.data = grad_applied
|
||||
else:
|
||||
assert param_grad.is_leaf
|
||||
out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad)
|
||||
|
@ -4295,7 +4295,9 @@ module_db: List[ModuleInfo] = [
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
|
||||
skips=(
|
||||
# RNNBase overrides `_apply` and adds weakrefs to params
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
|
||||
# RNNBase overrides `_apply` and adds weakrefs to params
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
|
||||
decorators=rnn_gru_lstm_module_info_decorators
|
||||
),
|
||||
ModuleInfo(torch.nn.GRU,
|
||||
@ -4304,7 +4306,9 @@ module_db: List[ModuleInfo] = [
|
||||
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
|
||||
skips=(
|
||||
# RNNBase overrides `_apply` and adds weakrefs to params
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
|
||||
# RNNBase overrides `_apply` and adds weakrefs to params
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
|
||||
decorators=rnn_gru_lstm_module_info_decorators),
|
||||
ModuleInfo(torch.nn.LSTM,
|
||||
train_and_eval_differ=True,
|
||||
@ -4314,7 +4318,9 @@ module_db: List[ModuleInfo] = [
|
||||
# LSTM with projections is not currently supported with MPS
|
||||
DecorateInfo(skipMPS),
|
||||
# RNNBase overrides `_apply` and adds weakrefs to params
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap'])),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
|
||||
# RNNBase overrides `_apply` and adds weakrefs to params
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
|
||||
decorators=rnn_gru_lstm_module_info_decorators),
|
||||
ModuleInfo(torch.nn.ReflectionPad1d,
|
||||
module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
|
||||
|
Reference in New Issue
Block a user