Refactor autocast C++ APIs to be device-agnostic (#124359)

# Motivation
This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast  **C++** APIs.
In C++ side,
- `is_enabled()` -> `is_enabled(device_type)`.
- `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`.
- `get_autocast_dtype()` -> `get_autocast_dtype(device_type)`
- `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)`

These following C++ APIs are deprecated and should be removed in PyTorch 2.5
- `is_cpu_enabled`
- `set_cpu_enabled`
- `get_autocast_cpu_dtype`
- `set_autocast_cpu_dtype`
- `is_xpu_enabled`
- `set_xpu_enabled`
- `get_autocast_xpu_dtype`
- `set_autocast_xpu_dtype`
- `is_ipu_enabled`
- `set_ipu_enabled`
- `get_autocast_ipu_dtype`
- `set_autocast_ipu_dtype`
- `is_hpu_enabled`
- `set_hpu_enabled`
- `get_autocast_hpu_dtype`
- `set_autocast_hpu_dtype`
- `is_xla_enabled`
- `set_xla_enabled`
- `get_autocast_xla_dtype`
- `set_autocast_xla_dtype`
- `is_privateuseone_enabled`
- `set_privateuseone_enabled`
- `get_autocast_privateuseone_dtype`
- `set_autocast_privateuseone_dtype`

In Python side,
provide 4 generic autocast APIs:
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type, new_enabled)`
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type, dtype)`

# Additional Context
We will submit another PR to refactor autocast **Python** APIs based on this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359
Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
Yu, Guangye
2024-04-23 09:34:45 +00:00
committed by PyTorch MergeBot
parent 3c964ad1ca
commit 25f321b84f
12 changed files with 335 additions and 234 deletions

View File

@ -799,7 +799,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
bool enabled = false;
#else
bool enabled = at::autocast::is_enabled();
bool enabled = at::autocast::is_autocast_enabled(at::kCUDA);
#endif
push(stack, enabled);
},
@ -810,7 +810,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
bool enabled = false;
#else
bool enabled = at::autocast::is_cpu_enabled();
bool enabled = at::autocast::is_autocast_enabled(at::kCPU);
#endif
push(stack, enabled);
},