Add support for XPU accumulate type (#128579)

Provide an accumulate type interface specifically for XPU, similar to what was done for MPS.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128579
Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
Yutao Xu
2024-07-17 14:33:53 +00:00
committed by PyTorch MergeBot
parent 76169cf691
commit 32995dec28

View File

@ -82,6 +82,7 @@ using acc_type = typename AccumulateType<T, is_cuda>::type;
using type = acc_t; \ using type = acc_t; \
}; };
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) #define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) #define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) #define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
@ -104,6 +105,25 @@ MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>); MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>); MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
XPU_ACC_TYPE(BFloat16, float);
XPU_ACC_TYPE(Half, float);
XPU_ACC_TYPE(Float8_e5m2, float);
XPU_ACC_TYPE(Float8_e4m3fn, float);
XPU_ACC_TYPE(Float8_e5m2fnuz, float);
XPU_ACC_TYPE(Float8_e4m3fnuz, float);
XPU_ACC_TYPE(float, float);
XPU_ACC_TYPE(double, double);
XPU_ACC_TYPE(int8_t, int64_t);
XPU_ACC_TYPE(uint8_t, int64_t);
XPU_ACC_TYPE(char, int64_t);
XPU_ACC_TYPE(int16_t, int64_t);
XPU_ACC_TYPE(int32_t, int64_t);
XPU_ACC_TYPE(int64_t, int64_t);
XPU_ACC_TYPE(bool, bool);
XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
#if defined(__CUDACC__) || defined(__HIPCC__) #if defined(__CUDACC__) || defined(__HIPCC__)
CUDA_ACC_TYPE(half, float); CUDA_ACC_TYPE(half, float);
#endif #endif