mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension. Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension. By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for. Fixes: #132944 CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945 Approved by: https://github.com/albanD, https://github.com/guangyey, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
import os
|
|
import sys
|
|
|
|
from setuptools import setup
|
|
|
|
import torch.cuda
|
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
|
from torch.utils.cpp_extension import (
|
|
BuildExtension,
|
|
CppExtension,
|
|
CUDA_HOME,
|
|
CUDAExtension,
|
|
ROCM_HOME,
|
|
SyclExtension,
|
|
)
|
|
|
|
|
|
if sys.platform == "win32":
|
|
vc_version = os.getenv("VCToolsVersion", "")
|
|
if vc_version.startswith("14.16."):
|
|
CXX_FLAGS = ["/sdl"]
|
|
else:
|
|
CXX_FLAGS = ["/sdl", "/permissive-"]
|
|
else:
|
|
CXX_FLAGS = ["-g"]
|
|
|
|
USE_NINJA = os.getenv("USE_NINJA") == "1"
|
|
|
|
ext_modules = [
|
|
CppExtension(
|
|
"torch_test_cpp_extension.cpp", ["extension.cpp"], extra_compile_args=CXX_FLAGS
|
|
),
|
|
CppExtension(
|
|
"torch_test_cpp_extension.maia",
|
|
["maia_extension.cpp"],
|
|
extra_compile_args=CXX_FLAGS,
|
|
),
|
|
CppExtension(
|
|
"torch_test_cpp_extension.rng",
|
|
["rng_extension.cpp"],
|
|
extra_compile_args=CXX_FLAGS,
|
|
),
|
|
]
|
|
|
|
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
|
extension = CUDAExtension(
|
|
"torch_test_cpp_extension.cuda",
|
|
[
|
|
"cuda_extension.cpp",
|
|
"cuda_extension_kernel.cu",
|
|
"cuda_extension_kernel2.cu",
|
|
],
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]},
|
|
)
|
|
ext_modules.append(extension)
|
|
|
|
if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None):
|
|
extension = CUDAExtension(
|
|
"torch_test_cpp_extension.torch_library",
|
|
["torch_library.cu"],
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2"]},
|
|
)
|
|
ext_modules.append(extension)
|
|
|
|
if torch.backends.mps.is_available():
|
|
extension = CppExtension(
|
|
"torch_test_cpp_extension.mps",
|
|
["mps_extension.mm"],
|
|
extra_compile_args=CXX_FLAGS,
|
|
)
|
|
ext_modules.append(extension)
|
|
|
|
if torch.xpu.is_available() and USE_NINJA:
|
|
extension = SyclExtension(
|
|
"torch_test_cpp_extension.sycl",
|
|
["xpu_extension.sycl"],
|
|
extra_compile_args={"cxx": CXX_FLAGS, "sycl": ["-O2"]},
|
|
)
|
|
ext_modules.append(extension)
|
|
|
|
|
|
# todo(mkozuki): Figure out the root cause
|
|
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
|
|
# malfet: One should not assume that PyTorch re-exports CUDA dependencies
|
|
cublas_extension = CUDAExtension(
|
|
name="torch_test_cpp_extension.cublas_extension",
|
|
sources=["cublas_extension.cpp"],
|
|
libraries=["cublas"] if torch.version.hip is None else [],
|
|
)
|
|
ext_modules.append(cublas_extension)
|
|
|
|
cusolver_extension = CUDAExtension(
|
|
name="torch_test_cpp_extension.cusolver_extension",
|
|
sources=["cusolver_extension.cpp"],
|
|
libraries=["cusolver"] if torch.version.hip is None else [],
|
|
)
|
|
ext_modules.append(cusolver_extension)
|
|
|
|
if (
|
|
USE_NINJA
|
|
and (not IS_WINDOWS)
|
|
and torch.cuda.is_available()
|
|
and CUDA_HOME is not None
|
|
):
|
|
extension = CUDAExtension(
|
|
name="torch_test_cpp_extension.cuda_dlink",
|
|
sources=[
|
|
"cuda_dlink_extension.cpp",
|
|
"cuda_dlink_extension_kernel.cu",
|
|
"cuda_dlink_extension_add.cu",
|
|
],
|
|
dlink=True,
|
|
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": ["-O2", "-dc"]},
|
|
)
|
|
ext_modules.append(extension)
|
|
|
|
setup(
|
|
name="torch_test_cpp_extension",
|
|
packages=["torch_test_cpp_extension"],
|
|
ext_modules=ext_modules,
|
|
include_dirs="self_compiler_include_dirs_test",
|
|
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=USE_NINJA)},
|
|
entry_points={
|
|
"torch.backends": [
|
|
"device_backend = torch_test_cpp_extension:_autoload",
|
|
],
|
|
},
|
|
)
|