mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor autocast C++ APIs to be device-agnostic (#124359)
# Motivation This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast **C++** APIs. In C++ side, - `is_enabled()` -> `is_enabled(device_type)`. - `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`. - `get_autocast_dtype()` -> `get_autocast_dtype(device_type)` - `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)` These following C++ APIs are deprecated and should be removed in PyTorch 2.5 - `is_cpu_enabled` - `set_cpu_enabled` - `get_autocast_cpu_dtype` - `set_autocast_cpu_dtype` - `is_xpu_enabled` - `set_xpu_enabled` - `get_autocast_xpu_dtype` - `set_autocast_xpu_dtype` - `is_ipu_enabled` - `set_ipu_enabled` - `get_autocast_ipu_dtype` - `set_autocast_ipu_dtype` - `is_hpu_enabled` - `set_hpu_enabled` - `get_autocast_hpu_dtype` - `set_autocast_hpu_dtype` - `is_xla_enabled` - `set_xla_enabled` - `get_autocast_xla_dtype` - `set_autocast_xla_dtype` - `is_privateuseone_enabled` - `set_privateuseone_enabled` - `get_autocast_privateuseone_dtype` - `set_autocast_privateuseone_dtype` In Python side, provide 4 generic autocast APIs: - `torch.is_autocast_enabled(device_type)` - `torch.set_autocast_enabled(device_type, new_enabled)` - `torch.get_autocast_dtype(device_type)` - `torch.set_autocast_dtype(device_type, dtype)` # Additional Context We will submit another PR to refactor autocast **Python** APIs based on this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359 Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3c964ad1ca
commit
25f321b84f
@ -799,7 +799,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
|
||||
bool enabled = false;
|
||||
#else
|
||||
bool enabled = at::autocast::is_enabled();
|
||||
bool enabled = at::autocast::is_autocast_enabled(at::kCUDA);
|
||||
#endif
|
||||
push(stack, enabled);
|
||||
},
|
||||
@ -810,7 +810,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
|
||||
bool enabled = false;
|
||||
#else
|
||||
bool enabled = at::autocast::is_cpu_enabled();
|
||||
bool enabled = at::autocast::is_autocast_enabled(at::kCPU);
|
||||
#endif
|
||||
push(stack, enabled);
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user