Allow NVIDIA Blackwell (#6991)

NVIDIA Blackwell GPU generation has number 10. The SM code and
architecture should be `100`, but the current code generates `1.`,
because it expects a 2 characters string.

This change modifies the logic to consider it as a string that contains
a `.`, hence splits the string and uses the array of strings.

Signed-off-by: Fabien Dupont <fdupont@redhat.com>
This commit is contained in:
Fabien Dupont
2025-02-04 18:47:44 +01:00
committed by GitHub
parent 241bffded3
commit f4caed6d4f
7 changed files with 11 additions and 11 deletions

View File

@ -612,8 +612,8 @@ class CUDAOpBuilder(OpBuilder):
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
@ -651,9 +651,9 @@ class CUDAOpBuilder(OpBuilder):
args = []
self.enable_bf16 = True
for cc in ccs:
num = cc[0] + cc[2]
num = cc[0] + cc[1].split('+')[0]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc.endswith('+PTX'):
if cc[1].endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
if int(cc[0]) <= 7:
@ -666,7 +666,7 @@ class CUDAOpBuilder(OpBuilder):
Prune any compute capabilities that are not compatible with the builder. Should log
which CCs have been pruned.
"""
return ccs
return [cc.split('.') for cc in ccs]
def version_dependent_macros(self):
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456

View File

@ -78,7 +78,7 @@ class FPQuantizerBuilder(CUDAOpBuilder):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 8:
ccs_retained.append(cc)
else:

View File

@ -46,7 +46,7 @@ class InferenceCoreBuilder(CUDAOpBuilder):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else:

View File

@ -45,7 +45,7 @@ class InferenceCutlassBuilder(CUDAOpBuilder):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 8:
# Only support Ampere and newer
ccs_retained.append(cc)

View File

@ -46,7 +46,7 @@ class RaggedOpsBuilder(CUDAOpBuilder):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 8:
# Blocked flash has a dependency on Ampere + newer
ccs_retained.append(cc)

View File

@ -46,7 +46,7 @@ class RaggedUtilsBuilder(CUDAOpBuilder):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else:

View File

@ -44,7 +44,7 @@ class InferenceBuilder(CUDAOpBuilder):
def filter_ccs(self, ccs):
ccs_retained = []
ccs_pruned = []
for cc in ccs:
for cc in [cc.split('.') for cc in ccs]:
if int(cc[0]) >= 6:
ccs_retained.append(cc)
else: