mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Correct torch.xpu.is_bf16_supported return False if no XPU detected (#152317)
# Motivation Fix https://github.com/pytorch/pytorch/issues/152301 When XPU is not available, calling `torch.xpu.is_bf16_supported()` still returns `True`, which is inconsistent with the expected behavior (should be False). # Solution Align to other backend, adding `including_emulation` to `torch.xpu.is_bf16_supported` and, - return `False` if XPU is not available - return `True` if `including_emulation` is True - return `torch.xpu.get_device_properties().has_bfloat16_conversions` if `including_emulation` is False, it means if the device could generate SPIRV code for bf16. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152317 Approved by: https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
8904ba6387
commit
e32a16a9da
@ -22,7 +22,6 @@ from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_LINUX,
|
||||
IS_WINDOWS,
|
||||
NoTest,
|
||||
run_tests,
|
||||
suppress_warnings,
|
||||
TEST_XPU,
|
||||
@ -31,10 +30,6 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
|
||||
|
||||
if not TEST_XPU:
|
||||
print("XPU not available, skipping tests", file=sys.stderr)
|
||||
TestCase = NoTest # noqa: F811
|
||||
|
||||
TEST_MULTIXPU = torch.xpu.device_count() > 1
|
||||
|
||||
cpu_device = torch.device("cpu")
|
||||
@ -74,6 +69,7 @@ _xpu_computation_ops = [
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpu(TestCase):
|
||||
def test_device_behavior(self):
|
||||
current_device = torch.xpu.current_device()
|
||||
@ -581,6 +577,7 @@ if __name__ == "__main__":
|
||||
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpuAutocast(TestAutocast):
|
||||
# These operators are not implemented on XPU backend and we can NOT fall back
|
||||
# them to CPU. So we have to skip them at this moment.
|
||||
@ -661,6 +658,7 @@ class TestXpuAutocast(TestAutocast):
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
|
||||
class TestXpuTrace(TestCase):
|
||||
def setUp(self):
|
||||
torch._C._activate_gpu_trace()
|
||||
@ -723,5 +721,17 @@ class TestXpuTrace(TestCase):
|
||||
self.mock.assert_called_once_with(event._as_parameter_.value)
|
||||
|
||||
|
||||
class TestXPUAPISanity(TestCase):
|
||||
def test_is_bf16_supported(self):
|
||||
self.assertEqual(
|
||||
torch.xpu.is_bf16_supported(including_emulation=True),
|
||||
torch.xpu.is_available(),
|
||||
)
|
||||
|
||||
def test_get_arch_list(self):
|
||||
if not torch.xpu._is_compiled():
|
||||
self.assertEqual(len(torch.xpu.get_arch_list()), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -66,9 +66,14 @@ def is_available() -> bool:
|
||||
return device_count() > 0
|
||||
|
||||
|
||||
def is_bf16_supported():
|
||||
def is_bf16_supported(including_emulation: bool = True) -> bool:
|
||||
r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
|
||||
return True
|
||||
if not is_available():
|
||||
return False
|
||||
return (
|
||||
including_emulation
|
||||
or torch.xpu.get_device_properties().has_bfloat16_conversions
|
||||
)
|
||||
|
||||
|
||||
def is_initialized():
|
||||
|
Reference in New Issue
Block a user