[MAIA] [Autocast] Enable autocast on MAIA device (#148511)

Fixes #148510.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148511
Approved by: https://github.com/albanD
This commit is contained in:
Wei-Sheng Chin
2025-03-18 03:46:19 +00:00
committed by PyTorch MergeBot
parent c43e35d6f7
commit bca75fe97a
11 changed files with 145 additions and 15 deletions

View File

@ -64,7 +64,7 @@ thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
at::ScalarType::Undefined, // IDEEP.
at::kHalf, // AMD HIP
at::ScalarType::Undefined, // FPGA
at::ScalarType::Undefined, // ONNX Runtime / Microsoft
at::kBFloat16, // ONNX Runtime / Microsoft
at::kBFloat16, // XLA / TPU
at::ScalarType::Undefined, // Vulkan
at::ScalarType::Undefined, // Metal
@ -500,6 +500,44 @@ TORCH_LIBRARY_IMPL(aten, AutocastMTIA, m) {
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
}
// MAIA
TORCH_LIBRARY_IMPL(_, AutocastMAIA, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(aten, AutocastMAIA, m) {
// lower_precision_fp
#define _KERNEL_MAIA_LOW_PRECISION_FP(...) \
KERNEL_MAIA(__VA_ARGS__, lower_precision_fp)
AT_FORALL_LOWER_PRECISION_FP(_KERNEL_MAIA_LOW_PRECISION_FP)
// fp32
#define _KERNEL_MAIA_FP32(...) KERNEL_MAIA(__VA_ARGS__, fp32)
AT_FORALL_FP32(_KERNEL_MAIA_FP32)
// fp32_set_opt_dtype
#define _KERNEL_MAIA_FP32_SET_OPT_DTYPE(...) \
KERNEL_MAIA(__VA_ARGS__, fp32_set_opt_dtype)
AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_MAIA_FP32_SET_OPT_DTYPE)
// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
// norm does not implicitly promote, but be aware when adding new ops to this policy.
AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA)
// promote
#define _KERNEL_MAIA_PROMOTE(...) KERNEL_MAIA(__VA_ARGS__, promote)
AT_FORALL_PROMOTE(_KERNEL_MAIA_PROMOTE)
m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
}
// XPU
TORCH_LIBRARY_IMPL(_, AutocastXPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());

View File

