From 32995dec28190730baf1b80eb5ffe446cd62ca04 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 17 Jul 2024 14:33:53 +0000 Subject: [PATCH] Add support for XPU accumulate type (#128579) Provide an accumulate type interface specifically for XPU, similar to what was done for MPS. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128579 Approved by: https://github.com/EikanWang, https://github.com/albanD --- aten/src/ATen/AccumulateType.h | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h index 0275ef099b03..b1f120e48176 100644 --- a/aten/src/ATen/AccumulateType.h +++ b/aten/src/ATen/AccumulateType.h @@ -82,6 +82,7 @@ using acc_type = typename AccumulateType::type; using type = acc_t; \ }; #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) +#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU) #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) @@ -104,6 +105,25 @@ MPS_ACC_TYPE(c10::complex, c10::complex); MPS_ACC_TYPE(c10::complex, c10::complex); MPS_ACC_TYPE(c10::complex, c10::complex); +XPU_ACC_TYPE(BFloat16, float); +XPU_ACC_TYPE(Half, float); +XPU_ACC_TYPE(Float8_e5m2, float); +XPU_ACC_TYPE(Float8_e4m3fn, float); +XPU_ACC_TYPE(Float8_e5m2fnuz, float); +XPU_ACC_TYPE(Float8_e4m3fnuz, float); +XPU_ACC_TYPE(float, float); +XPU_ACC_TYPE(double, double); +XPU_ACC_TYPE(int8_t, int64_t); +XPU_ACC_TYPE(uint8_t, int64_t); +XPU_ACC_TYPE(char, int64_t); +XPU_ACC_TYPE(int16_t, int64_t); +XPU_ACC_TYPE(int32_t, int64_t); +XPU_ACC_TYPE(int64_t, int64_t); +XPU_ACC_TYPE(bool, bool); +XPU_ACC_TYPE(c10::complex, c10::complex); +XPU_ACC_TYPE(c10::complex, c10::complex); +XPU_ACC_TYPE(c10::complex, c10::complex); + #if defined(__CUDACC__) || defined(__HIPCC__) CUDA_ACC_TYPE(half, float); #endif