[autocast][pytorch] Support autocast for MTIA (#145627)

Summary: Add autocast support to MTIA

Reviewed By: egienvalue

Differential Revision: D68572548

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145627
Approved by: https://github.com/egienvalue
This commit is contained in:
Simon Mahns
2025-01-25 03:24:57 +00:00
committed by PyTorch MergeBot
parent ef60de07a0
commit 6939a56e13
4 changed files with 23 additions and 4 deletions

View File

@ -140,6 +140,8 @@ const char* toString(DispatchKey t) {
case DispatchKey::AutocastCPU:
return "AutocastCPU";
case DispatchKey::AutocastMTIA:
return "AutocastMTIA";
case DispatchKey::AutocastXPU:
return "AutocastXPU";
case DispatchKey::AutocastIPU:
@ -296,6 +298,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
{"Tracer", c10::DispatchKey::Tracer},
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
{"AutocastMTIA", c10::DispatchKey::AutocastMTIA},
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
{"AutocastIPU", c10::DispatchKey::AutocastIPU},
{"AutocastHPU", c10::DispatchKey::AutocastHPU},

View File

@ -353,6 +353,7 @@ enum class DispatchKey : uint16_t {
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
// and inputs are saved for backward in the post-autocast type.
AutocastCPU,
AutocastMTIA,
AutocastXPU,
AutocastIPU,
AutocastHPU,

View File

@ -662,6 +662,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastHPU,
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
DispatchKey::AutocastMTIA,
});
// See Note [TLS Initialization]
@ -679,6 +680,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastHPU,
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
DispatchKey::AutocastMTIA,
});
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
@ -753,6 +755,7 @@ constexpr auto inplace_or_view_ks =
DispatchKeySet(DispatchKey::ADInplaceOrView);
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
@ -830,6 +833,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
return inplace_or_view_ks | autograd_cpu_ks;
case BackendComponent::IPUBit:
return inplace_or_view_ks | autograd_ipu_ks;
case BackendComponent::MTIABit:
return inplace_or_view_ks | autograd_mtia_ks;
case BackendComponent::XPUBit:
return inplace_or_view_ks | autograd_xpu_ks;
case BackendComponent::CUDABit:
@ -858,6 +863,7 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
// Returns a DispatchKeySet of autocast related keys mapped to backend.
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA);
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
@ -869,6 +875,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
case BackendComponent::MTIABit:
return autocast_mtia_ks;
case BackendComponent::XPUBit:
return autocast_xpu_ks;
case BackendComponent::IPUBit:

View File

@ -30,7 +30,7 @@ def is_autocast_available(device_type: str) -> bool:
Return a bool indicating if autocast is available on :attr:`device_type`.
Args:
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
@ -202,7 +202,7 @@ class autocast:
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).
Args:
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu', and 'hpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
enabled(bool, optional): Whether autocasting should be enabled in the region.
@ -282,6 +282,13 @@ class autocast:
)
warnings.warn(error_message)
enabled = False
elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
@ -473,7 +480,7 @@ def custom_fwd(
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
@ -527,7 +534,7 @@ def custom_bwd(bwd=None, *, device_type: str):
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""