Hopper (sm90) support (#87736)

Essentially a followup of #87436

CC @xwang233 @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87736
Approved by: https://github.com/xwang233, https://github.com/malfet
This commit is contained in:
Eddie Yan
2022-11-09 01:49:50 +00:00
committed by PyTorch MergeBot
parent 19d7941e37
commit a7420d2ccb
2 changed files with 14 additions and 2 deletions

View File

@ -1730,10 +1730,11 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
('Turing', '7.5+PTX'),
('Ampere', '8.0;8.6+PTX'),
('Ada', '8.9+PTX'),
('Hopper', '9.0+PTX'),
])
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
'7.0', '7.2', '7.5', '8.0', '8.6', '8.9']
'7.0', '7.2', '7.5', '8.0', '8.6', '8.9', '9.0']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
# The default is sm_30 for CUDA 9.x and 10.x