Add swap_tensors path to nn.Module._apply (#117167)

Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify  to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.

From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1.  The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](6cf1fc66e3/torch/csrc/autograd/variable.cpp (L307)). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**

***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.

If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.

**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
This commit is contained in:
Mikayla Gawarecki
2024-02-05 14:14:52 -08:00
committed by PyTorch MergeBot
parent 91d1d2c421
commit d5a718d27b
7 changed files with 208 additions and 30 deletions

View File

@ -1,21 +1,60 @@
"""
This global flag controls whether to assign new tensors to the parameters
instead of changing the existing parameters in-place when converting an `nn.Module`
using the following methods:
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
4. `module._apply(fn)` (for generic functions applied to `module`)
Default: False
"""
_overwrite_module_params_on_conversion = False
_overwrite_module_params_on_conversion: bool = False
_swap_module_params_on_conversion: bool = False
def set_overwrite_module_params_on_conversion(value):
def set_overwrite_module_params_on_conversion(value: bool) -> None:
"""
Sets whether to assign new tensors to the parameters instead of changing the
existing parameters in-place when converting an ``nn.Module``.
When enabled, the following methods will assign new parameters to the module:
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
(for converting a module to a different dtype)
#. ``module.to()``
"""
global _overwrite_module_params_on_conversion
_overwrite_module_params_on_conversion = value
def get_overwrite_module_params_on_conversion():
def get_overwrite_module_params_on_conversion() -> bool:
"""
Returns whether to assign new tensors to the parameters instead of changing the
existing parameters in-place when converting an ``nn.Module`. Defaults to ``False``.
See :func:`~torch.nn.utils.set_overwrite_module_params_on_conversion` for more information.
"""
return _overwrite_module_params_on_conversion
def set_swap_module_params_on_conversion(value: bool) -> None:
"""
Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
change the existing parameters in-place when converting an ``nn.Module``.
.. note::
If :func:`~torch.__future__.get_overwrite_module_params_on_conversion` returns ``True``,
no swapping will occur.
When enabled, the following methods will swap the existing parameters in-place:
#. ``module.{device}()`` (e.g. ``module.cuda()``) for moving a module between devices
#. ``module.{dtype}()`` (e.g. ``module.float()``) for converting a module to a different dtype
(for converting a module to a different dtype)
#. ``module.to()``
"""
global _swap_module_params_on_conversion
_swap_module_params_on_conversion = value
def get_swap_module_params_on_conversion() -> bool:
"""
Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
change the existing parameters in-place when converting an nn.Module. Defaults to ``False``.
See :func:`~torch.nn.utils.set_swap_module_params_on_conversion` for more information.
"""
return _swap_module_params_on_conversion