[CUDA] Compare major version of the runtime device arch against the built version of the pytorch binary (#161299)

Fixes misleading warning messages when running on sm12x devices using binaries built with sm120.
PyTorch binary built with sm120 is compatible with e.g. sm121, so no need for the warning of incompatibility.

Also allow the 'matched_cuda_warn' message to show when e.g. the user is running a binary built with only sm90 on sm12x, so that the user would be prompted to get a build which supports e.g. sm120.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161299
Approved by: https://github.com/eqy, https://github.com/atalman
This commit is contained in:
Wei Wang
2025-09-24 23:59:19 +00:00
committed by PyTorch MergeBot
parent 4ac4a7351e
commit 7163dce1e0

View File

@ -270,7 +270,7 @@ def _check_capability():
major = capability[0]
minor = capability[1]
name = get_device_name(d)
current_arch = major * 10 + minor
cur_arch_major = major * 10
min_arch = min(
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
default=50,
@ -279,7 +279,7 @@ def _check_capability():
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
default=50,
)
if current_arch < min_arch or current_arch > max_arch:
if cur_arch_major < min_arch or cur_arch_major > max_arch:
warnings.warn(
incompatible_gpu_warn
% (
@ -295,10 +295,7 @@ def _check_capability():
)
matched_arches = ""
for arch, arch_info in CUDA_ARCHES_SUPPORTED.items():
if (
current_arch >= arch_info["min"]
and current_arch <= arch_info["max"]
):
if arch_info["min"] <= cur_arch_major <= arch_info["max"]:
matched_arches += f" {arch}"
if matched_arches != "":
warnings.warn(matched_cuda_warn.format(matched_arches))