[Inductor] Generalize is_cuda to specific device_type to make cpp_wrapper mode be extensible (#134693)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134693
Approved by: https://github.com/ezyang, https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
xinan.lin
2024-09-09 23:40:32 -07:00
committed by PyTorch MergeBot
parent 6e13f5eb38
commit 67735d1ee8
8 changed files with 156 additions and 93 deletions

View File

@ -140,6 +140,21 @@ def _find_rocm_home() -> Optional[str]:
file=sys.stderr)
return rocm_home
def _find_sycl_home() -> Optional[str]:
"""Find the OneAPI install path."""
# Guess #1
sycl_home = os.environ.get('ONEAPI_ROOT')
if sycl_home is None:
# Guess #2
icpx_path = shutil.which('icpx')
if icpx_path is not None:
sycl_home = os.path.dirname(os.path.dirname(
os.path.realpath(icpx_path)))
if sycl_home and not torch.xpu.is_available():
print(f"No XPU runtime is found, using ONEAPI_ROOT='{sycl_home}'",
file=sys.stderr)
return sycl_home
def _join_rocm_home(*paths) -> str:
"""
@ -156,6 +171,20 @@ def _join_rocm_home(*paths) -> str:
'ROCm and Windows is not supported.')
return os.path.join(ROCM_HOME, *paths)
def _join_sycl_home(*paths) -> str:
"""
Join paths with SYCL_HOME, or raises an error if it SYCL_HOME is not set.
This is basically a lazy way of raising an error for missing $SYCL_HOME
only once we need to get any SYCL-specific path.
"""
if SYCL_HOME is None:
raise OSError('SYCL_HOME environment variable is not set. '
'Please set it to your OneAPI install root.')
return os.path.join(SYCL_HOME, *paths)
ABI_INCOMPATIBILITY_WARNING = '''
@ -207,6 +236,8 @@ if torch.version.hip is not None:
CUDA_HOME = _find_cuda_home() if torch.cuda._is_compiled() else None
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
SYCL_HOME = _find_sycl_home() if torch.xpu._is_compiled() else None
# PyTorch releases have the version pattern major.minor.patch, whereas when
# PyTorch is built from source, we append the git commit hash, which gives
# it the below pattern.
@ -1075,7 +1106,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
... 'nvcc': ['-O2', '-rdc=true']})
"""
library_dirs = kwargs.get('library_dirs', [])
library_dirs += library_paths(cuda=True)
library_dirs += library_paths(device_type="cuda")
kwargs['library_dirs'] = library_dirs
libraries = kwargs.get('libraries', [])
@ -1119,7 +1150,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
sources = list(hipified_sources)
include_dirs += include_paths(cuda=True)
include_dirs += include_paths(device_type="cuda")
kwargs['include_dirs'] = include_dirs
kwargs['language'] = 'c++'
@ -1144,9 +1175,9 @@ def CUDAExtension(name, sources, *args, **kwargs):
return setuptools.Extension(name, sources, *args, **kwargs)
def include_paths(cuda: bool = False) -> List[str]:
def include_paths(device_type: str = "cpu") -> List[str]:
"""
Get the include paths required to build a C++ or CUDA extension.
Get the include paths required to build a C++ or CUDA or SYCL extension.
Args:
cuda: If `True`, includes CUDA-specific include paths.
@ -1164,10 +1195,10 @@ def include_paths(cuda: bool = False) -> List[str]:
os.path.join(lib_include, 'TH'),
os.path.join(lib_include, 'THC')
]
if cuda and IS_HIP_EXTENSION:
if device_type == "cuda" and IS_HIP_EXTENSION:
paths.append(os.path.join(lib_include, 'THH'))
paths.append(_join_rocm_home('include'))
elif cuda:
elif device_type == "cuda":
cuda_home_include = _join_cuda_home('include')
# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.
# but gcc doesn't like having /usr/include passed explicitly
@ -1180,10 +1211,12 @@ def include_paths(cuda: bool = False) -> List[str]:
paths.append(cuda_inc_path)
if CUDNN_HOME is not None:
paths.append(os.path.join(CUDNN_HOME, 'include'))
elif device_type == "xpu":
paths.append(_join_sycl_home('include'))
return paths
def library_paths(cuda: bool = False) -> List[str]:
def library_paths(device_type: str = "cpu") -> List[str]:
"""
Get the library paths required to build a C++ or CUDA extension.
@ -1196,12 +1229,12 @@ def library_paths(cuda: bool = False) -> List[str]:
# We need to link against libtorch.so
paths = [TORCH_LIB_PATH]
if cuda and IS_HIP_EXTENSION:
if device_type == "cuda" and IS_HIP_EXTENSION:
lib_dir = 'lib'
paths.append(_join_rocm_home(lib_dir))
if HIP_HOME is not None:
paths.append(os.path.join(HIP_HOME, 'lib'))
elif cuda:
elif device_type == "cuda":
if IS_WINDOWS:
lib_dir = os.path.join('lib', 'x64')
else:
@ -1216,6 +1249,17 @@ def library_paths(cuda: bool = False) -> List[str]:
paths.append(_join_cuda_home(lib_dir))
if CUDNN_HOME is not None:
paths.append(os.path.join(CUDNN_HOME, lib_dir))
elif device_type == "xpu":
if IS_WINDOWS:
lib_dir = os.path.join('lib', 'x64')
else:
lib_dir = 'lib64'
if (not os.path.exists(_join_sycl_home(lib_dir)) and
os.path.exists(_join_sycl_home('lib'))):
lib_dir = 'lib'
paths.append(_join_sycl_home(lib_dir))
return paths
@ -2165,7 +2209,11 @@ def _write_ninja_file_to_build_library(path,
user_includes = [os.path.abspath(file) for file in extra_include_paths]
# include_paths() gives us the location of torch/extension.h
system_includes = include_paths(with_cuda)
# TODO generalize with_cuda as specific device type.
if with_cuda:
system_includes = include_paths("cuda")
else:
system_includes = include_paths("cpu")
# sysconfig.get_path('include') gives us the location of Python.h
# Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS
# installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folder