mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Generalize is_cuda
to specific device_type to make cpp_wrapper mode be extensible (#134693)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134693 Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6e13f5eb38
commit
67735d1ee8
@ -140,6 +140,21 @@ def _find_rocm_home() -> Optional[str]:
|
||||
file=sys.stderr)
|
||||
return rocm_home
|
||||
|
||||
def _find_sycl_home() -> Optional[str]:
|
||||
"""Find the OneAPI install path."""
|
||||
# Guess #1
|
||||
sycl_home = os.environ.get('ONEAPI_ROOT')
|
||||
if sycl_home is None:
|
||||
# Guess #2
|
||||
icpx_path = shutil.which('icpx')
|
||||
if icpx_path is not None:
|
||||
sycl_home = os.path.dirname(os.path.dirname(
|
||||
os.path.realpath(icpx_path)))
|
||||
|
||||
if sycl_home and not torch.xpu.is_available():
|
||||
print(f"No XPU runtime is found, using ONEAPI_ROOT='{sycl_home}'",
|
||||
file=sys.stderr)
|
||||
return sycl_home
|
||||
|
||||
def _join_rocm_home(*paths) -> str:
|
||||
"""
|
||||
@ -156,6 +171,20 @@ def _join_rocm_home(*paths) -> str:
|
||||
'ROCm and Windows is not supported.')
|
||||
return os.path.join(ROCM_HOME, *paths)
|
||||
|
||||
def _join_sycl_home(*paths) -> str:
|
||||
"""
|
||||
Join paths with SYCL_HOME, or raises an error if it SYCL_HOME is not set.
|
||||
|
||||
This is basically a lazy way of raising an error for missing $SYCL_HOME
|
||||
only once we need to get any SYCL-specific path.
|
||||
"""
|
||||
if SYCL_HOME is None:
|
||||
raise OSError('SYCL_HOME environment variable is not set. '
|
||||
'Please set it to your OneAPI install root.')
|
||||
|
||||
return os.path.join(SYCL_HOME, *paths)
|
||||
|
||||
|
||||
|
||||
ABI_INCOMPATIBILITY_WARNING = '''
|
||||
|
||||
@ -207,6 +236,8 @@ if torch.version.hip is not None:
|
||||
|
||||
CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None
|
||||
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
|
||||
SYCL_HOME = _find_sycl_home() if torch.xpu._is_compiled() else None
|
||||
|
||||
# PyTorch releases have the version pattern major.minor.patch, whereas when
|
||||
# PyTorch is built from source, we append the git commit hash, which gives
|
||||
# it the below pattern.
|
||||
@ -1075,7 +1106,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
... 'nvcc': ['-O2', '-rdc=true']})
|
||||
"""
|
||||
library_dirs = kwargs.get('library_dirs', [])
|
||||
library_dirs += library_paths(cuda=True)
|
||||
library_dirs += library_paths(device_type="cuda")
|
||||
kwargs['library_dirs'] = library_dirs
|
||||
|
||||
libraries = kwargs.get('libraries', [])
|
||||
@ -1119,7 +1150,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
|
||||
sources = list(hipified_sources)
|
||||
|
||||
include_dirs += include_paths(cuda=True)
|
||||
include_dirs += include_paths(device_type="cuda")
|
||||
kwargs['include_dirs'] = include_dirs
|
||||
|
||||
kwargs['language'] = 'c++'
|
||||
@ -1144,9 +1175,9 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
return setuptools.Extension(name, sources, *args, **kwargs)
|
||||
|
||||
|
||||
def include_paths(cuda: bool = False) -> List[str]:
|
||||
def include_paths(device_type: str = "cpu") -> List[str]:
|
||||
"""
|
||||
Get the include paths required to build a C++ or CUDA extension.
|
||||
Get the include paths required to build a C++ or CUDA or SYCL extension.
|
||||
|
||||
Args:
|
||||
cuda: If `True`, includes CUDA-specific include paths.
|
||||
@ -1164,10 +1195,10 @@ def include_paths(cuda: bool = False) -> List[str]:
|
||||
os.path.join(lib_include, 'TH'),
|
||||
os.path.join(lib_include, 'THC')
|
||||
]
|
||||
if cuda and IS_HIP_EXTENSION:
|
||||
if device_type == "cuda" and IS_HIP_EXTENSION:
|
||||
paths.append(os.path.join(lib_include, 'THH'))
|
||||
paths.append(_join_rocm_home('include'))
|
||||
elif cuda:
|
||||
elif device_type == "cuda":
|
||||
cuda_home_include = _join_cuda_home('include')
|
||||
# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
|
||||
# but gcc doesn't like having /usr/include passed explicitly
|
||||
@ -1180,10 +1211,12 @@ def include_paths(cuda: bool = False) -> List[str]:
|
||||
paths.append(cuda_inc_path)
|
||||
if CUDNN_HOME is not None:
|
||||
paths.append(os.path.join(CUDNN_HOME, 'include'))
|
||||
elif device_type == "xpu":
|
||||
paths.append(_join_sycl_home('include'))
|
||||
return paths
|
||||
|
||||
|
||||
def library_paths(cuda: bool = False) -> List[str]:
|
||||
def library_paths(device_type: str = "cpu") -> List[str]:
|
||||
"""
|
||||
Get the library paths required to build a C++ or CUDA extension.
|
||||
|
||||
@ -1196,12 +1229,12 @@ def library_paths(cuda: bool = False) -> List[str]:
|
||||
# We need to link against libtorch.so
|
||||
paths = [TORCH_LIB_PATH]
|
||||
|
||||
if cuda and IS_HIP_EXTENSION:
|
||||
if device_type == "cuda" and IS_HIP_EXTENSION:
|
||||
lib_dir = 'lib'
|
||||
paths.append(_join_rocm_home(lib_dir))
|
||||
if HIP_HOME is not None:
|
||||
paths.append(os.path.join(HIP_HOME, 'lib'))
|
||||
elif cuda:
|
||||
elif device_type == "cuda":
|
||||
if IS_WINDOWS:
|
||||
lib_dir = os.path.join('lib', 'x64')
|
||||
else:
|
||||
@ -1216,6 +1249,17 @@ def library_paths(cuda: bool = False) -> List[str]:
|
||||
paths.append(_join_cuda_home(lib_dir))
|
||||
if CUDNN_HOME is not None:
|
||||
paths.append(os.path.join(CUDNN_HOME, lib_dir))
|
||||
elif device_type == "xpu":
|
||||
if IS_WINDOWS:
|
||||
lib_dir = os.path.join('lib', 'x64')
|
||||
else:
|
||||
lib_dir = 'lib64'
|
||||
if (not os.path.exists(_join_sycl_home(lib_dir)) and
|
||||
os.path.exists(_join_sycl_home('lib'))):
|
||||
lib_dir = 'lib'
|
||||
|
||||
paths.append(_join_sycl_home(lib_dir))
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
@ -2165,7 +2209,11 @@ def _write_ninja_file_to_build_library(path,
|
||||
user_includes = [os.path.abspath(file) for file in extra_include_paths]
|
||||
|
||||
# include_paths() gives us the location of torch/extension.h
|
||||
system_includes = include_paths(with_cuda)
|
||||
# TODO generalize with_cuda as specific device type.
|
||||
if with_cuda:
|
||||
system_includes = include_paths("cuda")
|
||||
else:
|
||||
system_includes = include_paths("cpu")
|
||||
# sysconfig.get_path('include') gives us the location of Python.h
|
||||
# Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS
|
||||
# installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder
|
||||
|
Reference in New Issue
Block a user