Files
pytorch/torch/__future__.py
Will Feng 6b972795e4 Add torch.__future__._overwrite_module_params_on_conversion global flag, and check it in nn.Module._apply() (#21613)
Summary:
https://github.com/pytorch/pytorch/pull/17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# 6dc445e1a8/torch/nn/modules/module.py (L192-L208)
def _apply(self, fn):
    ...
    for param in self._parameters.values():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't
            # want to create copy nodes, so we have to unpack the data.
            param.data = fn(param.data)  # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
            if param._grad is not None:
                param._grad.data = fn(param._grad.data)  # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
   ...
```

yf225 TODO: fix the description here when we finish the implementation

To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.

We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.

This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.

[xla ci]

cc. resistor ailzhang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21613

Differential Revision: D15895387

Pulled By: yf225

fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
2019-06-19 10:30:02 -07:00

20 lines
813 B
Python

"""
This global flag controls whether to assign new tensors to the parameters
instead of changing the existing parameters in-place when converting an `nn.Module`
using the following methods:
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
4. `module._apply(fn)` (for generic functions applied to `module`)
Default: False
"""
_overwrite_module_params_on_conversion = False
def set_overwrite_module_params_on_conversion(value):
global _overwrite_module_params_on_conversion
_overwrite_module_params_on_conversion = value
def get_overwrite_module_params_on_conversion():
return _overwrite_module_params_on_conversion