mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MAIA] [Autocast] Enable autocast on MAIA device (#148511)
Fixes #148510. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148511 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
c43e35d6f7
commit
bca75fe97a
@ -64,7 +64,7 @@ thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
at::ScalarType::Undefined, // IDEEP.
|
||||
at::kHalf, // AMD HIP
|
||||
at::ScalarType::Undefined, // FPGA
|
||||
at::ScalarType::Undefined, // ONNX Runtime / Microsoft
|
||||
at::kBFloat16, // ONNX Runtime / Microsoft
|
||||
at::kBFloat16, // XLA / TPU
|
||||
at::ScalarType::Undefined, // Vulkan
|
||||
at::ScalarType::Undefined, // Metal
|
||||
@ -500,6 +500,44 @@ TORCH_LIBRARY_IMPL(aten, AutocastMTIA, m) {
|
||||
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
|
||||
}
|
||||
|
||||
// MAIA
|
||||
TORCH_LIBRARY_IMPL(_, AutocastMAIA, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, AutocastMAIA, m) {
|
||||
// lower_precision_fp
|
||||
#define _KERNEL_MAIA_LOW_PRECISION_FP(...) \
|
||||
KERNEL_MAIA(__VA_ARGS__, lower_precision_fp)
|
||||
|
||||
AT_FORALL_LOWER_PRECISION_FP(_KERNEL_MAIA_LOW_PRECISION_FP)
|
||||
|
||||
// fp32
|
||||
#define _KERNEL_MAIA_FP32(...) KERNEL_MAIA(__VA_ARGS__, fp32)
|
||||
|
||||
AT_FORALL_FP32(_KERNEL_MAIA_FP32)
|
||||
|
||||
// fp32_set_opt_dtype
|
||||
#define _KERNEL_MAIA_FP32_SET_OPT_DTYPE(...) \
|
||||
KERNEL_MAIA(__VA_ARGS__, fp32_set_opt_dtype)
|
||||
|
||||
AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_MAIA_FP32_SET_OPT_DTYPE)
|
||||
|
||||
// fp32_append_dtype
|
||||
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
|
||||
// norm does not implicitly promote, but be aware when adding new ops to this policy.
|
||||
AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA)
|
||||
|
||||
// promote
|
||||
#define _KERNEL_MAIA_PROMOTE(...) KERNEL_MAIA(__VA_ARGS__, promote)
|
||||
|
||||
AT_FORALL_PROMOTE(_KERNEL_MAIA_PROMOTE)
|
||||
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"),
|
||||
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
|
||||
}
|
||||
|
||||
// XPU
|
||||
TORCH_LIBRARY_IMPL(_, AutocastXPU, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
|
@ -126,10 +126,11 @@ TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
|
||||
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
||||
AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
|
||||
|
||||
const std::array<at::DeviceType, 9> _AUTOCAST_SUPPORTED_DEVICES{
|
||||
const std::array<at::DeviceType, 10> _AUTOCAST_SUPPORTED_DEVICES{
|
||||
at::kCPU,
|
||||
at::kCUDA,
|
||||
at::kMTIA,
|
||||
at::kMAIA,
|
||||
at::kXPU,
|
||||
at::kIPU,
|
||||
at::kHPU,
|
||||
@ -150,6 +151,8 @@ inline bool is_autocast_eligible(
|
||||
tensor.is_floating_point();
|
||||
case c10::DeviceType::MTIA:
|
||||
return tensor.is_mtia() && tensor.is_floating_point();
|
||||
case c10::DeviceType::MAIA:
|
||||
return tensor.is_maia() && tensor.is_floating_point();
|
||||
case c10::DeviceType::XPU:
|
||||
return tensor.is_xpu() && tensor.is_floating_point();
|
||||
case c10::DeviceType::IPU:
|
||||
@ -177,6 +180,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
||||
return DispatchKey::AutocastCPU;
|
||||
case c10::DeviceType::MTIA:
|
||||
return DispatchKey::AutocastMTIA;
|
||||
case c10::DeviceType::MAIA:
|
||||
return DispatchKey::AutocastMAIA;
|
||||
case c10::DeviceType::XPU:
|
||||
return DispatchKey::AutocastXPU;
|
||||
case c10::DeviceType::IPU:
|
||||
@ -748,6 +753,24 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
|
||||
REDISPATCH_SIGNATURE, \
|
||||
POLICY)
|
||||
|
||||
// KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA
|
||||
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA
|
||||
#define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__)
|
||||
|
||||
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \
|
||||
REDISPATCH_FUNC, \
|
||||
REGISTER_NAME, \
|
||||
REGISTER_SIGNATURE, \
|
||||
REDISPATCH_SIGNATURE, \
|
||||
POLICY) \
|
||||
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
|
||||
c10::DeviceType::MAIA, \
|
||||
REDISPATCH_FUNC, \
|
||||
REGISTER_NAME, \
|
||||
REGISTER_SIGNATURE, \
|
||||
REDISPATCH_SIGNATURE, \
|
||||
POLICY)
|
||||
|
||||
// KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
|
||||
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
|
||||
#define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
|
||||
|
@ -80,6 +80,10 @@ TORCH_LIBRARY_IMPL(_, AutogradMTIA, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradMAIA, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradXLA, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ inline Backend dispatchKeyToBackend(DispatchKey t) {
|
||||
return Backend::VE;
|
||||
} else if (t == DispatchKey::FPGA) {
|
||||
return Backend::FPGA;
|
||||
} else if (t == DispatchKey::MAIA) {
|
||||
} else if (t == DispatchKey::MAIA || t == DispatchKey::AutogradMAIA) {
|
||||
return Backend::MAIA;
|
||||
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
|
||||
return Backend::XLA;
|
||||
|
@ -32,6 +32,8 @@ const char* toString(BackendComponent t) {
|
||||
return "VEBit";
|
||||
case BackendComponent::MTIABit:
|
||||
return "MTIA";
|
||||
case BackendComponent::MAIABit:
|
||||
return "MAIA";
|
||||
case BackendComponent::PrivateUse1Bit:
|
||||
return "PrivateUse1Bit";
|
||||
case BackendComponent::PrivateUse2Bit:
|
||||
@ -142,6 +144,8 @@ const char* toString(DispatchKey t) {
|
||||
return "AutocastCPU";
|
||||
case DispatchKey::AutocastMTIA:
|
||||
return "AutocastMTIA";
|
||||
case DispatchKey::AutocastMAIA:
|
||||
return "AutocastMAIA";
|
||||
case DispatchKey::AutocastXPU:
|
||||
return "AutocastXPU";
|
||||
case DispatchKey::AutocastIPU:
|
||||
@ -299,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"Tracer", c10::DispatchKey::Tracer},
|
||||
{"AutocastCPU", c10::DispatchKey::AutocastCPU},
|
||||
{"AutocastMTIA", c10::DispatchKey::AutocastMTIA},
|
||||
{"AutocastMAIA", c10::DispatchKey::AutocastMAIA},
|
||||
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
|
||||
{"AutocastIPU", c10::DispatchKey::AutocastIPU},
|
||||
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
|
||||
|
@ -45,6 +45,7 @@ namespace c10 {
|
||||
_(VE, extra) \
|
||||
_(Lazy, extra) \
|
||||
_(MTIA, extra) \
|
||||
_(MAIA, extra) \
|
||||
_(PrivateUse1, extra) \
|
||||
_(PrivateUse2, extra) \
|
||||
_(PrivateUse3, extra) \
|
||||
@ -180,13 +181,6 @@ enum class DispatchKey : uint16_t {
|
||||
FPGA, // Xilinx support lives out of tree at
|
||||
// https://gitlab.com/pytorch-complex/vitis_kernels
|
||||
|
||||
// TODO: put this in BackendComponents
|
||||
// MAIA backend lives out of tree
|
||||
// - test/cpp_extensions/maia_extension.cpp
|
||||
// - test/test_torch.py
|
||||
// - aten/src/ATen/test/extension_backend_test.cpp
|
||||
MAIA,
|
||||
|
||||
Vulkan, // TODO: put this in BackendComponents
|
||||
Metal, // TODO: put this in BackendComponents
|
||||
|
||||
@ -354,6 +348,7 @@ enum class DispatchKey : uint16_t {
|
||||
// and inputs are saved for backward in the post-autocast type.
|
||||
AutocastCPU,
|
||||
AutocastMTIA,
|
||||
AutocastMAIA,
|
||||
AutocastXPU,
|
||||
AutocastIPU,
|
||||
AutocastHPU,
|
||||
|
@ -131,6 +131,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
|
||||
return DispatchKeySet(DispatchKey::IPU);
|
||||
case DispatchKey::AutogradXPU:
|
||||
return DispatchKeySet(DispatchKey::XPU);
|
||||
case DispatchKey::AutogradMAIA:
|
||||
return DispatchKeySet(DispatchKey::MAIA);
|
||||
case DispatchKey::AutogradPrivateUse1:
|
||||
return DispatchKeySet(DispatchKey::PrivateUse1);
|
||||
case DispatchKey::AutogradPrivateUse2:
|
||||
|
@ -663,6 +663,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||||
DispatchKey::AutocastXLA,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
DispatchKey::AutocastMTIA,
|
||||
DispatchKey::AutocastMAIA,
|
||||
});
|
||||
|
||||
// See Note [TLS Initialization]
|
||||
@ -681,6 +682,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||||
DispatchKey::AutocastXLA,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
DispatchKey::AutocastMTIA,
|
||||
DispatchKey::AutocastMAIA,
|
||||
});
|
||||
|
||||
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
|
||||
@ -706,7 +708,6 @@ constexpr DispatchKeySet autogradother_backends =
|
||||
// Technically, HIP will now redispatch to its own custom AutogradHIP
|
||||
// slot in the runtime table.
|
||||
{DispatchKey::FPGA,
|
||||
DispatchKey::MAIA,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::CustomRNGKeyId,
|
||||
@ -756,6 +757,7 @@ constexpr auto inplace_or_view_ks =
|
||||
constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
|
||||
constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
|
||||
constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA);
|
||||
constexpr auto autograd_maia_ks = DispatchKeySet(DispatchKey::AutogradMAIA);
|
||||
constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
|
||||
constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
|
||||
constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
|
||||
@ -835,6 +837,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
|
||||
return inplace_or_view_ks | autograd_ipu_ks;
|
||||
case BackendComponent::MTIABit:
|
||||
return inplace_or_view_ks | autograd_mtia_ks;
|
||||
case BackendComponent::MAIABit:
|
||||
return inplace_or_view_ks | autograd_maia_ks;
|
||||
case BackendComponent::XPUBit:
|
||||
return inplace_or_view_ks | autograd_xpu_ks;
|
||||
case BackendComponent::CUDABit:
|
||||
@ -864,6 +868,7 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
|
||||
inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
|
||||
constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA);
|
||||
constexpr auto autocast_maia_ks = DispatchKeySet(DispatchKey::AutocastMAIA);
|
||||
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
|
||||
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
|
||||
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
|
||||
@ -877,6 +882,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
return autocast_cpu_ks;
|
||||
case BackendComponent::MTIABit:
|
||||
return autocast_mtia_ks;
|
||||
case BackendComponent::MAIABit:
|
||||
return autocast_maia_ks;
|
||||
case BackendComponent::XPUBit:
|
||||
return autocast_xpu_ks;
|
||||
case BackendComponent::IPUBit:
|
||||
|
@ -52,11 +52,28 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
||||
get_tensor(input.dtype(), {}));
|
||||
}
|
||||
|
||||
at::Tensor maia_to_dtype_override(
|
||||
const at::Tensor & self, at::ScalarType dtype, bool non_blocking,
|
||||
bool copy, ::std::optional<at::MemoryFormat> memory_format
|
||||
) {
|
||||
return get_tensor(scalarTypeToTypeMeta(dtype), self.sizes());
|
||||
}
|
||||
|
||||
at::Tensor maia_matmul_override(const at::Tensor & self, const at::Tensor & other) {
|
||||
AT_ASSERT(self.dim() == 2);
|
||||
AT_ASSERT(other.dim() == 2);
|
||||
AT_ASSERT(self.dtype() == other.dtype());
|
||||
AT_ASSERT(self.device() == other.device());
|
||||
return get_tensor(self.dtype(), {self.size(0), other.size(1)});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MAIA, m) {
|
||||
m.impl("empty.memory_format", empty_override);
|
||||
m.impl("add.out", add_out_override);
|
||||
m.impl("convolution_overrideable", fake_convolution);
|
||||
m.impl("convolution_backward_overrideable", fake_convolution_backward);
|
||||
m.impl("to.dtype", maia_to_dtype_override);
|
||||
m.impl("matmul", maia_matmul_override);
|
||||
}
|
||||
|
||||
// TODO: Extend this to exercise multi-device setting. In that case,
|
||||
|
@ -458,6 +458,38 @@ class TestMAIATensor(common.TestCase):
|
||||
self.assertEqual(maia_extension.get_test_int(), 3)
|
||||
self.assertEqual(grad[0].shape, input.shape)
|
||||
|
||||
def test_autocast_apis_for_maia_device(self):
|
||||
# Default low-precision type in MAIA's autocast.
|
||||
fast_dtype = torch.get_autocast_dtype("maia")
|
||||
self.assertEqual(fast_dtype, torch.bfloat16)
|
||||
self.assertTrue(torch._C._is_autocast_available("maia"))
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"dynamo cannot handle maia device. Output tensor may have wrong dtype."
|
||||
)
|
||||
def test_matmul_autocast_float16_precision(self):
|
||||
# Ensure we can change low precision dtype.
|
||||
x = torch.empty((2, 4), dtype=torch.float, device="maia")
|
||||
w = torch.empty((4, 2), dtype=torch.float, device="maia")
|
||||
with torch.autocast(device_type="maia", dtype=torch.float16):
|
||||
self.assertTrue(torch.is_autocast_enabled("maia"))
|
||||
y = torch.ops.aten.matmul(x, w)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
||||
self.assertEqual(y.shape, (2, 2))
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"dynamo cannot handle maia device. Output tensor may have wrong dtype."
|
||||
)
|
||||
def test_matmul_autocast_default_precision(self):
|
||||
# Use default lower precision dtype, bfloat16.
|
||||
x = torch.empty((2, 4), dtype=torch.float, device="maia")
|
||||
w = torch.empty((4, 2), dtype=torch.float, device="maia")
|
||||
with torch.autocast(device_type="maia"):
|
||||
self.assertTrue(torch.is_autocast_enabled("maia"))
|
||||
y = torch.ops.aten.matmul(x, w)
|
||||
self.assertEqual(y.dtype, torch.bfloat16)
|
||||
self.assertEqual(y.shape, (2, 2))
|
||||
|
||||
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
class TestRNGExtension(common.TestCase):
|
||||
|
@ -30,7 +30,7 @@ def is_autocast_available(device_type: str) -> bool:
|
||||
Return a bool indicating if autocast is available on :attr:`device_type`.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu' and so on.
|
||||
device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
"""
|
||||
@ -202,7 +202,7 @@ class autocast:
|
||||
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).
|
||||
|
||||
Args:
|
||||
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu', and 'hpu'.
|
||||
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
enabled(bool, optional): Whether autocasting should be enabled in the region.
|
||||
@ -289,6 +289,13 @@ class autocast:
|
||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "maia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
@ -480,7 +487,7 @@ def custom_fwd(
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
@ -534,7 +541,7 @@ def custom_bwd(bwd=None, *, device_type: str):
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on.
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user