mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[CUDA] Compare major version of the runtime device arch against the built version of the pytorch binary (#161299)
Fixes misleading warning messages when running on sm12x devices using binaries built with sm120. PyTorch binary built with sm120 is compatible with e.g. sm121, so no need for the warning of incompatibility. Also allow the 'matched_cuda_warn' message to show when e.g. the user is running a binary built with only sm90 on sm12x, so that the user would be prompted to get a build which supports e.g. sm120. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161299 Approved by: https://github.com/eqy, https://github.com/atalman
This commit is contained in:
committed by
PyTorch MergeBot
parent
4ac4a7351e
commit
7163dce1e0
@ -270,7 +270,7 @@ def _check_capability():
|
||||
major = capability[0]
|
||||
minor = capability[1]
|
||||
name = get_device_name(d)
|
||||
current_arch = major * 10 + minor
|
||||
cur_arch_major = major * 10
|
||||
min_arch = min(
|
||||
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
|
||||
default=50,
|
||||
@ -279,7 +279,7 @@ def _check_capability():
|
||||
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()),
|
||||
default=50,
|
||||
)
|
||||
if current_arch < min_arch or current_arch > max_arch:
|
||||
if cur_arch_major < min_arch or cur_arch_major > max_arch:
|
||||
warnings.warn(
|
||||
incompatible_gpu_warn
|
||||
% (
|
||||
@ -295,10 +295,7 @@ def _check_capability():
|
||||
)
|
||||
matched_arches = ""
|
||||
for arch, arch_info in CUDA_ARCHES_SUPPORTED.items():
|
||||
if (
|
||||
current_arch >= arch_info["min"]
|
||||
and current_arch <= arch_info["max"]
|
||||
):
|
||||
if arch_info["min"] <= cur_arch_major <= arch_info["max"]:
|
||||
matched_arches += f" {arch}"
|
||||
if matched_arches != "":
|
||||
warnings.warn(matched_cuda_warn.format(matched_arches))
|
||||
|
||||
Reference in New Issue
Block a user