mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Allow NVIDIA Blackwell (#6991)
NVIDIA Blackwell GPU generation has number 10. The SM code and architecture should be `100`, but the current code generates `1.`, because it expects a 2 characters string. This change modifies the logic to consider it as a string that contains a `.`, hence splits the string and uses the array of strings. Signed-off-by: Fabien Dupont <fdupont@redhat.com>
This commit is contained in:
@ -612,8 +612,8 @@ class CUDAOpBuilder(OpBuilder):
|
||||
|
||||
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
|
||||
|
||||
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
|
||||
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
|
||||
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
|
||||
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...
|
||||
|
||||
- `cross_compile_archs` uses ; separator.
|
||||
|
||||
@ -651,9 +651,9 @@ class CUDAOpBuilder(OpBuilder):
|
||||
args = []
|
||||
self.enable_bf16 = True
|
||||
for cc in ccs:
|
||||
num = cc[0] + cc[2]
|
||||
num = cc[0] + cc[1].split('+')[0]
|
||||
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
|
||||
if cc.endswith('+PTX'):
|
||||
if cc[1].endswith('+PTX'):
|
||||
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
|
||||
|
||||
if int(cc[0]) <= 7:
|
||||
@ -666,7 +666,7 @@ class CUDAOpBuilder(OpBuilder):
|
||||
Prune any compute capabilities that are not compatible with the builder. Should log
|
||||
which CCs have been pruned.
|
||||
"""
|
||||
return ccs
|
||||
return [cc.split('.') for cc in ccs]
|
||||
|
||||
def version_dependent_macros(self):
|
||||
# Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
|
||||
|
@ -78,7 +78,7 @@ class FPQuantizerBuilder(CUDAOpBuilder):
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
for cc in [cc.split('.') for cc in ccs]:
|
||||
if int(cc[0]) >= 8:
|
||||
ccs_retained.append(cc)
|
||||
else:
|
||||
|
@ -46,7 +46,7 @@ class InferenceCoreBuilder(CUDAOpBuilder):
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
for cc in [cc.split('.') for cc in ccs]:
|
||||
if int(cc[0]) >= 6:
|
||||
ccs_retained.append(cc)
|
||||
else:
|
||||
|
@ -45,7 +45,7 @@ class InferenceCutlassBuilder(CUDAOpBuilder):
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
for cc in [cc.split('.') for cc in ccs]:
|
||||
if int(cc[0]) >= 8:
|
||||
# Only support Ampere and newer
|
||||
ccs_retained.append(cc)
|
||||
|
@ -46,7 +46,7 @@ class RaggedOpsBuilder(CUDAOpBuilder):
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
for cc in [cc.split('.') for cc in ccs]:
|
||||
if int(cc[0]) >= 8:
|
||||
# Blocked flash has a dependency on Ampere + newer
|
||||
ccs_retained.append(cc)
|
||||
|
@ -46,7 +46,7 @@ class RaggedUtilsBuilder(CUDAOpBuilder):
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
for cc in [cc.split('.') for cc in ccs]:
|
||||
if int(cc[0]) >= 6:
|
||||
ccs_retained.append(cc)
|
||||
else:
|
||||
|
@ -44,7 +44,7 @@ class InferenceBuilder(CUDAOpBuilder):
|
||||
def filter_ccs(self, ccs):
|
||||
ccs_retained = []
|
||||
ccs_pruned = []
|
||||
for cc in ccs:
|
||||
for cc in [cc.split('.') for cc in ccs]:
|
||||
if int(cc[0]) >= 6:
|
||||
ccs_retained.append(cc)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user