diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 4fae147e2815..afd0a6b67674 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -64,7 +64,7 @@ thread_local std::array at::ScalarType::Undefined, // IDEEP. at::kHalf, // AMD HIP at::ScalarType::Undefined, // FPGA - at::ScalarType::Undefined, // ONNX Runtime / Microsoft + at::kBFloat16, // ONNX Runtime / Microsoft at::kBFloat16, // XLA / TPU at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal @@ -500,6 +500,44 @@ TORCH_LIBRARY_IMPL(aten, AutocastMTIA, m) { TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } +// MAIA +TORCH_LIBRARY_IMPL(_, AutocastMAIA, m) { + m.fallback(torch::CppFunction::makeFallthrough()); +} + +TORCH_LIBRARY_IMPL(aten, AutocastMAIA, m) { + // lower_precision_fp +#define _KERNEL_MAIA_LOW_PRECISION_FP(...) \ + KERNEL_MAIA(__VA_ARGS__, lower_precision_fp) + + AT_FORALL_LOWER_PRECISION_FP(_KERNEL_MAIA_LOW_PRECISION_FP) + + // fp32 +#define _KERNEL_MAIA_FP32(...) KERNEL_MAIA(__VA_ARGS__, fp32) + + AT_FORALL_FP32(_KERNEL_MAIA_FP32) + + // fp32_set_opt_dtype +#define _KERNEL_MAIA_FP32_SET_OPT_DTYPE(...) \ + KERNEL_MAIA(__VA_ARGS__, fp32_set_opt_dtype) + + AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_MAIA_FP32_SET_OPT_DTYPE) + + // fp32_append_dtype + // The fp32_append_dtype wrapper overrides implicit promotion behavior. + // norm does not implicitly promote, but be aware when adding new ops to this policy. + AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA) + + // promote +#define _KERNEL_MAIA_PROMOTE(...) KERNEL_MAIA(__VA_ARGS__, promote) + + AT_FORALL_PROMOTE(_KERNEL_MAIA_PROMOTE) + + m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), + TORCH_FN((&at::autocast::binary_cross_entropy_banned))); +} + // XPU TORCH_LIBRARY_IMPL(_, AutocastXPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 551dff55e1de..56f5e2fe5511 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -126,10 +126,11 @@ TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) { // NOLINTNEXTLINE(misc-use-internal-linkage) AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS) -const std::array _AUTOCAST_SUPPORTED_DEVICES{ +const std::array _AUTOCAST_SUPPORTED_DEVICES{ at::kCPU, at::kCUDA, at::kMTIA, + at::kMAIA, at::kXPU, at::kIPU, at::kHPU, @@ -150,6 +151,8 @@ inline bool is_autocast_eligible( tensor.is_floating_point(); case c10::DeviceType::MTIA: return tensor.is_mtia() && tensor.is_floating_point(); + case c10::DeviceType::MAIA: + return tensor.is_maia() && tensor.is_floating_point(); case c10::DeviceType::XPU: return tensor.is_xpu() && tensor.is_floating_point(); case c10::DeviceType::IPU: @@ -177,6 +180,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type( return DispatchKey::AutocastCPU; case c10::DeviceType::MTIA: return DispatchKey::AutocastMTIA; + case c10::DeviceType::MAIA: + return DispatchKey::AutocastMAIA; case c10::DeviceType::XPU: return DispatchKey::AutocastXPU; case c10::DeviceType::IPU: @@ -748,6 +753,24 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) +// KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA +#define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::MAIA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__) diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp index 390d9189190e..2ae1f5f8f0c9 100644 --- a/aten/src/ATen/core/VariableFallbackKernel.cpp +++ b/aten/src/ATen/core/VariableFallbackKernel.cpp @@ -80,6 +80,10 @@ TORCH_LIBRARY_IMPL(_, AutogradMTIA, m) { m.fallback(AUTOGRAD_FALLBACK); } +TORCH_LIBRARY_IMPL(_, AutogradMAIA, m) { + m.fallback(AUTOGRAD_FALLBACK); +} + TORCH_LIBRARY_IMPL(_, AutogradXLA, m) { m.fallback(AUTOGRAD_FALLBACK); } diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 8ecaa7be7377..409c837c5908 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -76,7 +76,7 @@ inline Backend dispatchKeyToBackend(DispatchKey t) { return Backend::VE; } else if (t == DispatchKey::FPGA) { return Backend::FPGA; - } else if (t == DispatchKey::MAIA) { + } else if (t == DispatchKey::MAIA || t == DispatchKey::AutogradMAIA) { return Backend::MAIA; } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { return Backend::XLA; diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp index e9e18dcbd588..9d543db2e555 100644 --- a/c10/core/DispatchKey.cpp +++ b/c10/core/DispatchKey.cpp @@ -32,6 +32,8 @@ const char* toString(BackendComponent t) { return "VEBit"; case BackendComponent::MTIABit: return "MTIA"; + case BackendComponent::MAIABit: + return "MAIA"; case BackendComponent::PrivateUse1Bit: return "PrivateUse1Bit"; case BackendComponent::PrivateUse2Bit: @@ -142,6 +144,8 @@ const char* toString(DispatchKey t) { return "AutocastCPU"; case DispatchKey::AutocastMTIA: return "AutocastMTIA"; + case DispatchKey::AutocastMAIA: + return "AutocastMAIA"; case DispatchKey::AutocastXPU: return "AutocastXPU"; case DispatchKey::AutocastIPU: @@ -299,6 +303,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) { {"Tracer", c10::DispatchKey::Tracer}, {"AutocastCPU", c10::DispatchKey::AutocastCPU}, {"AutocastMTIA", c10::DispatchKey::AutocastMTIA}, + {"AutocastMAIA", c10::DispatchKey::AutocastMAIA}, {"AutocastXPU", c10::DispatchKey::AutocastXPU}, {"AutocastIPU", c10::DispatchKey::AutocastIPU}, {"AutocastHPU", c10::DispatchKey::AutocastHPU}, diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index 13c2b1ca2658..30aad0aeb00a 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -45,6 +45,7 @@ namespace c10 { _(VE, extra) \ _(Lazy, extra) \ _(MTIA, extra) \ + _(MAIA, extra) \ _(PrivateUse1, extra) \ _(PrivateUse2, extra) \ _(PrivateUse3, extra) \ @@ -180,13 +181,6 @@ enum class DispatchKey : uint16_t { FPGA, // Xilinx support lives out of tree at // https://gitlab.com/pytorch-complex/vitis_kernels - // TODO: put this in BackendComponents - // MAIA backend lives out of tree - // - test/cpp_extensions/maia_extension.cpp - // - test/test_torch.py - // - aten/src/ATen/test/extension_backend_test.cpp - MAIA, - Vulkan, // TODO: put this in BackendComponents Metal, // TODO: put this in BackendComponents @@ -354,6 +348,7 @@ enum class DispatchKey : uint16_t { // and inputs are saved for backward in the post-autocast type. AutocastCPU, AutocastMTIA, + AutocastMAIA, AutocastXPU, AutocastIPU, AutocastHPU, diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index f8f0b755e17e..4cbd0cea8571 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -131,6 +131,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { return DispatchKeySet(DispatchKey::IPU); case DispatchKey::AutogradXPU: return DispatchKeySet(DispatchKey::XPU); + case DispatchKey::AutogradMAIA: + return DispatchKeySet(DispatchKey::MAIA); case DispatchKey::AutogradPrivateUse1: return DispatchKeySet(DispatchKey::PrivateUse1); case DispatchKey::AutogradPrivateUse2: diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index c702737f055e..82a8d2486eb3 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -663,6 +663,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ DispatchKey::AutocastXLA, DispatchKey::AutocastPrivateUse1, DispatchKey::AutocastMTIA, + DispatchKey::AutocastMAIA, }); // See Note [TLS Initialization] @@ -681,6 +682,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ DispatchKey::AutocastXLA, DispatchKey::AutocastPrivateUse1, DispatchKey::AutocastMTIA, + DispatchKey::AutocastMAIA, }); constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = @@ -706,7 +708,6 @@ constexpr DispatchKeySet autogradother_backends = // Technically, HIP will now redispatch to its own custom AutogradHIP // slot in the runtime table. {DispatchKey::FPGA, - DispatchKey::MAIA, DispatchKey::Vulkan, DispatchKey::Metal, DispatchKey::CustomRNGKeyId, @@ -756,6 +757,7 @@ constexpr auto inplace_or_view_ks = constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU); constexpr auto autograd_mtia_ks = DispatchKeySet(DispatchKey::AutogradMTIA); +constexpr auto autograd_maia_ks = DispatchKeySet(DispatchKey::AutogradMAIA); constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); @@ -835,6 +837,8 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { return inplace_or_view_ks | autograd_ipu_ks; case BackendComponent::MTIABit: return inplace_or_view_ks | autograd_mtia_ks; + case BackendComponent::MAIABit: + return inplace_or_view_ks | autograd_maia_ks; case BackendComponent::XPUBit: return inplace_or_view_ks | autograd_xpu_ks; case BackendComponent::CUDABit: @@ -864,6 +868,7 @@ inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); constexpr auto autocast_mtia_ks = DispatchKeySet(DispatchKey::AutocastMTIA); + constexpr auto autocast_maia_ks = DispatchKeySet(DispatchKey::AutocastMAIA); constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU); constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU); constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU); @@ -877,6 +882,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { return autocast_cpu_ks; case BackendComponent::MTIABit: return autocast_mtia_ks; + case BackendComponent::MAIABit: + return autocast_maia_ks; case BackendComponent::XPUBit: return autocast_xpu_ks; case BackendComponent::IPUBit: diff --git a/test/cpp_extensions/maia_extension.cpp b/test/cpp_extensions/maia_extension.cpp index 2b8c001c0ab2..d6a4a78015e0 100644 --- a/test/cpp_extensions/maia_extension.cpp +++ b/test/cpp_extensions/maia_extension.cpp @@ -52,11 +52,28 @@ std::tuple fake_convolution_backward( get_tensor(input.dtype(), {})); } +at::Tensor maia_to_dtype_override( + const at::Tensor & self, at::ScalarType dtype, bool non_blocking, + bool copy, ::std::optional memory_format +) { + return get_tensor(scalarTypeToTypeMeta(dtype), self.sizes()); +} + +at::Tensor maia_matmul_override(const at::Tensor & self, const at::Tensor & other) { + AT_ASSERT(self.dim() == 2); + AT_ASSERT(other.dim() == 2); + AT_ASSERT(self.dtype() == other.dtype()); + AT_ASSERT(self.device() == other.device()); + return get_tensor(self.dtype(), {self.size(0), other.size(1)}); +} + TORCH_LIBRARY_IMPL(aten, MAIA, m) { m.impl("empty.memory_format", empty_override); m.impl("add.out", add_out_override); m.impl("convolution_overrideable", fake_convolution); m.impl("convolution_backward_overrideable", fake_convolution_backward); + m.impl("to.dtype", maia_to_dtype_override); + m.impl("matmul", maia_matmul_override); } // TODO: Extend this to exercise multi-device setting. In that case, diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 36752a588d53..c9c3129adfad 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -458,6 +458,38 @@ class TestMAIATensor(common.TestCase): self.assertEqual(maia_extension.get_test_int(), 3) self.assertEqual(grad[0].shape, input.shape) + def test_autocast_apis_for_maia_device(self): + # Default low-precision type in MAIA's autocast. + fast_dtype = torch.get_autocast_dtype("maia") + self.assertEqual(fast_dtype, torch.bfloat16) + self.assertTrue(torch._C._is_autocast_available("maia")) + + @skipIfTorchDynamo( + "dynamo cannot handle maia device. Output tensor may have wrong dtype." + ) + def test_matmul_autocast_float16_precision(self): + # Ensure we can change low precision dtype. + x = torch.empty((2, 4), dtype=torch.float, device="maia") + w = torch.empty((4, 2), dtype=torch.float, device="maia") + with torch.autocast(device_type="maia", dtype=torch.float16): + self.assertTrue(torch.is_autocast_enabled("maia")) + y = torch.ops.aten.matmul(x, w) + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.shape, (2, 2)) + + @skipIfTorchDynamo( + "dynamo cannot handle maia device. Output tensor may have wrong dtype." + ) + def test_matmul_autocast_default_precision(self): + # Use default lower precision dtype, bfloat16. + x = torch.empty((2, 4), dtype=torch.float, device="maia") + w = torch.empty((4, 2), dtype=torch.float, device="maia") + with torch.autocast(device_type="maia"): + self.assertTrue(torch.is_autocast_enabled("maia")) + y = torch.ops.aten.matmul(x, w) + self.assertEqual(y.dtype, torch.bfloat16) + self.assertEqual(y.shape, (2, 2)) + @torch.testing._internal.common_utils.markDynamoStrictTest class TestRNGExtension(common.TestCase): diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 7e6c61c2660d..b5dff9eb6c44 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -30,7 +30,7 @@ def is_autocast_available(device_type: str) -> bool: Return a bool indicating if autocast is available on :attr:`device_type`. Args: - device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu' and so on. + device_type(str): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and so on. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. """ @@ -202,7 +202,7 @@ class autocast: (see :ref:`Working with Multiple GPUs`). Args: - device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'xpu', and 'hpu'. + device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'mtia', 'maia', 'xpu', and 'hpu'. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. enabled(bool, optional): Whether autocasting should be enabled in the region. @@ -289,6 +289,13 @@ class autocast: error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." warnings.warn(error_message) enabled = False + elif self.device == "maia": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False elif self.device == "xpu": supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: @@ -480,7 +487,7 @@ def custom_fwd( See the :ref:`example page` for more detail. Args: - device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on. + device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, @@ -534,7 +541,7 @@ def custom_bwd(bwd=None, *, device_type: str): See the :ref:`example page` for more detail. Args: - device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'xpu' and so on. + device_type(str): Device type to use. 'cuda', 'cpu', 'mtia', 'maia', 'xpu' and so on. The type is the same as the `type` attribute of a :class:`torch.device`. Thus, you may obtain the device type of a tensor using `Tensor.device.type`. """