Update nn.Module._apply to not gate on should_use_set_data when swap_tensors is set (#120659)

This updates the nesting of if statements in `nn.Module._apply` such that if

`torch.__future__.set_swap_module_params_on_conversion(True)`, we always try to swap regardless of whether
- `torch._has_compatible_shallow_copy_type(param, fn(param)`
- `torch.__future__.set_overwrite_module_params_on_conversion` is set

This means that `meta_module.to_empty('device')` can now use the swap_tensors path cc @awgu

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120659
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-02-26 14:20:21 -08:00
committed by PyTorch MergeBot
parent 213b3ac3f2
commit 677e67c399
4 changed files with 89 additions and 46 deletions

View File

@ -869,7 +869,6 @@ class TestModule(TestCase):
for module_input in module_inputs:
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
m = module_cls(*c_args, **c_kwargs)
@ -904,7 +903,7 @@ class TestModule(TestCase):
m.to(device=device_, dtype=dtype_)
self.assertTrue(isinstance(p, torch.nn.Parameter) for p in m.parameters())
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
self.assertTrue(all(p.device.type == device_ for p in m.parameters()))
self.assertTrue(all(p.dtype == dtype_ for p in m.parameters()))
p_ids_after = [id(p) for p in m.parameters()]
@ -932,6 +931,47 @@ class TestModule(TestCase):
self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after)))
@modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32])
@parametrize('swap', [True, False])
@wrapSwapTensorsTest()
def test_to_empty(self, device, dtype, module_info, swap, training):
module_cls = module_info.module_cls
with torch.device("meta"):
module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype,
requires_grad=False, training=training)
torch.__future__.set_swap_module_params_on_conversion(swap)
device_ = torch.device(device)
for module_input in module_inputs:
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
with torch.device("meta"):
m = module_cls(*c_args, **c_kwargs)
p_ids_before = [id(p) for p in m.parameters()]
p_cdatas_before = [p._cdata for p in m.parameters()]
m.to_empty(device=device_)
self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters()))
self.assertTrue(all(p.device == device_ for p in m.parameters()))
self.assertTrue(all(p.dtype == dtype for p in m.parameters()))
p_ids_after = [id(p) for p in m.parameters()]
p_cdatas_after = [p._cdata for p in m.parameters()]
if swap:
# id same, ._cdata differs --> swapped cdata of THPVariable
self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
else:
# id and ._cdata differ
# meta and device have different shallow copy types, so this will create a new
# parameter and assign it to the module
self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after)))
self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after)))
instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
if __name__ == '__main__':

View File

@ -9,10 +9,10 @@ def set_overwrite_module_params_on_conversion(value: bool) -> None:
When enabled, the following methods will assign new parameters to the module:
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
(for converting a module to a different dtype)
#. ``module.to()``
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
#. :meth:`nn.Module.to`
#. :meth:`nn.Module.to_empty`
Args:
value (bool): Whether to assign new tensors or not.
@ -25,7 +25,7 @@ def set_overwrite_module_params_on_conversion(value: bool) -> None:
def get_overwrite_module_params_on_conversion() -> bool:
"""
Returns whether to assign new tensors to the parameters instead of changing the
existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
"""
@ -39,20 +39,19 @@ def set_swap_module_params_on_conversion(value: bool) -> None:
of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
.. note::
If :func:`~torch.__future__.get_overwrite_module_params_on_conversion` returns ``True``,
for methods other than :meth:`~nn.Module.load_state_dict` no swapping will occur.
This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
When enabled, the following methods will swap the existing parameters in-place:
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
(for converting a module to a different dtype)
#. ``module.to()``
#. ``module.load_state_dict(state_dict)``
#. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
#. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
#. :meth:`nn.Module.to`
#. :meth:`nn.Module.to_empty`
#. :meth:`nn.Module.load_state_dict`
The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
#. For each parameter/buffer, its corresponding``state_dict['key']`` is transformed via
#. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
:meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
#. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
#. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`

View File

@ -803,23 +803,22 @@ class Module:
param_applied = fn(param)
p_should_use_set_data = compute_should_use_set_data(param, param_applied)
param_grad = param.grad
if p_should_use_set_data:
if should_use_swap_tensors:
try:
if param_grad is not None:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param.grad = None
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad)
torch.utils.swap_tensors(param, param_applied)
except Exception as e:
if param_grad is not None:
param.grad = param_grad
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
out_param = param
else:
param.data = param_applied
out_param = param
if should_use_swap_tensors:
try:
if param_grad is not None:
# Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
# Decrement use count of the gradient by setting to None
param.grad = None
param_applied = torch.nn.Parameter(param_applied, requires_grad=param.requires_grad)
torch.utils.swap_tensors(param, param_applied)
except Exception as e:
if param_grad is not None:
param.grad = param_grad
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}") from e
out_param = param
elif p_should_use_set_data:
param.data = param_applied
out_param = param
else:
assert isinstance(param, Parameter)
assert param.is_leaf
@ -830,17 +829,16 @@ class Module:
with torch.no_grad():
grad_applied = fn(param_grad)
g_should_use_set_data = compute_should_use_set_data(param_grad, grad_applied)
if g_should_use_set_data:
if should_use_swap_tensors:
grad_applied.requires_grad_(param_grad.requires_grad)
try:
torch.utils.swap_tensors(param_grad, grad_applied)
except Exception as e:
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e
out_param.grad = param_grad
else:
assert out_param.grad is not None
out_param.grad.data = grad_applied
if should_use_swap_tensors:
grad_applied.requires_grad_(param_grad.requires_grad)
try:
torch.utils.swap_tensors(param_grad, grad_applied)
except Exception as e:
raise RuntimeError(f"_apply(): Couldn't swap {self._get_name()}.{key}.grad") from e
out_param.grad = param_grad
elif g_should_use_set_data:
assert out_param.grad is not None
out_param.grad.data = grad_applied
else:
assert param_grad.is_leaf
out_param.grad = grad_applied.requires_grad_(param_grad.requires_grad)

View File

@ -4295,7 +4295,9 @@ module_db: List[ModuleInfo] = [
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
decorators=rnn_gru_lstm_module_info_decorators
),
ModuleInfo(torch.nn.GRU,
@ -4304,7 +4306,9 @@ module_db: List[ModuleInfo] = [
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.LSTM,
train_and_eval_differ=True,
@ -4314,7 +4318,9 @@ module_db: List[ModuleInfo] = [
# LSTM with projections is not currently supported with MPS
DecorateInfo(skipMPS),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap'])),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.ReflectionPad1d,
module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,