Add cpu accelerator fp16 dtype support (#7207)

Add cpu accelerator fp16 dtype support

---------

Signed-off-by: Lai, Yejing <yejing.lai@intel.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
This commit is contained in:
Yejing Lai
2025-04-22 03:21:37 +08:00
committed by GitHub
parent 9c9d32c2ca
commit d79bd930d6

View File

@ -229,10 +229,17 @@ class CPU_Accelerator(DeepSpeedAccelerator):
return True
def is_fp16_supported(self):
return False
try:
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
return True
except:
return False
def supported_dtypes(self):
return [torch.float, torch.bfloat16]
supported_dtypes = [torch.float, torch.bfloat16]
if self.is_fp16_supported():
supported_dtypes.append(torch.float16)
return supported_dtypes
# Graph operations
def create_graph(self):