mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Integrate swap_tensors into nn.Module.load_state_dict (#117913)
Added a `torch.Tensor` method that defines how to transform `other`, a value in the state dictionary, to be loaded into `self`, a param/buffer in an `nn.Module` before swapping via `torch.utils.swap_tensors` * `param.module_load(sd[key])` This method can be overridden using `__torch_function__`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117913 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7f82b7d62
commit
3372aa51b4
@ -4,7 +4,7 @@ _swap_module_params_on_conversion: bool = False
|
||||
|
||||
def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
||||
"""
|
||||
Sets whether to assign new tensors to the parameters instead of changing the
|
||||
Sets whether to assign new tensors to the parameters instead of changing the
|
||||
existing parameters in-place when converting an ``nn.Module``.
|
||||
|
||||
When enabled, the following methods will assign new parameters to the module:
|
||||
@ -14,6 +14,9 @@ def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
||||
(for converting a module to a different dtype)
|
||||
#. ``module.to()``
|
||||
|
||||
Args:
|
||||
value (bool): Whether to assign new tensors or not.
|
||||
|
||||
"""
|
||||
global _overwrite_module_params_on_conversion
|
||||
_overwrite_module_params_on_conversion = value
|
||||
@ -22,9 +25,9 @@ def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
||||
def get_overwrite_module_params_on_conversion() -> bool:
|
||||
"""
|
||||
Returns whether to assign new tensors to the parameters instead of changing the
|
||||
existing parameters in-place when converting an ``nn.Module`. Defaults to ``False``.
|
||||
existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
|
||||
|
||||
See :func:`~torch.nn.utils.set_overwrite_module_params_on_conversion` for more information.
|
||||
See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
|
||||
"""
|
||||
return _overwrite_module_params_on_conversion
|
||||
|
||||
@ -32,11 +35,12 @@ def get_overwrite_module_params_on_conversion() -> bool:
|
||||
def set_swap_module_params_on_conversion(value: bool) -> None:
|
||||
"""
|
||||
Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
|
||||
change the existing parameters in-place when converting an ``nn.Module``.
|
||||
change the existing parameters in-place when converting an ``nn.Module`` and instead
|
||||
of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
|
||||
|
||||
.. note::
|
||||
If :func:`~torch.__future__.get_overwrite_module_params_on_conversion` returns ``True``,
|
||||
no swapping will occur.
|
||||
for methods other than :meth:`~nn.Module.load_state_dict` no swapping will occur.
|
||||
|
||||
When enabled, the following methods will swap the existing parameters in-place:
|
||||
|
||||
@ -44,6 +48,18 @@ def set_swap_module_params_on_conversion(value: bool) -> None:
|
||||
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
|
||||
(for converting a module to a different dtype)
|
||||
#. ``module.to()``
|
||||
#. ``module.load_state_dict(state_dict)``
|
||||
|
||||
The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
|
||||
|
||||
#. For each parameter/buffer, its corresponding``state_dict['key']`` is transformed via
|
||||
:meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
|
||||
#. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
|
||||
#. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
|
||||
with ``res``
|
||||
|
||||
Args:
|
||||
value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
|
||||
|
||||
"""
|
||||
global _swap_module_params_on_conversion
|
||||
@ -53,8 +69,8 @@ def set_swap_module_params_on_conversion(value: bool) -> None:
|
||||
def get_swap_module_params_on_conversion() -> bool:
|
||||
"""
|
||||
Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
|
||||
change the existing parameters in-place when converting an nn.Module. Defaults to ``False``.
|
||||
change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
|
||||
|
||||
See :func:`~torch.nn.utils.set_swap_module_params_on_conversion` for more information.
|
||||
See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
|
||||
"""
|
||||
return _swap_module_params_on_conversion
|
||||
|
Reference in New Issue
Block a user