mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[Distributed][CI] Add SM guard for compiled tests involving BF16 (#138245)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138245 Approved by: https://github.com/yf225
This commit is contained in:
@ -353,6 +353,22 @@ def skip_if_win32():
|
||||
)
|
||||
|
||||
|
||||
def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool:
|
||||
"""
|
||||
Returns True if the device's compute capability is (major, minor) or higher.
|
||||
Error out if the device is not a CUDA device.
|
||||
Returns False if device is a RoCM device.
|
||||
"""
|
||||
if device.type != "cuda":
|
||||
raise ValueError("sm_is_or_later() is only supported for CUDA devices")
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# ROCm devices may have different compute capability codes
|
||||
return False
|
||||
|
||||
return torch.cuda.get_device_capability(device) >= (major, minor)
|
||||
|
||||
|
||||
@retry_on_connect_failures
|
||||
def create_tcp_store(
|
||||
addr="localhost",
|
||||
|
Reference in New Issue
Block a user