mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Improve CUDA extension building error/warning messages (#59665)
Summary: See https://github.com/pytorch/pytorch/issues/55267 Pull Request resolved: https://github.com/pytorch/pytorch/pull/59665 Reviewed By: mruberry Differential Revision: D29462248 Pulled By: ezyang fbshipit-source-id: 9de13a284a14a7cd24200b9684151ce652e1eb1e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
12b63f4046
commit
d46eb77b04
@ -20,7 +20,7 @@ from .hipify.hipify_python import get_hip_file_path, GeneratedFileCleaner
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from pkg_resources import packaging # type: ignore[attr-defined]
|
||||
from pkg_resources import packaging, parse_version # type: ignore[attr-defined]
|
||||
|
||||
IS_WINDOWS = sys.platform == 'win32'
|
||||
LIB_EXT = '.pyd' if IS_WINDOWS else '.so'
|
||||
@ -153,6 +153,15 @@ with compiling PyTorch from source.
|
||||
|
||||
!! WARNING !!
|
||||
'''
|
||||
CUDA_MISMATCH_MESSAGE = '''
|
||||
The detected CUDA version ({0}) mismatches the version that was used to compile
|
||||
PyTorch ({1}). Please make sure to use the same CUDA versions.
|
||||
'''
|
||||
CUDA_MISMATCH_WARN = "The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem."
|
||||
CUDA_NOT_FOUND_MESSAGE = '''
|
||||
CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATH
|
||||
environment variable or add NVCC to your system PATH. The extension compilation will fail.
|
||||
'''
|
||||
ROCM_HOME = _find_rocm_home()
|
||||
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
|
||||
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
|
||||
@ -377,6 +386,21 @@ class BuildExtension(build_ext, object):
|
||||
|
||||
def build_extensions(self) -> None:
|
||||
self._check_abi()
|
||||
|
||||
cuda_ext = False
|
||||
extension_iter = iter(self.extensions)
|
||||
extension = next(extension_iter, None)
|
||||
while not cuda_ext and extension:
|
||||
for source in extension.sources:
|
||||
_, ext = os.path.splitext(source)
|
||||
if ext == '.cu':
|
||||
cuda_ext = True
|
||||
break
|
||||
extension = next(extension_iter, None)
|
||||
|
||||
if cuda_ext and not IS_HIP_EXTENSION:
|
||||
self._check_cuda_version()
|
||||
|
||||
for extension in self.extensions:
|
||||
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
|
||||
# extra_compile_args is a dict. Otherwise, default torch flags do
|
||||
@ -740,6 +764,24 @@ class BuildExtension(build_ext, object):
|
||||
'Please set `DISTUTILS_USE_SDK=1` and try again.')
|
||||
raise UserWarning(msg)
|
||||
|
||||
def _check_cuda_version(self):
|
||||
if CUDA_HOME:
|
||||
nvcc = os.path.join(CUDA_HOME, 'bin', 'nvcc')
|
||||
cuda_version_str = subprocess.check_output([nvcc, '--version']).strip().decode()
|
||||
cuda_version = re.search(r'release (\d+[.]\d+)', cuda_version_str)
|
||||
if cuda_version is not None:
|
||||
cuda_str_version = cuda_version.group(1)
|
||||
cuda_ver = parse_version(cuda_str_version)
|
||||
torch_cuda_version = parse_version(torch.version.cuda) # type: ignore[arg-type]
|
||||
if cuda_ver.major != torch_cuda_version.major: # type: ignore[attr-defined]
|
||||
raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(
|
||||
cuda_str_version, torch.version.cuda))
|
||||
elif cuda_ver.minor != torch_cuda_version.minor: # type: ignore[attr-defined]
|
||||
warnings.warn(CUDA_MISMATCH_WARN.format(
|
||||
cuda_str_version, torch.version.cuda))
|
||||
else:
|
||||
raise RuntimeError(CUDA_NOT_FOUND_MESSAGE)
|
||||
|
||||
def _add_compile_flag(self, extension, flag):
|
||||
extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args)
|
||||
if isinstance(extension.extra_compile_args, dict):
|
||||
|
Reference in New Issue
Block a user