Improve CUDA extension building error/warning messages (#59665)

Summary:
See https://github.com/pytorch/pytorch/issues/55267

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59665

Reviewed By: mruberry

Differential Revision: D29462248

Pulled By: ezyang

fbshipit-source-id: 9de13a284a14a7cd24200b9684151ce652e1eb1e
This commit is contained in:
Edgar Andrés Margffoy Tuay
2021-06-29 13:01:42 -07:00
committed by Facebook GitHub Bot
parent 12b63f4046
commit d46eb77b04

View File

@ -20,7 +20,7 @@ from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
from typing import List, Optional, Union
from setuptools.command.build_ext import build_ext
from pkg_resources import packaging # type: ignore[attr-defined]
from pkg_resources import packaging, parse_version # type: ignore[attr-defined]
IS_WINDOWS = sys.platform == 'win32'
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
@ -153,6 +153,15 @@ with compiling PyTorch from source.
!! WARNING !!
'''
CUDA_MISMATCH_MESSAGE = '''
The detected CUDA version ({0}) mismatches the version that was used to compile
PyTorch ({1}). Please make sure to use the same CUDA versions.
'''
CUDA_MISMATCH_WARN = "The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem."
CUDA_NOT_FOUND_MESSAGE = '''
CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH
environment variable or add NVCC to your system PATH. The extension compilation will fail.
'''
ROCM_HOME = _find_rocm_home()
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
@ -377,6 +386,21 @@ class BuildExtension(build_ext, object):
def build_extensions(self) -> None:
self._check_abi()
cuda_ext = False
extension_iter = iter(self.extensions)
extension = next(extension_iter, None)
while not cuda_ext and extension:
for source in extension.sources:
_, ext = os.path.splitext(source)
if ext == '.cu':
cuda_ext = True
break
extension = next(extension_iter, None)
if cuda_ext and not IS_HIP_EXTENSION:
self._check_cuda_version()
for extension in self.extensions:
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict. Otherwise, default torch flags do
@ -740,6 +764,24 @@ class BuildExtension(build_ext, object):
'Please set `DISTUTILS_USE_SDK=1` and try again.')
raise UserWarning(msg)
def _check_cuda_version(self):
if CUDA_HOME:
nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc')
cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode()
cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str)
if cuda_version is not None:
cuda_str_version = cuda_version.group(1)
cuda_ver = parse_version(cuda_str_version)
torch_cuda_version = parse_version(torch.version.cuda) # type: ignore[arg-type]
if cuda_ver.major != torch_cuda_version.major: # type: ignore[attr-defined]
raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(
cuda_str_version, torch.version.cuda))
elif cuda_ver.minor != torch_cuda_version.minor: # type: ignore[attr-defined]
warnings.warn(CUDA_MISMATCH_WARN.format(
cuda_str_version, torch.version.cuda))
else:
raise RuntimeError(CUDA_NOT_FOUND_MESSAGE)
def _add_compile_flag(self, extension, flag):
extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args)
if isinstance(extension.extra_compile_args, dict):