mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add support for XPU accumulate type (#128579)
Provide an accumulate type interface specifically for XPU, similar to what was done for MPS. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128579 Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
76169cf691
commit
32995dec28
@ -82,6 +82,7 @@ using acc_type = typename AccumulateType<T, is_cuda>::type;
|
||||
using type = acc_t; \
|
||||
};
|
||||
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
|
||||
#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
|
||||
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
|
||||
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
|
||||
|
||||
@ -104,6 +105,25 @@ MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
||||
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
||||
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
|
||||
|
||||
XPU_ACC_TYPE(BFloat16, float);
|
||||
XPU_ACC_TYPE(Half, float);
|
||||
XPU_ACC_TYPE(Float8_e5m2, float);
|
||||
XPU_ACC_TYPE(Float8_e4m3fn, float);
|
||||
XPU_ACC_TYPE(Float8_e5m2fnuz, float);
|
||||
XPU_ACC_TYPE(Float8_e4m3fnuz, float);
|
||||
XPU_ACC_TYPE(float, float);
|
||||
XPU_ACC_TYPE(double, double);
|
||||
XPU_ACC_TYPE(int8_t, int64_t);
|
||||
XPU_ACC_TYPE(uint8_t, int64_t);
|
||||
XPU_ACC_TYPE(char, int64_t);
|
||||
XPU_ACC_TYPE(int16_t, int64_t);
|
||||
XPU_ACC_TYPE(int32_t, int64_t);
|
||||
XPU_ACC_TYPE(int64_t, int64_t);
|
||||
XPU_ACC_TYPE(bool, bool);
|
||||
XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
|
||||
XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
|
||||
XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
CUDA_ACC_TYPE(half, float);
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user