60 Commits

Author SHA1 Message Date
9420944033 Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)"
This reverts commit 960b0d5f0d0efb1f1962bddcf62e2a698e26edd2.

Reverted https://github.com/pytorch/pytorch/pull/163446 on behalf of https://github.com/izaitsevfb due to breaks autocast tests on linux and mac ([comment](https://github.com/pytorch/pytorch/pull/163446#issuecomment-3390688642))
2025-10-10 15:12:46 +00:00
960b0d5f0d [AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)
## Description:

This PR refactors the autocast context manager in `autocast_mode.py` to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping `device_supported_dtypes` is used to associate device types with their supported dtypes, and the validation logic is unified.

In my view, this makes the code easier to maintain and extend for new devices.

Please share any suggestions and comments with me.

BTW, in the original `xla` branch, the `supported_dtype` are `[torch.float16, torch.bfloat16]`, 5d8a226e23/torch/amp/autocast_mode.py (L358-L363) but the warning message has only `torch.bfloat16`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163446
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-10 12:30:06 +00:00
5f18f240de Add initial suppressions for pyrefly (#164177)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
`python3 scripts/lintrunner.py`
`pyrefly check`

---

Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60
After:

```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,063 ignored)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177
Approved by: https://github.com/Lucaskabela
2025-10-02 20:57:41 +00:00
d8cbbc0f70 [Easy][AMP] Refactor the AMP logic for getting dtype (#162796)
As the title stated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162796
Approved by: https://github.com/ezyang
2025-09-21 06:32:35 +00:00
8fd3c9ce91 Optimize AMP custom_backend_name error message (#162037)
Print out amp target dtype and let custom backend easier find out expected dtype while integration.

## Test Result

### Before
```python
In [1]: import torch
   ...: import torch_openreg
   ...:
   ...: a = torch.randn(3, 4)
   ...: b = torch.randn(4, 2)
   ...: with torch.autocast("openreg", dtype=torch.float16):
   ...:     torch.mm(a, b)
   ...:
/home/coder/code/pytorch/torch/amp/autocast_mode.py:332: UserWarning: In openreg autocast, but the target dtype is not supported. Disabling autocast.
 openreg Autocast only supports dtypes of torch.float32 currently.
  warnings.warn(error_message
```

### After
```python
In [1]: import torch
   ...: import torch_openreg
   ...:
   ...: a = torch.randn(3, 4)
   ...: b = torch.randn(4, 2)
   ...: with torch.autocast("openreg", dtype=torch.float16):
   ...:     torch.mm(a, b)
   ...:

/home/coder/code/pytorch/torch/amp/autocast_mode.py:332: UserWarning: In openreg autocast, but the target dtype torch.float16 is not supported. Disabling autocast.
 openreg Autocast only supports dtypes of torch.float32 currently.
  warnings.warn(error_message)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162037
Approved by: https://github.com/zou3519
2025-09-04 08:27:56 +00:00
2ac45c2752 Fix autocast context manager when there is exception (#159565)
Summary: When exception occurs inside context manager, we need to either return False OR properly propagage exceptions via __exit__(exc_type, exc_val). But previously while tracing, we don't actually run the exit node so we end up swallowing the exception in a very weird way as outlined in https://github.com/pytorch/pytorch/issues/153202. This PR fixes it

Test Plan:
new test case

Rollback Plan:

Differential Revision: D79348382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159565
Approved by: https://github.com/zou3519, https://github.com/yushangdi
2025-08-01 02:12:24 +00:00
aa11628576 Issue warning with reference to user code rather than torch (#155112)
Re-raising of #129959 as that was closed.

Warning message before:
```
/home/admin/.local/share/hatch/env/virtual/toms-project-1/Qv9k_r_5/dev/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py:120: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
```

Warning message after:
```
/path/to/my/code:91: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.
```

Helps the user find where the issue stems from in their code. What do you think?

(Looks like "skip_file_prefixes" is not available until Python 3.12 minimum...)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155112
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-07-14 05:24:23 +00:00
3fd84a8592 [BE][PYFMT] migrate PYFMT for torch/[a-c]*/ to ruff format (#144554)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144554
Approved by: https://github.com/soulitzer
2025-07-03 18:56:07 +00:00
b59f3d3ae0 [Intel GPU] skip a cuda api call in amp to save some host overhead on xpu (#151111)
This can save ~0.2ms on non cuda devices by skip calling `amp_definitely_not_available()`. It can improve small models in torchbench like lennard_jones on xpu 10% on both eager and inductor in dynamo benchmarks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151111
Approved by: https://github.com/soulitzer
2025-04-13 06:37:07 +00:00
49f6cce736 [MPS] grad scaler (#150255)
Fixes #142397

Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests

Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.

By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`

This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-06 17:06:55 +00:00
bca75fe97a [MAIA] [Autocast] Enable autocast on MAIA device (#148511)
Fixes #148510.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148511
Approved by: https://github.com/albanD
2025-03-18 03:46:22 +00:00
6939a56e13 [autocast][pytorch] Support autocast for MTIA (#145627)
Summary: Add autocast support to MTIA

Reviewed By: egienvalue

Differential Revision: D68572548

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145627
Approved by: https://github.com/egienvalue
2025-01-25 03:24:59 +00:00
f2cfe8b59f PEP585 update - mostly toplevels (#145178)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
2025-01-22 02:21:14 +00:00
bc69a19139 [MPS] Add support for bf16 autocast (#139390)
This PR adds support for bf16 autocast. Most of the code and ideas are copied from #99272.

Most of the heavy lifting was done by AI.

Fixes #139386

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139390
Approved by: https://github.com/malfet

Co-authored-by: Kulin Seth <kulin_seth@apple.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2024-11-20 19:52:28 +00:00
4b83302585 [MPS] Update error message for supported autocast type (#139192)
Autocast in MPS currently only supports dtype of `torch.float16`. This PR updates the error message to reflect this.

This PR was created using [Copilot Workspace](https://copilot-workspace.githubnext.com/pytorch/pytorch/issues/139190?shareId=5b510fda-380c-4e86-8e91-6b67a078f180) with no human input other than clicking buttons.

Fixes #139190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139192
Approved by: https://github.com/malfet
2024-10-30 16:48:29 +00:00
a47bb4a393 Fix autocast for non-strict export (#137495)
Summary:

add testing for autocast and set_grad nodes for export_for_training. In export_for_training, we do not wrap the autocast and set_grad node in to HOP, but we should still have the set_grad_enabled/autocast nodes.

add support for autocast in non-strict export. Previously, `_enter_autocast` and `_exit_autocast` nodes don't show up in the export graph when we use `strict=False`.

- In autocast's enter and exit function, we dispatch to `PreDispatchTorchFunctionMode.__torch_function__`.
 if we have PreDispatchTorchFunctionMode in our function_mode_stack, the call stack looks like below. This is mostly the same call stack as strict mode, except strict mode enters [here](https://www.internalfb.com/code/fbsource/[0d4f1135cacdb26c6e01d5dce1ce52a15d61ee48]/xplat/caffe2/torch/_dynamo/variables/ctx_manager.py?lines=806).
```
- torch.amp.autocast.__enter__()'s torch.overrides.handle_torch_function
- torch.fx.experimental.proxy_tensor.TorchFunctionMetadataMode.__torch_function__
- torch.amp._enter_autocast()'s torch.overrides.handle_torch_function
- PreDispatchTorchFunctionMode.__torch_function__
```
- in `PreDispatchTorchFunctionMode.__torch_function__`, we create the autocast nodes.
- to match the strict mode behavior, we let the input node to the `_exist_autocast` node be the corresponding `_enter_autocast` node. This requires us to maintain a stack in `PreDispatchTorchFunctionMode`.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_with_autocast
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_with_set_grad
```

Differential Revision: D64016023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137495
Approved by: https://github.com/bdhirsh
2024-10-16 17:39:00 +00:00
144fde4fd2 [MPS] Add support for autocast in MPS (#99272)
Fixes https://github.com/pytorch/pytorch/issues/88415

Need to run inductor/test_cpu_select_algorithm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: Roy Hvaara <roy@lightyear.no>
2024-09-05 23:23:17 +00:00
2764bee942 Revert "[MPS] Add support for autocast in MPS (#99272)"
This reverts commit 6919e8baaba391ced7b4acaa553d6ea1f3b30e79.

Reverted https://github.com/pytorch/pytorch/pull/99272 on behalf of https://github.com/clee2000 due to Broke test/inductor/test_cpu_select_algorithm.py::TestSelectAlgorithmCPU::test_quantized_linear_amx_batch_size_3_in_features_128_out_features_64_bias_False_cpu on sm86 jobs [GH job link](https://github.com/pytorch/pytorch/actions/runs/10252979157/job/28367091621) [HUD commit link](6919e8baab) Not caught on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/99272#issuecomment-2269808857))
2024-08-05 19:59:04 +00:00
6919e8baab [MPS] Add support for autocast in MPS (#99272)
Fixes https://github.com/pytorch/pytorch/issues/88415

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet
2024-08-05 17:02:30 +00:00
07450e9713 Revert "[MPS] Add support for autocast in MPS (#99272)"
This reverts commit 6240cfd5c751bea6ca91dc765085e1d871b22345.

Reverted https://github.com/pytorch/pytorch/pull/99272 on behalf of https://github.com/jeanschmidt due to introduced breakages in trunk ([comment](https://github.com/pytorch/pytorch/pull/99272#issuecomment-2203033719))
2024-07-02 12:29:51 +00:00
6240cfd5c7 [MPS] Add support for autocast in MPS (#99272)
Fixes https://github.com/pytorch/pytorch/issues/88415

Co-authored-by: Siddharth Kotapati <skotapati@apple.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99272
Approved by: https://github.com/malfet
2024-07-02 01:49:52 +00:00
d80939e5e9 [BE] enable UFMT for torch/storage.py (#127706)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127706
Approved by: https://github.com/ezyang
2024-06-27 23:16:24 +00:00
305ba62906 Add support to GradScaler for respecting an already set grad_scale value (#123429)
Fixes #123428

Co-authored-by: Yousuf Mohamed-Ahmed <youmed.tech@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123429
Approved by: https://github.com/ezyang
2024-06-27 22:40:54 +00:00
5fba5d83f0 add xpu for amp (#127276)
As support for Intel GPU has been upstreamed, this PR is to add the XPU-related contents to AMP doc.

Co-authored-by: Yu, Guangye <guangye.yu@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127276
Approved by: https://github.com/dvrogozh, https://github.com/albanD, https://github.com/malfet
2024-06-20 21:49:35 +00:00
afe15d2d2f Flip default value for mypy disallow_untyped_defs [3/11] (#127840)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127840
Approved by: https://github.com/oulgen
2024-06-08 18:28:01 +00:00
e7a42702f9 generalize custom_fwd&custom_bwd to be device-agnostic (#126531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126531
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD, https://github.com/EikanWang
ghstack dependencies: #126527
2024-05-25 06:48:16 +00:00
d17be10df1 make torch.amp.autocast more generic (#125103)
# Motivation
As discussed in [#124479](https://github.com/pytorch/pytorch/pull/124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui
2024-05-08 12:13:26 +00:00
6761b49551 Ensure autocast device_type is a string + Unit test (#125014)
Reviving #124873 (already approved) to resolve CLA issues

Fixes #124738

(Marked as draft until I get local unit tests to run)

Edit: Tests passing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125014
Approved by: https://github.com/mikaylagawarecki, https://github.com/soulitzer
2024-04-28 16:27:30 +00:00
19a83eacb5 add new API torch.amp.is_autocast_available (#124938)
# Motivation
expose `torch._is_autocast_available` to `torch.amp.is_autocast_available` as a public api.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124938
Approved by: https://github.com/albanD
2024-04-26 08:45:20 +00:00
cdc66e9dc3 refactor autocast python APIs (#124479)
# Motivation
Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`.

# Solution
Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #124359
2024-04-25 14:33:33 +00:00
9c4fc5fa34 [BE][Ez]: Fix minor potential perf regression from #123960 (#124013)
The `non_blocking` arg here is useless if the values are all eagerly consumed, so revert the change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124013
Approved by: https://github.com/ezyang
2024-04-15 16:51:45 +00:00
1d6c5972c1 [BE]: Optimize min/max/sum comprehensions C419 (#123960)
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
2024-04-12 23:54:15 +00:00
354a436d96 Remove device assert in Gradscaler (#119362)
Fixes #119358

Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Co-authored-by: ydwu4 <ydwu2014@gmail.com>
Co-authored-by: PyTorch UpdateBot <pytorchupdatebot@users.noreply.github.com>
Co-authored-by: Bin Bao <binbao@meta.com>
Co-authored-by: Shuqiang Zhang <sqzhang@meta.com>
Co-authored-by: Adnan Akhundov <aakhundov@meta.com>
Co-authored-by: Ting Lu <tingl@nvidia.com>
Co-authored-by: Yang Chen <yangche@fb.com>
Co-authored-by: cyy <cyyever@outlook.com>
Co-authored-by: Animesh Jain <anijain@umich.edu>
Co-authored-by: Jason Ansel <jansel@meta.com>
Co-authored-by: Eddie Yan <eddiey@nvidia.com>
Co-authored-by: wz337 <wz337@cornell.edu>
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
Co-authored-by: Anthony Alayo <anthony.alayo@applovin.com>
Co-authored-by: leslie-fang-intel <leslie.fang@intel.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Yukio Siraichi <yukio.siraichi@gmail.com>
Co-authored-by: atalman <atalman@fb.com>
Co-authored-by: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Co-authored-by: haozhe.zhu <haozhe.zhu@intel.com>
Co-authored-by: lezcano <lezcano-93@hotmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119362
Approved by: https://github.com/ezyang
2024-02-22 08:02:18 +00:00
bacbad5bc9 add GradScaler on CPU (#109993)
Step 2 of https://github.com/pytorch/pytorch/issues/111559.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109993
Approved by: https://github.com/jgong5, https://github.com/ezyang
2024-01-29 23:42:35 +00:00
c47d2b8035 Add Half support for CPU autocast on eager mode (#112484)
Add Half support for CPU autocast on eager mode since common operators have Half support on CPU.
https://github.com/pytorch/pytorch/issues/96093.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112484
Approved by: https://github.com/leslie-fang-intel, https://github.com/ezyang
2023-11-21 20:08:28 +00:00
e2e9d15726 Unblock float16 dtype for xla autocasting (#109554)
`torch.autocast` with `xla` backend has been restricted to `torch.bfloat16`. This shouldn't be the case anymore.

This works with `xla::cast( ..., type=f16)`
```
IR {
  %0 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %1 = f32[3,2]{1,0} aten::expand(%0), xla_shape=f32[3,2]{1,0}, size=(3, 2), dynamic_dims=(0, 0)
  %2 = f16[3,2]{1,0} xla::cast(%1), xla_shape=f16[3,2]{1,0}, type=f16, dtype=Half, stype=Float
  %3 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %4 = f32[2,3]{1,0} aten::expand(%3), xla_shape=f32[2,3]{1,0}, size=(2, 3), dynamic_dims=(0, 0)
  %5 = f16[2,3]{1,0} xla::cast(%4), xla_shape=f16[2,3]{1,0}, type=f16, dtype=Half, stype=Float
  %6 = f16[2,2]{1,0} aten::mm(%5, %2), xla_shape=f16[2,2]{1,0}, ROOT=0
}
```

This will allow PyTorch/XLA to extend its autocast implementation to use `xla` backend for `float16` type as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109554
Approved by: https://github.com/JackCaoG, https://github.com/bdhirsh
2023-09-21 03:19:44 +00:00
ee0e04ac48 Allow float dtype when Autocast CPU Disabled (#107348)
**Summary**
Fix the https://github.com/pytorch/pytorch/issues/100565 by allowing float32 data type when Autocast CPU is disabled. Current behavior is:
- When autocast is disabled and user passes in float data type, it works well.
- When autocast is enabled and user passes in float data type, a warn message throws `UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.` to disable autocast automatically

**TestPlan**

```
python -u -m pytest -s -v test_autocast.py -k test_autocast_disabled_with_fp32_dtype
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107348
Approved by: https://github.com/jgong5, https://github.com/Neilblaze, https://github.com/albanD
2023-09-01 00:49:44 +00:00
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
5afc2f5069 Documentation for torch.autocast (#95760)
- [x] Corrected examples for CUDA devices.
- [x] Information about availability of `torch.autocast`.

Fixes #95547

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95760
Approved by: https://github.com/leslie-fang-intel, https://github.com/kit1980
2023-07-22 03:56:34 +00:00
875f60399e pre_dispatch tracing: support autocast and no_grad/enable_grad ctx managers, add a pre_dispatch_eager dynamo backend (#103024)
This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
2023-06-29 14:17:42 +00:00
41866a2ead Fix missing mandatory device_type argument in autocast docstring (#97223)
Fixes #[92803](https://github.com/pytorch/pytorch/issues/92803)
![Screenshot from 2023-03-21 12-28-14](https://user-images.githubusercontent.com/100136654/226538769-141f3b9e-0de2-4e86-8e42-d5a4a7413c6f.png)
![Screenshot from 2023-03-21 12-28-29](https://user-images.githubusercontent.com/100136654/226538777-9e719090-75c0-46f7-8594-5efcb0a46df6.png)
![Screenshot from 2023-03-21 12-29-36](https://user-images.githubusercontent.com/100136654/226538783-399a0e60-ffc9-4d73-801c-8cfce366d142.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97223
Approved by: https://github.com/albanD, https://github.com/malfet
2023-06-27 01:54:54 +00:00
6ff4548b6e [AMP] Support XLA:TPU (#96370)
With https://github.com/pytorch/xla/pull/5148, https://github.com/pytorch/xla/pull/4740

With these changes
XLA:GPU users should use `torch.cuda.amp.autocast()` for AMP with float16
XLA:TPU users should use `torch.amp.autocast('xla')` for AMP with bfloat16

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96370
Approved by: https://github.com/bdhirsh, https://github.com/malfet
2023-06-23 19:46:42 +00:00
5eb7325bc7 Add autocast support for IPU (#103890)
As part of this, a new `AutocastIPU` dispatch key has been added.

There's an existing PR, #85043, to make `Autocast` a proper per-backend functionality key, but it ran into issues with layering with other functionality keys and went stale.

This has been tested in the out-of-tree IPU PyTorch backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103890
Approved by: https://github.com/albanD
2023-06-22 15:38:45 +00:00
faa7eb81c6 change error_message for XPU Autocast data type check (#102073)
XPU autocast supports bf16 and fp16 data types, we are going to change the error_message for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102073
Approved by: https://github.com/jgong5
2023-05-24 08:36:43 +00:00
48463f687a refactor macro with AMP (#99285)
Fixes #ISSUE_NUMBER
as the tiltle, optimize the macro with AMP and put the macro in `.hpp` file, so that we can use it for custom device.  @bdhirsh  @albanD
as we talked at this discuss, optimize the macros so that we can add a new macro for other devide, and move these macros to `.hpp` so that we can include these macros with custom device to configure the ops.
https://dev-discuss.pytorch.org/t/improve-the-extension-with-privateuse1-for-custom-device/1196/7

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99285
Approved by: https://github.com/albanD, https://github.com/bdhirsh
2023-04-19 01:00:00 +00:00
d03799f9a5 optimize the AMP func name in custom_device_mod (#98052)
Fixes #ISSUE_NUMBER
1、optimize the func name of AMP in custom device module,use `torch.foo.set_autocast_enable` instead of `torch.foo.set_autocast_foo_enable`.
2、In AMP with custom device,use `custom_device_mod.set_autocast_enable` instead of `getattr(custom_device_mod,  "set_autocast_enable"`, because we have check that `custom_device_mod` hasattr `set_autocast_enable` before.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98052
Approved by: https://github.com/bdhirsh
2023-03-31 17:04:32 +00:00
ee6b19bd4c Error only if autocast actually enabled (#96097)
I am trying to use bfloat16 AMP on a range of devices, using the `enabled` argument to actually enable/disable AMP, like this:
```python
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
```

However, this raises a RuntimeError even if enabled=False.

```
  File "/venv/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 221, in __init__
    raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96097
Approved by: https://github.com/ngimel, https://github.com/kit1980
2023-03-21 03:13:13 +00:00
6b691b99da add amp support for custom backend (#96188)
Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh
2023-03-20 20:27:35 +00:00
a8f36dd646 Revert "add amp support for custom backend (#96188)"
This reverts commit cf12edee02a44009c4f06e36efa97d9a7372ab35.

Reverted https://github.com/pytorch/pytorch/pull/96188 on behalf of https://github.com/kit1980 due to Broke some linalg tests : https://github.com/pytorch/pytorch/actions/runs/4420037607/jobs/7750708339
2023-03-15 00:03:19 +00:00
cf12edee02 add amp support for custom backend (#96188)
Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh
2023-03-14 20:43:21 +00:00