mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable Extension Support (#142028)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142028 Approved by: https://github.com/ezyang, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
38d10a1b17
commit
0582b32f6c
@ -1991,10 +1991,12 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
|||||||
('Ampere', '8.0;8.6+PTX'),
|
('Ampere', '8.0;8.6+PTX'),
|
||||||
('Ada', '8.9+PTX'),
|
('Ada', '8.9+PTX'),
|
||||||
('Hopper', '9.0+PTX'),
|
('Hopper', '9.0+PTX'),
|
||||||
|
('Blackwell', '10.0+PTX'),
|
||||||
])
|
])
|
||||||
|
|
||||||
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
|
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
|
||||||
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a']
|
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a'
|
||||||
|
'10.0']
|
||||||
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
|
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
|
||||||
|
|
||||||
# The default is sm_30 for CUDA 9.x and 10.x
|
# The default is sm_30 for CUDA 9.x and 10.x
|
||||||
@ -2040,7 +2042,10 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
|||||||
if arch not in valid_arch_strings:
|
if arch not in valid_arch_strings:
|
||||||
raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
|
raise ValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")
|
||||||
else:
|
else:
|
||||||
num = arch[0] + arch[2:].split("+")[0]
|
# Handle both single and double-digit architecture versions
|
||||||
|
version = arch.split('+')[0] # Remove "+PTX" if present
|
||||||
|
major, minor = version.split('.')
|
||||||
|
num = f"{major}{minor}"
|
||||||
flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
|
flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')
|
||||||
if arch.endswith('+PTX'):
|
if arch.endswith('+PTX'):
|
||||||
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
|
flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')
|
||||||
|
Reference in New Issue
Block a user