diff --git a/docs/source/xpu.md b/docs/source/xpu.md index 2018bc6c994f..08e0299480e4 100644 --- a/docs/source/xpu.md +++ b/docs/source/xpu.md @@ -28,6 +28,7 @@ is_available is_bf16_supported is_initialized + is_tf32_supported set_device set_stream stream diff --git a/test/test_xpu.py b/test/test_xpu.py index 3474e4031ef2..93524286d788 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -776,6 +776,10 @@ class TestXPUAPISanity(TestCase): torch.xpu.is_available(), ) + def test_is_tf32_supported(self): + if not torch.xpu.is_available(): + self.assertFalse(torch.xpu.is_tf32_supported()) + def test_get_arch_list(self): if not torch.xpu._is_compiled(): self.assertEqual(len(torch.xpu.get_arch_list()), 0) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index d1ceb8df2b00..137e960afabb 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -78,6 +78,17 @@ def is_bf16_supported(including_emulation: bool = True) -> bool: ) +def is_tf32_supported() -> bool: + r"""Return a bool indicating if the current XPU device supports dtype tf32.""" + if not is_available(): + return False + # On Intel Xe architecture and newer, TF32 operations can be accelerated + # through DPAS (Dot Product Accumulate Systolic) instructions. Therefore, + # TF32 support can be determined by checking whether the device supports + # subgroup matrix multiply-accumulate operations. + return torch.xpu.get_device_properties().has_subgroup_matrix_multiply_accumulate + + def is_initialized(): r"""Return whether PyTorch's XPU state has been initialized.""" return _initialized and not _is_in_bad_fork() @@ -559,6 +570,7 @@ __all__ = [ "is_available", "is_bf16_supported", "is_initialized", + "is_tf32_supported", "manual_seed", "manual_seed_all", "max_memory_allocated",