mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Added `torch.__future__.{get/set}_swap_module_params_on_conversion` that defaults to `False` for now, but we probably want to modify to override this and default to `True` in `nn.Module._apply` if input is a tensor subclass.
From offline discussion, for now we are **not** allowing `swap_tensor` after the first module forward has been run*** if the autograd graph is still alive. The reason being that `torch.utils.swap_tensors(t1, t2)` requires the `use_count` of both `TensorImpl`s associated with `t1` and `t2` to be 1. The first forward pass will install `AccumulateGrad` nodes on each param, which [bump the refcount of the associated TensorImpl](6cf1fc66e3/torch/csrc/autograd/variable.cpp (L307)). **Future work might be to swap the refs that the `AccumulateGrad` nodes hold if it is necessary.**
***From this, it might seem like we don't need to handle gradients. However, I still handle the grads for the edge case that the grads are set via `p.grad = grad` OR the autograd graph is no longer alive because the output has been garbage collected.
If any `swap_tensors` fails on any of the parameters in the `nn.Module` we raise an error.
**`RNNBase` overrides `nn.Module._apply()` and installs weakrefs on some parameters. As a result, all modules that inherit from `RNNBase` (`RNN`, `GRU` and `LSTM`) cannot use the`swap_tensors` path as of now**
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117167
Approved by: https://github.com/albanD
ghstack dependencies: #118028
11 lines
354 B
ReStructuredText
11 lines
354 B
ReStructuredText
torch.__future__
|
|
===================================
|
|
|
|
.. automodule:: torch.__future__
|
|
.. currentmodule:: torch.__future__
|
|
|
|
.. autofunction:: set_overwrite_module_params_on_conversion
|
|
.. autofunction:: get_overwrite_module_params_on_conversion
|
|
.. autofunction:: set_swap_module_params_on_conversion
|
|
.. autofunction:: get_swap_module_params_on_conversion
|