mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary:
https://github.com/pytorch/pytorch/pull/17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# 6dc445e1a8/torch/nn/modules/module.py (L192-L208)
def _apply(self, fn):
...
for param in self._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data) # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
if param._grad is not None:
param._grad.data = fn(param._grad.data) # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
...
```
yf225 TODO: fix the description here when we finish the implementation
To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.
We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.
This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.
[xla ci]
cc. resistor ailzhang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21613
Differential Revision: D15895387
Pulled By: yf225
fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
20 lines
813 B
Python
20 lines
813 B
Python
"""
|
|
This global flag controls whether to assign new tensors to the parameters
|
|
instead of changing the existing parameters in-place when converting an `nn.Module`
|
|
using the following methods:
|
|
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
|
|
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
|
|
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
|
|
4. `module._apply(fn)` (for generic functions applied to `module`)
|
|
|
|
Default: False
|
|
"""
|
|
_overwrite_module_params_on_conversion = False
|
|
|
|
def set_overwrite_module_params_on_conversion(value):
|
|
global _overwrite_module_params_on_conversion
|
|
_overwrite_module_params_on_conversion = value
|
|
|
|
def get_overwrite_module_params_on_conversion():
|
|
return _overwrite_module_params_on_conversion
|