mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add swap_tensors path to nn.Module._apply (#117167)
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.
From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](6cf1fc66e3/torch/csrc/autograd/variable.cpp (L307)
). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**
***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.
If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.
**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
This commit is contained in:
committed by
PyTorch MergeBot
parent
91d1d2c421
commit
d5a718d27b
@ -1,21 +1,60 @@
|
||||
"""
|
||||
This global flag controls whether to assign new tensors to the parameters
|
||||
instead of changing the existing parameters in-place when converting an `nn.Module`
|
||||
using the following methods:
|
||||
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
|
||||
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
|
||||
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
|
||||
4. `module._apply(fn)` (for generic functions applied to `module`)
|
||||
|
||||
Default: False
|
||||
"""
|
||||
_overwrite_module_params_on_conversion = False
|
||||
_overwrite_module_params_on_conversion: bool = False
|
||||
_swap_module_params_on_conversion: bool = False
|
||||
|
||||
|
||||
def set_overwrite_module_params_on_conversion(value):
|
||||
def set_overwrite_module_params_on_conversion(value: bool) -> None:
|
||||
"""
|
||||
Sets whether to assign new tensors to the parameters instead of changing the
|
||||
existing parameters in-place when converting an ``nn.Module``.
|
||||
|
||||
When enabled, the following methods will assign new parameters to the module:
|
||||
|
||||
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
|
||||
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
|
||||
(for converting a module to a different dtype)
|
||||
#. ``module.to()``
|
||||
|
||||
"""
|
||||
global _overwrite_module_params_on_conversion
|
||||
_overwrite_module_params_on_conversion = value
|
||||
|
||||
|
||||
def get_overwrite_module_params_on_conversion():
|
||||
def get_overwrite_module_params_on_conversion() -> bool:
|
||||
"""
|
||||
Returns whether to assign new tensors to the parameters instead of changing the
|
||||
existing parameters in-place when converting an ``nn.Module`. Defaults to ``False``.
|
||||
|
||||
See :func:`~torch.nn.utils.set_overwrite_module_params_on_conversion` for more information.
|
||||
"""
|
||||
return _overwrite_module_params_on_conversion
|
||||
|
||||
|
||||
def set_swap_module_params_on_conversion(value: bool) -> None:
|
||||
"""
|
||||
Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
|
||||
change the existing parameters in-place when converting an ``nn.Module``.
|
||||
|
||||
.. note::
|
||||
If :func:`~torch.__future__.get_overwrite_module_params_on_conversion` returns ``True``,
|
||||
no swapping will occur.
|
||||
|
||||
When enabled, the following methods will swap the existing parameters in-place:
|
||||
|
||||
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
|
||||
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
|
||||
(for converting a module to a different dtype)
|
||||
#. ``module.to()``
|
||||
|
||||
"""
|
||||
global _swap_module_params_on_conversion
|
||||
_swap_module_params_on_conversion = value
|
||||
|
||||
|
||||
def get_swap_module_params_on_conversion() -> bool:
|
||||
"""
|
||||
Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
|
||||
change the existing parameters in-place when converting an nn.Module. Defaults to ``False``.
|
||||
|
||||
See :func:`~torch.nn.utils.set_swap_module_params_on_conversion` for more information.
|
||||
"""
|
||||
return _swap_module_params_on_conversion
|
||||
|
Reference in New Issue
Block a user