Add a new API torch.xpu.is_tf32_supported for Intel GPU (#163141)

# Motivation
Aligned with other backends, this PR introduces a new API `torch.xpu.is_tf32_supported`, which should be used before `torch.backends.mkldnn.allow_tf32=True` or provide hardware capability information to the Triton

# Additional Context
On Intel Xe architecture and newer, TF32 operations can be accelerated through DPAS (Dot Product Accumulate Systolic) instructions. Therefore, TF32 support can be determined by checking whether the device supports subgroup matrix multiply-accumulate operations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163141
Approved by: https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2025-10-12 09:03:25 +00:00
committed by PyTorch MergeBot
parent 5dbca58bd0
commit 3a110c9bb2
3 changed files with 17 additions and 0 deletions

View File

@ -28,6 +28,7 @@
is_available is_available
is_bf16_supported is_bf16_supported
is_initialized is_initialized
is_tf32_supported
set_device set_device
set_stream set_stream
stream stream

View File

@ -776,6 +776,10 @@ class TestXPUAPISanity(TestCase):
torch.xpu.is_available(), torch.xpu.is_available(),
) )
def test_is_tf32_supported(self):
if not torch.xpu.is_available():
self.assertFalse(torch.xpu.is_tf32_supported())
def test_get_arch_list(self): def test_get_arch_list(self):
if not torch.xpu._is_compiled(): if not torch.xpu._is_compiled():
self.assertEqual(len(torch.xpu.get_arch_list()), 0) self.assertEqual(len(torch.xpu.get_arch_list()), 0)

View File

@ -78,6 +78,17 @@ def is_bf16_supported(including_emulation: bool = True) -> bool:
) )
def is_tf32_supported() -> bool:
r"""Return a bool indicating if the current XPU device supports dtype tf32."""
if not is_available():
return False
# On Intel Xe architecture and newer, TF32 operations can be accelerated
# through DPAS (Dot Product Accumulate Systolic) instructions. Therefore,
# TF32 support can be determined by checking whether the device supports
# subgroup matrix multiply-accumulate operations.
return torch.xpu.get_device_properties().has_subgroup_matrix_multiply_accumulate
def is_initialized(): def is_initialized():
r"""Return whether PyTorch's XPU state has been initialized.""" r"""Return whether PyTorch's XPU state has been initialized."""
return _initialized and not _is_in_bad_fork() return _initialized and not _is_in_bad_fork()
@ -559,6 +570,7 @@ __all__ = [
"is_available", "is_available",
"is_bf16_supported", "is_bf16_supported",
"is_initialized", "is_initialized",
"is_tf32_supported",
"manual_seed", "manual_seed",
"manual_seed_all", "manual_seed_all",
"max_memory_allocated", "max_memory_allocated",