## Description:
This PR refactors the autocast context manager in `autocast_mode.py` to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping `device_supported_dtypes` is used to associate device types with their supported dtypes, and the validation logic is unified.
In my view, this makes the code easier to maintain and extend for new devices.
Please share any suggestions and comments with me.
BTW, in the original `xla` branch, the `supported_dtype` are `[torch.float16, torch.bfloat16]`, 5d8a226e23/torch/amp/autocast_mode.py (L358-L363) but the warning message has only `torch.bfloat16`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163446
Approved by: https://github.com/FFFrog, https://github.com/albanD
Print out amp target dtype and let custom backend easier find out expected dtype while integration.
## Test Result
### Before
```python
In [1]: import torch
...: import torch_openreg
...:
...: a = torch.randn(3, 4)
...: b = torch.randn(4, 2)
...: with torch.autocast("openreg", dtype=torch.float16):
...: torch.mm(a, b)
...:
/home/coder/code/pytorch/torch/amp/autocast_mode.py:332: UserWarning: In openreg autocast, but the target dtype is not supported. Disabling autocast.
openreg Autocast only supports dtypes of torch.float32 currently.
warnings.warn(error_message
```
### After
```python
In [1]: import torch
...: import torch_openreg
...:
...: a = torch.randn(3, 4)
...: b = torch.randn(4, 2)
...: with torch.autocast("openreg", dtype=torch.float16):
...: torch.mm(a, b)
...:
/home/coder/code/pytorch/torch/amp/autocast_mode.py:332: UserWarning: In openreg autocast, but the target dtype torch.float16 is not supported. Disabling autocast.
openreg Autocast only supports dtypes of torch.float32 currently.
warnings.warn(error_message)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162037
Approved by: https://github.com/zou3519
Re-raising of #129959 as that was closed.
Warning message before:
```
/home/admin/.local/share/hatch/env/virtual/toms-project-1/Qv9k_r_5/dev/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.
```
Warning message after:
```
/path/to/my/code:91: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.
```
Helps the user find where the issue stems from in their code. What do you think?
(Looks like "skip_file_prefixes" is not available until Python 3.12 minimum...)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155112
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This can save ~0.2ms on non cuda devices by skip calling `amp_definitely_not_available()`. It can improve small models in torchbench like lennard_jones on xpu 10% on both eager and inductor in dynamo benchmarks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151111
Approved by: https://github.com/soulitzer
Fixes#142397
Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests
Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.
By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`
This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Summary:
add testing for autocast and set_grad nodes for export_for_training. In export_for_training, we do not wrap the autocast and set_grad node in to HOP, but we should still have the set_grad_enabled/autocast nodes.
add support for autocast in non-strict export. Previously, `_enter_autocast` and `_exit_autocast` nodes don't show up in the export graph when we use `strict=False`.
- In autocast's enter and exit function, we dispatch to `PreDispatchTorchFunctionMode.__torch_function__`.
if we have PreDispatchTorchFunctionMode in our function_mode_stack, the call stack looks like below. This is mostly the same call stack as strict mode, except strict mode enters [here](https://www.internalfb.com/code/fbsource/[0d4f1135cacdb26c6e01d5dce1ce52a15d61ee48]/xplat/caffe2/torch/_dynamo/variables/ctx_manager.py?lines=806).
```
- torch.amp.autocast.__enter__()'s torch.overrides.handle_torch_function
- torch.fx.experimental.proxy_tensor.TorchFunctionMetadataMode.__torch_function__
- torch.amp._enter_autocast()'s torch.overrides.handle_torch_function
- PreDispatchTorchFunctionMode.__torch_function__
```
- in `PreDispatchTorchFunctionMode.__torch_function__`, we create the autocast nodes.
- to match the strict mode behavior, we let the input node to the `_exist_autocast` node be the corresponding `_enter_autocast` node. This requires us to maintain a stack in `PreDispatchTorchFunctionMode`.
Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_with_autocast
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_with_set_grad
```
Differential Revision: D64016023
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137495
Approved by: https://github.com/bdhirsh
# Motivation
As discussed in [#124479](https://github.com/pytorch/pytorch/pull/124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.
# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.
# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui
`torch.autocast` with `xla` backend has been restricted to `torch.bfloat16`. This shouldn't be the case anymore.
This works with `xla::cast( ..., type=f16)`
```
IR {
%0 = f32[] prim::Constant(), xla_shape=f32[], value=1
%1 = f32[3,2]{1,0} aten::expand(%0), xla_shape=f32[3,2]{1,0}, size=(3, 2), dynamic_dims=(0, 0)
%2 = f16[3,2]{1,0} xla::cast(%1), xla_shape=f16[3,2]{1,0}, type=f16, dtype=Half, stype=Float
%3 = f32[] prim::Constant(), xla_shape=f32[], value=1
%4 = f32[2,3]{1,0} aten::expand(%3), xla_shape=f32[2,3]{1,0}, size=(2, 3), dynamic_dims=(0, 0)
%5 = f16[2,3]{1,0} xla::cast(%4), xla_shape=f16[2,3]{1,0}, type=f16, dtype=Half, stype=Float
%6 = f16[2,2]{1,0} aten::mm(%5, %2), xla_shape=f16[2,2]{1,0}, ROOT=0
}
```
This will allow PyTorch/XLA to extend its autocast implementation to use `xla` backend for `float16` type as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109554
Approved by: https://github.com/JackCaoG, https://github.com/bdhirsh
**Summary**
Fix the https://github.com/pytorch/pytorch/issues/100565 by allowing float32 data type when Autocast CPU is disabled. Current behavior is:
- When autocast is disabled and user passes in float data type, it works well.
- When autocast is enabled and user passes in float data type, a warn message throws `UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.` to disable autocast automatically
**TestPlan**
```
python -u -m pytest -s -v test_autocast.py -k test_autocast_disabled_with_fp32_dtype
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107348
Approved by: https://github.com/jgong5, https://github.com/Neilblaze, https://github.com/albanD
This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
As part of this, a new `AutocastIPU` dispatch key has been added.
There's an existing PR, #85043, to make `Autocast` a proper per-backend functionality key, but it ran into issues with layering with other functionality keys and went stale.
This has been tested in the out-of-tree IPU PyTorch backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103890
Approved by: https://github.com/albanD
Fixes #ISSUE_NUMBER
1、optimize the func name of AMP in custom device module,use `torch.foo.set_autocast_enable` instead of `torch.foo.set_autocast_foo_enable`.
2、In AMP with custom device,use `custom_device_mod.set_autocast_enable` instead of `getattr(custom_device_mod, "set_autocast_enable"`, because we have check that `custom_device_mod` hasattr `set_autocast_enable` before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98052
Approved by: https://github.com/bdhirsh
I am trying to use bfloat16 AMP on a range of devices, using the `enabled` argument to actually enable/disable AMP, like this:
```python
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
```
However, this raises a RuntimeError even if enabled=False.
```
File "/venv/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 221, in __init__
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96097
Approved by: https://github.com/ngimel, https://github.com/kit1980
Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh
Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh