mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Simplify BFLOAT16_AVAILABLE (#163445)
Simplify `BFLOAT16_AVAILABLE` by using `torch.cuda.is_bf16_supported()` and `torch.xpu.is_bf16_supported()`. Outdated comments are also removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163445 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
edafc902d7
commit
96a3afb8ec
@ -34,11 +34,7 @@ device_type = (
|
||||
acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
|
||||
)
|
||||
|
||||
# bfloat16 is only supported by CUDA 11+ or XPU
|
||||
BFLOAT16_AVAILABLE = (
|
||||
torch.cuda.is_available()
|
||||
and (torch.version.cuda is not None or torch.version.hip is not None)
|
||||
) or torch.xpu.is_available()
|
||||
BFLOAT16_AVAILABLE = torch.cuda.is_bf16_supported() or torch.xpu.is_bf16_supported()
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
@ -83,7 +83,6 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
# bfloat16 is only supported by CUDA 11+
|
||||
BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
torch.version.cuda is not None or torch.version.hip is not None
|
||||
)
|
||||
|
Reference in New Issue
Block a user