mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Add support for XPU accumulate type (#128579)
Provide an accumulate type interface specifically for XPU, similar to what was done for MPS. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128579 Approved by: https://github.com/EikanWang, https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							76169cf691
						
					
				
				
					commit
					32995dec28
				
			@ -82,6 +82,7 @@ using acc_type = typename AccumulateType<T, is_cuda>::type;
 | 
				
			|||||||
    using type = acc_t;                         \
 | 
					    using type = acc_t;                         \
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
 | 
					#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS)
 | 
				
			||||||
 | 
					#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU)
 | 
				
			||||||
#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
 | 
					#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA)
 | 
				
			||||||
#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
 | 
					#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -104,6 +105,25 @@ MPS_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
 | 
				
			|||||||
MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
 | 
					MPS_ACC_TYPE(c10::complex<float>, c10::complex<float>);
 | 
				
			||||||
MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
 | 
					MPS_ACC_TYPE(c10::complex<double>, c10::complex<float>);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(BFloat16, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(Half, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(Float8_e5m2, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(Float8_e4m3fn, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(Float8_e5m2fnuz, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(Float8_e4m3fnuz, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(float, float);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(double, double);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(int8_t, int64_t);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(uint8_t, int64_t);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(char, int64_t);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(int16_t, int64_t);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(int32_t, int64_t);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(int64_t, int64_t);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(bool, bool);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(c10::complex<Half>, c10::complex<float>);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(c10::complex<float>, c10::complex<float>);
 | 
				
			||||||
 | 
					XPU_ACC_TYPE(c10::complex<double>, c10::complex<double>);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#if defined(__CUDACC__) || defined(__HIPCC__)
 | 
					#if defined(__CUDACC__) || defined(__HIPCC__)
 | 
				
			||||||
CUDA_ACC_TYPE(half, float);
 | 
					CUDA_ACC_TYPE(half, float);
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user