@ -126,10 +126,11 @@ TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
// NOLINTNEXTLINE(misc-use-internal-linkage)
AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
const std::array<at::DeviceType, 9> _AUTOCAST_SUPPORTED_DEVICES{
const std::array<at::DeviceType, 10> _AUTOCAST_SUPPORTED_DEVICES{
at::kCPU,
at::kCUDA,
at::kMTIA,
at::kMAIA,
at::kXPU,
at::kIPU,
at::kHPU,
@ -150,6 +151,8 @@ inline bool is_autocast_eligible(
tensor.is_floating_point();
case c10::DeviceType::MTIA:
return tensor.is_mtia() && tensor.is_floating_point();
case c10::DeviceType::MAIA:
return tensor.is_maia() && tensor.is_floating_point();
case c10::DeviceType::XPU:
return tensor.is_xpu() && tensor.is_floating_point();
case c10::DeviceType::IPU:
@ -177,6 +180,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
return DispatchKey::AutocastCPU;
case c10::DeviceType::MTIA:
return DispatchKey::AutocastMTIA;
case c10::DeviceType::MAIA:
return DispatchKey::AutocastMAIA;
case c10::DeviceType::XPU:
return DispatchKey::AutocastXPU;
case c10::DeviceType::IPU:
@ -748,6 +753,24 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA
#define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY) \
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
c10::DeviceType::MAIA, \
REDISPATCH_FUNC, \
REGISTER_NAME, \
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
#define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)

View File

@ -80,6 +80,10 @@ TORCH_LIBRARY_IMPL(_, AutogradMTIA, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradMAIA, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
m.fallback(AUTOGRAD_FALLBACK);
}

View File

@ -76,7 +76,7 @@ inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::VE;
} else if (t == DispatchKey::FPGA) {
return Backend::FPGA;
} else if (t == DispatchKey::MAIA) {
} else if (t == DispatchKey::MAIA || t == DispatchKey::AutogradMAIA) {
return Backend::MAIA;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;

View File

@ -32,6 +32,8 @@ const char* toString(BackendComponent t) {
return "VEBit";
case BackendComponent::MTIABit:
return "MTIA";
case BackendComponent::MAIABit:
return "MAIA";
case BackendComponent::PrivateUse1Bit:
return "PrivateUse1Bit";
case BackendComponent::PrivateUse2Bit:
@ -142,6 +144,8 @@ const char* toString(DispatchKey t) {
return "AutocastCPU";
case DispatchKey::AutocastMTIA:
return "AutocastMTIA";
case DispatchKey::AutocastMAIA:
return "AutocastMAIA";
case DispatchKey::AutocastXPU:
return "AutocastXPU";
case DispatchKey::AutocastIPU:
@ -299,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"Tracer", c10::DispatchKey::Tracer},
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
{"AutocastMTIA", c10::DispatchKey::AutocastMTIA},
{"AutocastMAIA", c10::DispatchKey::AutocastMAIA},
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
{"AutocastIPU", c10::DispatchKey::AutocastIPU},
{"AutocastHPU", c10::DispatchKey::AutocastHPU},

View File

@ -45,6 +45,7 @@ namespace c10 {
_(VE, extra) \
_(Lazy, extra) \
_(MTIA, extra) \
_(MAIA, extra) \
_(PrivateUse1, extra) \
_(PrivateUse2, extra) \
_(PrivateUse3, extra) \
@ -180,13 +181,6 @@ enum class DispatchKey : uint16_t {
FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels
// TODO: put this in BackendComponents
// MAIA backend lives out of tree
// - test/cpp_extensions/maia_extension.cpp
// - test/test_torch.py
// - aten/src/ATen/test/extension_backend_test.cpp
MAIA,
Vulkan, // TODO: put this in BackendComponents
Metal, // TODO: put this in BackendComponents
@ -354,6 +348,7 @@ enum class DispatchKey : uint16_t {
// and inputs are saved for backward in the post-autocast type.
AutocastCPU,
AutocastMTIA,
AutocastMAIA,
AutocastXPU,
AutocastIPU,
AutocastHPU,

View File

@ -131,6 +131,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
return DispatchKeySet(DispatchKey::IPU);
case DispatchKey::AutogradXPU:
return DispatchKeySet(DispatchKey::XPU);
case DispatchKey::AutogradMAIA:
return DispatchKeySet(DispatchKey::MAIA);
case DispatchKey::AutogradPrivateUse1:
return DispatchKeySet(DispatchKey::PrivateUse1);
case DispatchKey::AutogradPrivateUse2:

View File

@ -663,6 +663,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
DispatchKey::AutocastMTIA,
DispatchKey::AutocastMAIA,
});
// See Note [TLS Initialization]
@ -681,6 +682,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
DispatchKey::AutocastMTIA,
DispatchKey::AutocastMAIA,
});
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
@ -706,7 +708,6 @@ constexpr DispatchKeySet autogradother_backends =
// Technically, HIP will now redispatch to its own custom AutogradHIP
// slot in the runtime table.
{DispatchKey::FPGA,
DispatchKey::MAIA,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::CustomRNGKeyId,
@ -756,6 +757,7 @@ constexpr auto inplace_or_view_ks =
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA);
constexpr auto autograd_maia_ks = DispatchKeySet(DispatchKey::AutogradMAIA);
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
@ -835,6 +837,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
return inplace_or_view_ks | autograd_ipu_ks;
case BackendComponent::MTIABit:
return inplace_or_view_ks | autograd_mtia_ks;
case BackendComponent::MAIABit:
return inplace_or_view_ks | autograd_maia_ks;
case BackendComponent::XPUBit:
return inplace_or_view_ks | autograd_xpu_ks;
case BackendComponent::CUDABit:
@ -864,6 +868,7 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA);
constexpr auto autocast_maia_ks = DispatchKeySet(DispatchKey::AutocastMAIA);
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
@ -877,6 +882,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
return autocast_cpu_ks;
case BackendComponent::MTIABit:
return autocast_mtia_ks;
case BackendComponent::MAIABit:
return autocast_maia_ks;
case BackendComponent::XPUBit:
return autocast_xpu_ks;
case BackendComponent::IPUBit:

View File

@ -52,11 +52,28 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
get_tensor(input.dtype(), {}));
}
at::Tensor maia_to_dtype_override(
const at::Tensor & self, at::ScalarType dtype, bool non_blocking,
bool copy, ::std::optional<at::MemoryFormat> memory_format
) {
return get_tensor(scalarTypeToTypeMeta(dtype), self.sizes());
}
at::Tensor maia_matmul_override(const at::Tensor & self, const at::Tensor & other) {
AT_ASSERT(self.dim() == 2);
AT_ASSERT(other.dim() == 2);
AT_ASSERT(self.dtype() == other.dtype());
AT_ASSERT(self.device() == other.device());
return get_tensor(self.dtype(), {self.size(0), other.size(1)});
}
TORCH_LIBRARY_IMPL(aten, MAIA, m) {
m.impl("empty.memory_format", empty_override);
m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution);
m.impl("convolution_backward_overrideable", fake_convolution_backward);
m.impl("to.dtype", maia_to_dtype_override);
m.impl("matmul", maia_matmul_override);
}
// TODO: Extend this to exercise multi-device setting. In that case,

