[Distributed][CI] Add SM guard for compiled tests involving BF16 (#138245)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138245
Approved by: https://github.com/yf225
This commit is contained in:
Ke Wen
2024-10-18 10:27:22 -07:00
committed by PyTorch MergeBot
parent 7faa1284ab
commit c88b77af9c
5 changed files with 52 additions and 5 deletions

View File

@ -353,6 +353,22 @@ def skip_if_win32():
)
def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
"""
Returns True if the device's compute capability is (major, minor) or higher.
Error out if the device is not a CUDA device.
Returns False if device is a RoCM device.
"""
if device.type != "cuda":
raise ValueError("sm_is_or_later() is only supported for CUDA devices")
if torch.version.hip is not None:
# ROCm devices may have different compute capability codes
return False
return torch.cuda.get_device_capability(device) >= (major, minor)
@retry_on_connect_failures
def create_tcp_store(
addr="localhost",