mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[autocast][pytorch] Support autocast for MTIA (#145627)
Summary: Add autocast support to MTIA Reviewed By: egienvalue Differential Revision: D68572548 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145627 Approved by: https://github.com/egienvalue
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef60de07a0
commit
6939a56e13
@ -140,6 +140,8 @@ const char* toString(DispatchKey t) {
|
||||
|
||||
case DispatchKey::AutocastCPU:
|
||||
return "AutocastCPU";
|
||||
case DispatchKey::AutocastMTIA:
|
||||
return "AutocastMTIA";
|
||||
case DispatchKey::AutocastXPU:
|
||||
return "AutocastXPU";
|
||||
case DispatchKey::AutocastIPU:
|
||||
@ -296,6 +298,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
|
||||
{"Tracer", c10::DispatchKey::Tracer},
|
||||
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
|
||||
{"AutocastMTIA", c10::DispatchKey::AutocastMTIA},
|
||||
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
|
||||
{"AutocastIPU", c10::DispatchKey::AutocastIPU},
|
||||
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
|
||||
|
@ -353,6 +353,7 @@ enum class DispatchKey : uint16_t {
|
||||
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
|
||||
// and inputs are saved for backward in the post-autocast type.
|
||||
AutocastCPU,
|
||||
AutocastMTIA,
|
||||
AutocastXPU,
|
||||
AutocastIPU,
|
||||
AutocastHPU,
|
||||
|
@ -662,6 +662,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||||
DispatchKey::AutocastHPU,
|
||||
DispatchKey::AutocastXLA,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
DispatchKey::AutocastMTIA,
|
||||
});
|
||||
|
||||
// See Note [TLS Initialization]
|
||||
@ -679,6 +680,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||||
DispatchKey::AutocastHPU,
|
||||
DispatchKey::AutocastXLA,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
DispatchKey::AutocastMTIA,
|
||||
});
|
||||
|
||||
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
|
||||
@ -753,6 +755,7 @@ constexpr auto inplace_or_view_ks =
|
||||
DispatchKeySet(DispatchKey::ADInplaceOrView);
|
||||
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
|
||||
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
|
||||
constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA);
|
||||
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
|
||||
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
|
||||
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
|
||||
@ -830,6 +833,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
|
||||
return inplace_or_view_ks | autograd_cpu_ks;
|
||||
case BackendComponent::IPUBit:
|
||||
return inplace_or_view_ks | autograd_ipu_ks;
|
||||
case BackendComponent::MTIABit:
|
||||
return inplace_or_view_ks | autograd_mtia_ks;
|
||||
case BackendComponent::XPUBit:
|
||||
return inplace_or_view_ks | autograd_xpu_ks;
|
||||
case BackendComponent::CUDABit:
|
||||
@ -858,6 +863,7 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
|
||||
// Returns a DispatchKeySet of autocast related keys mapped to backend.
|
||||
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
|
||||
constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA);
|
||||
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
|
||||
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
|
||||
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
|
||||
@ -869,6 +875,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
switch (t) {
|
||||
case BackendComponent::CPUBit:
|
||||
return autocast_cpu_ks;
|
||||
case BackendComponent::MTIABit:
|
||||
return autocast_mtia_ks;
|
||||
case BackendComponent::XPUBit:
|
||||
return autocast_xpu_ks;
|
||||
case BackendComponent::IPUBit:
|
||||
|
@ -30,7 +30,7 @@ def is_autocast_available(device_type: str) -> bool:
|
||||
Return a bool indicating if autocast is available on :attr:`device_type`.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and so on.
|
||||
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
"""
|
||||
@ -202,7 +202,7 @@ class autocast:
|
||||
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).
|
||||
|
||||
Args:
|
||||
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
|
||||
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu', and 'hpu'.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
enabled(bool, optional): Whether autocasting should be enabled in the region.
|
||||
@ -282,6 +282,13 @@ class autocast:
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "mtia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
@ -473,7 +480,7 @@ def custom_fwd(
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
@ -527,7 +534,7 @@ def custom_bwd(bwd=None, *, device_type: str):
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user