diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index de711f731..4b3d89e6c 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -229,10 +229,17 @@ class CPU_Accelerator(DeepSpeedAccelerator): return True def is_fp16_supported(self): - return False + try: + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + return True + except: + return False def supported_dtypes(self): - return [torch.float, torch.bfloat16] + supported_dtypes = [torch.float, torch.bfloat16] + if self.is_fp16_supported(): + supported_dtypes.append(torch.float16) + return supported_dtypes # Graph operations def create_graph(self):