diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index 479cb134af39..e9e18dcbd588 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -140,6 +140,8 @@ const char* toString(DispatchKey t) { case DispatchKey::AutocastCPU: return "AutocastCPU"; + case DispatchKey::AutocastMTIA: + return "AutocastMTIA"; case DispatchKey::AutocastXPU: return "AutocastXPU"; case DispatchKey::AutocastIPU: @@ -296,6 +298,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor}, {"Tracer", c10::DispatchKey::Tracer}, {"AutocastCPU", c10::DispatchKey::AutocastCPU}, + {"AutocastMTIA", c10::DispatchKey::AutocastMTIA}, {"AutocastXPU", c10::DispatchKey::AutocastXPU}, {"AutocastIPU", c10::DispatchKey::AutocastIPU}, {"AutocastHPU", c10::DispatchKey::AutocastHPU}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index fc5bdabd18fd..13c2b1ca2658 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -353,6 +353,7 @@ enum class DispatchKey : uint16_t { // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed // and inputs are saved for backward in the post-autocast type. AutocastCPU, + AutocastMTIA, AutocastXPU, AutocastIPU, AutocastHPU, diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index 289a88312c91..c702737f055e 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -662,6 +662,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ DispatchKey::AutocastHPU, DispatchKey::AutocastXLA, DispatchKey::AutocastPrivateUse1, + DispatchKey::AutocastMTIA, }); // See Note [TLS Initialization] @@ -679,6 +680,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ DispatchKey::AutocastHPU, DispatchKey::AutocastXLA, DispatchKey::AutocastPrivateUse1, + DispatchKey::AutocastMTIA, }); constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = @@ -753,6 +755,7 @@ constexpr auto inplace_or_view_ks = DispatchKeySet(DispatchKey::ADInplaceOrView); constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU); +constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA); constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); @@ -830,6 +833,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { return inplace_or_view_ks | autograd_cpu_ks; case BackendComponent::IPUBit: return inplace_or_view_ks | autograd_ipu_ks; + case BackendComponent::MTIABit: + return inplace_or_view_ks | autograd_mtia_ks; case BackendComponent::XPUBit: return inplace_or_view_ks | autograd_xpu_ks; case BackendComponent::CUDABit: @@ -858,6 +863,7 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { // Returns a DispatchKeySet of autocast related keys mapped to backend. inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); + constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA); constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU); constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU); constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU); @@ -869,6 +875,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { switch (t) { case BackendComponent::CPUBit: return autocast_cpu_ks; + case BackendComponent::MTIABit: + return autocast_mtia_ks; case BackendComponent::XPUBit: return autocast_xpu_ks; case BackendComponent::IPUBit: diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index b7e278503ec6..7e6c61c2660d 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -30,7 +30,7 @@ def is_autocast_available(device_type: str) -> bool: Return a bool indicating if autocast is available on :attr:`device_type`. Args: - device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on. + device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu' and so on. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. """ @@ -202,7 +202,7 @@ class autocast: (see :ref:`Working with Multiple GPUs`). Args: - device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'. + device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu', and 'hpu'. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. enabled(bool, optional): Whether autocasting should be enabled in the region. @@ -282,6 +282,13 @@ class autocast: ) warnings.warn(error_message) enabled = False + elif self.device == "mtia": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False elif self.device == "xpu": supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: @@ -473,7 +480,7 @@ def custom_fwd( See the :ref:`example page` for more detail. Args: - device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on. + device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, @@ -527,7 +534,7 @@ def custom_bwd(bwd=None, *, device_type: str): See the :ref:`example page` for more detail. Args: - device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on. + device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. """