Correct torch.xpu.is_bf16_supported return False if no XPU detected (#152317)

# Motivation
Fix https://github.com/pytorch/pytorch/issues/152301
When XPU is not available, calling `torch.xpu.is_bf16_supported()` still returns `True`, which is inconsistent with the expected behavior (should be False).

# Solution
Align to other backend, adding `including_emulation` to `torch.xpu.is_bf16_supported` and,
- return `False` if XPU is not available
- return `True` if `including_emulation` is True
- return `torch.xpu.get_device_properties().has_bfloat16_conversions` if `including_emulation` is False, it means if the device could generate SPIRV code for bf16.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152317
Approved by: https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2025-05-06 05:19:50 +00:00
committed by PyTorch MergeBot
parent 8904ba6387
commit e32a16a9da
2 changed files with 22 additions and 7 deletions

View File

@ -22,7 +22,6 @@ from torch.testing._internal.common_utils import (
find_library_location,
IS_LINUX,
IS_WINDOWS,
NoTest,
run_tests,
suppress_warnings,
TEST_XPU,
@ -31,10 +30,6 @@ from torch.testing._internal.common_utils import (
from torch.utils.checkpoint import checkpoint_sequential
if not TEST_XPU:
print("XPU not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
TEST_MULTIXPU = torch.xpu.device_count() > 1
cpu_device = torch.device("cpu")
@ -74,6 +69,7 @@ _xpu_computation_ops = [
]
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
class TestXpu(TestCase):
def test_device_behavior(self):
current_device = torch.xpu.current_device()
@ -581,6 +577,7 @@ if __name__ == "__main__":
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
class TestXpuAutocast(TestAutocast):
# These operators are not implemented on XPU backend and we can NOT fall back
# them to CPU. So we have to skip them at this moment.
@ -661,6 +658,7 @@ class TestXpuAutocast(TestAutocast):
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
class TestXpuTrace(TestCase):
def setUp(self):
torch._C._activate_gpu_trace()
@ -723,5 +721,17 @@ class TestXpuTrace(TestCase):
self.mock.assert_called_once_with(event._as_parameter_.value)
class TestXPUAPISanity(TestCase):
def test_is_bf16_supported(self):
self.assertEqual(
torch.xpu.is_bf16_supported(including_emulation=True),
torch.xpu.is_available(),
)
def test_get_arch_list(self):
if not torch.xpu._is_compiled():
self.assertEqual(len(torch.xpu.get_arch_list()), 0)
if __name__ == "__main__":
run_tests()

View File

@ -66,9 +66,14 @@ def is_available() -> bool:
return device_count() > 0
def is_bf16_supported():
def is_bf16_supported(including_emulation: bool = True) -> bool:
r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
return True
if not is_available():
return False
return (
including_emulation
or torch.xpu.get_device_properties().has_bfloat16_conversions
)
def is_initialized():