View File

@ -458,6 +458,38 @@ class TestMAIATensor(common.TestCase):
self.assertEqual(maia_extension.get_test_int(), 3)
self.assertEqual(grad[0].shape, input.shape)
def test_autocast_apis_for_maia_device(self):
# Default low-precision type in MAIA's autocast.
fast_dtype = torch.get_autocast_dtype("maia")
self.assertEqual(fast_dtype, torch.bfloat16)
self.assertTrue(torch._C._is_autocast_available("maia"))
@skipIfTorchDynamo(
"dynamo cannot handle maia device. Output tensor may have wrong dtype."
)
def test_matmul_autocast_float16_precision(self):
# Ensure we can change low precision dtype.
x = torch.empty((2, 4), dtype=torch.float, device="maia")
w = torch.empty((4, 2), dtype=torch.float, device="maia")
with torch.autocast(device_type="maia", dtype=torch.float16):
self.assertTrue(torch.is_autocast_enabled("maia"))
y = torch.ops.aten.matmul(x, w)
self.assertEqual(y.dtype, torch.float16)
self.assertEqual(y.shape, (2, 2))
@skipIfTorchDynamo(
"dynamo cannot handle maia device. Output tensor may have wrong dtype."
)
def test_matmul_autocast_default_precision(self):
# Use default lower precision dtype, bfloat16.
x = torch.empty((2, 4), dtype=torch.float, device="maia")
w = torch.empty((4, 2), dtype=torch.float, device="maia")
with torch.autocast(device_type="maia"):
self.assertTrue(torch.is_autocast_enabled("maia"))
y = torch.ops.aten.matmul(x, w)
self.assertEqual(y.dtype, torch.bfloat16)
self.assertEqual(y.shape, (2, 2))
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestRNGExtension(common.TestCase):

View File

@ -30,7 +30,7 @@ def is_autocast_available(device_type: str) -> bool:
Return a bool indicating if autocast is available on :attr:`device_type`.
Args:
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu' and so on.
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
@ -202,7 +202,7 @@ class autocast:
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).
Args:
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu', and 'hpu'.
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
enabled(bool, optional): Whether autocasting should be enabled in the region.
@ -289,6 +289,13 @@ class autocast:
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
@ -480,7 +487,7 @@ def custom_fwd(
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
@ -534,7 +541,7 @@ def custom_bwd(bwd=None, *, device_type: str):
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""