mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
xpu: support sycl with torch.utils.cpp_extension APIs (#132945)
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension. Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension. By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for. Fixes: #132944 CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945 Approved by: https://github.com/albanD, https://github.com/guangyey, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd5d0ea6bb
commit
d27ecf85db
@ -4,6 +4,7 @@ torch.utils.cpp_extension
|
||||
.. currentmodule:: torch.utils.cpp_extension
|
||||
.. autofunction:: CppExtension
|
||||
.. autofunction:: CUDAExtension
|
||||
.. autofunction:: SyclExtension
|
||||
.. autofunction:: BuildExtension
|
||||
.. autofunction:: load
|
||||
.. autofunction:: load_inline
|
||||
|
@ -11,6 +11,7 @@ from torch.utils.cpp_extension import (
|
||||
CUDA_HOME,
|
||||
CUDAExtension,
|
||||
ROCM_HOME,
|
||||
SyclExtension,
|
||||
)
|
||||
|
||||
|
||||
@ -69,6 +70,15 @@ if torch.backends.mps.is_available():
|
||||
)
|
||||
ext_modules.append(extension)
|
||||
|
||||
if torch.xpu.is_available() and USE_NINJA:
|
||||
extension = SyclExtension(
|
||||
"torch_test_cpp_extension.sycl",
|
||||
["xpu_extension.sycl"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "sycl": ["-O2"]},
|
||||
)
|
||||
ext_modules.append(extension)
|
||||
|
||||
|
||||
# todo(mkozuki): Figure out the root cause
|
||||
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
# malfet: One should not assume that PyTorch re-exports CUDA dependencies
|
||||
|
63
test/cpp_extensions/xpu_extension.sycl
Normal file
63
test/cpp_extensions/xpu_extension.sycl
Normal file
@ -0,0 +1,63 @@
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
void sigmoid_add_kernel(const float* x,
|
||||
const float* y,
|
||||
float* output,
|
||||
const int size,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
if (index < size) {
|
||||
const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
|
||||
const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
|
||||
output[index] = sigmoid_x + sigmoid_y;
|
||||
}
|
||||
}
|
||||
|
||||
class SigmoidAddKernel {
|
||||
public:
|
||||
void operator()(const sycl::nd_item<3> &item_ct1) const {
|
||||
sigmoid_add_kernel(x, y, output, size, item_ct1);
|
||||
}
|
||||
SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
|
||||
x(_x),
|
||||
y(_y),
|
||||
output(_output),
|
||||
size(_size)
|
||||
{}
|
||||
private:
|
||||
const float* x;
|
||||
const float* y;
|
||||
float* output;
|
||||
int size;
|
||||
};
|
||||
|
||||
void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
|
||||
SigmoidAddKernel krn(x, y, output, size);
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
|
||||
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
|
||||
queue.submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for<SigmoidAddKernel>(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
|
||||
sycl::range<3>(1, 1, threads)),
|
||||
krn);
|
||||
});
|
||||
}
|
||||
|
||||
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
||||
TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
|
||||
TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
|
||||
auto output = torch::zeros_like(x);
|
||||
sigmoid_add_xpu(
|
||||
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
|
||||
return output;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
|
||||
}
|
@ -18,6 +18,7 @@ from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
shell,
|
||||
skipIfTorchDynamo,
|
||||
TEST_XPU,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
|
||||
@ -113,6 +114,22 @@ class TestCppExtensionAOT(common.TestCase):
|
||||
|
||||
self.assertEqual(cpu_output, mps_output.to("cpu"))
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not found")
|
||||
@unittest.skipIf(
|
||||
os.getenv("USE_NINJA", "0") == "0",
|
||||
"sycl extension requires ninja to build",
|
||||
)
|
||||
def test_sycl_extension(self):
|
||||
import torch_test_cpp_extension.sycl as sycl_extension
|
||||
|
||||
x = torch.zeros(100, device="xpu", dtype=torch.float32)
|
||||
y = torch.zeros(100, device="xpu", dtype=torch.float32)
|
||||
|
||||
z = sycl_extension.sigmoid_add(x, y).cpu()
|
||||
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
|
@ -17,7 +17,7 @@ import torch.multiprocessing as mp
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
|
||||
from torch.testing._internal.common_utils import gradcheck
|
||||
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
|
||||
from torch.utils.cpp_extension import (
|
||||
_TORCH_PATH,
|
||||
check_compiler_is_gcc,
|
||||
@ -116,6 +116,26 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_extension(self):
|
||||
# NOTE: The name of the extension must equal the name of the module.
|
||||
module = torch.utils.cpp_extension.load(
|
||||
name="torch_test_xpu_extension",
|
||||
sources=[
|
||||
"cpp_extensions/xpu_extension.sycl",
|
||||
],
|
||||
verbose=True,
|
||||
keep_intermediates=False,
|
||||
)
|
||||
|
||||
x = torch.zeros(100, device="xpu", dtype=torch.float32)
|
||||
y = torch.zeros(100, device="xpu", dtype=torch.float32)
|
||||
|
||||
z = module.sigmoid_add(x, y).cpu()
|
||||
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
||||
def test_mps_extension(self):
|
||||
module = torch.utils.cpp_extension.load(
|
||||
@ -442,6 +462,80 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y)
|
||||
self.assertEqual(z, x.cos() + y.cos())
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not found")
|
||||
def test_inline_jit_compile_extension_xpu(self):
|
||||
sycl_source = """
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
class CosAddKernel {
|
||||
public:
|
||||
void operator()(const sycl::nd_item<3> &item_ct1) const {
|
||||
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
if (index < size) {
|
||||
output[index] = cosf(x[index]) + cosf(y[index]);
|
||||
}
|
||||
}
|
||||
CosAddKernel(const float* _x, const float* _y, float* _output, int _size):
|
||||
x(_x),
|
||||
y(_y),
|
||||
output(_output),
|
||||
size(_size)
|
||||
{}
|
||||
private:
|
||||
const float* x;
|
||||
const float* y;
|
||||
float* output;
|
||||
int size;
|
||||
};
|
||||
|
||||
void cos_add_kernel(
|
||||
const float* x,
|
||||
const float* y,
|
||||
float* output,
|
||||
int size) {
|
||||
CosAddKernel krn(x, y, output, size);
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
|
||||
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
|
||||
queue.submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for<CosAddKernel>(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
|
||||
sycl::range<3>(1, 1, threads)),
|
||||
krn);
|
||||
});
|
||||
}
|
||||
|
||||
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
|
||||
auto output = torch::zeros_like(x);
|
||||
const int threads = 1024;
|
||||
const int blocks = (output.numel() + threads - 1) / threads;
|
||||
cos_add_kernel(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
|
||||
return output;
|
||||
}
|
||||
"""
|
||||
|
||||
# Here, the C++ source need only declare the function signature.
|
||||
cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name="inline_jit_extension_xpu",
|
||||
cpp_sources=cpp_source,
|
||||
sycl_sources=sycl_source,
|
||||
functions=["cos_add"],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")
|
||||
|
||||
x = torch.randn(4, 4, device="xpu", dtype=torch.float32)
|
||||
y = torch.randn(4, 4, device="xpu", dtype=torch.float32)
|
||||
|
||||
z = module.cos_add(x, y)
|
||||
self.assertEqual(z, x.cos() + y.cos())
|
||||
|
||||
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
|
||||
with self.assertRaises(ValueError):
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
|
@ -40,6 +40,7 @@ class ExtensionVersioner:
|
||||
build_arguments,
|
||||
build_directory,
|
||||
with_cuda,
|
||||
with_sycl,
|
||||
is_python_module,
|
||||
is_standalone):
|
||||
hash_value = 0
|
||||
@ -47,6 +48,7 @@ class ExtensionVersioner:
|
||||
hash_value = hash_build_arguments(hash_value, build_arguments)
|
||||
hash_value = update_hash(hash_value, build_directory)
|
||||
hash_value = update_hash(hash_value, with_cuda)
|
||||
hash_value = update_hash(hash_value, with_sycl)
|
||||
hash_value = update_hash(hash_value, is_python_module)
|
||||
hash_value = update_hash(hash_value, is_standalone)
|
||||
|
||||
|
@ -75,7 +75,7 @@ CUDA_CLANG_VERSIONS: VersionMap = {
|
||||
}
|
||||
|
||||
__all__ = ["get_default_build_root", "check_compiler_ok_for_platform", "get_compiler_abi_compatibility_and_version", "BuildExtension",
|
||||
"CppExtension", "CUDAExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
|
||||
"CppExtension", "CUDAExtension", "SyclExtension", "include_paths", "library_paths", "load", "load_inline", "is_ninja_available",
|
||||
"verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"]
|
||||
# Taken directly from python stdlib < 3.9
|
||||
# See https://github.com/pytorch/pytorch/issues/48617
|
||||
@ -282,6 +282,30 @@ COMMON_HIPCC_FLAGS = [
|
||||
'-D__HIP_NO_HALF_CONVERSIONS__=1',
|
||||
]
|
||||
|
||||
_COMMON_SYCL_FLAGS = [
|
||||
'-fsycl',
|
||||
'-fsycl-targets=spir64_gen,spir64',
|
||||
]
|
||||
|
||||
def _get_sycl_arch_list():
|
||||
if 'TORCH_XPU_ARCH_LIST' in os.environ:
|
||||
return os.environ.get('TORCH_XPU_ARCH_LIST')
|
||||
if not torch.xpu.is_available():
|
||||
return ""
|
||||
arch_list = torch.xpu.get_arch_list()
|
||||
# Dropping dg2-* archs since they lack hardware support for fp64 and require
|
||||
# special consideration from the user. If needed these platforms can
|
||||
# be requested thru TORCH_XPU_ARCH_LIST environment variable.
|
||||
arch_list = [x for x in arch_list if not x.startswith('dg2-')]
|
||||
return ','.join(arch_list)
|
||||
|
||||
_SYCL_DLINK_FLAGS = [
|
||||
*_COMMON_SYCL_FLAGS,
|
||||
'-fsycl-link',
|
||||
'--offload-compress',
|
||||
f'-Xs "-device {_get_sycl_arch_list()}"',
|
||||
]
|
||||
|
||||
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
||||
|
||||
PLAT_TO_VCVARS = {
|
||||
@ -490,19 +514,34 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N
|
||||
)
|
||||
|
||||
|
||||
def _append_sycl_std_if_no_std_present(cflags):
|
||||
if not any(flag.startswith('-sycl-std=') for flag in cflags):
|
||||
cflags.append('-sycl-std=2020')
|
||||
|
||||
|
||||
def _wrap_sycl_host_flags(cflags):
|
||||
host_cxx = get_cxx_compiler()
|
||||
host_cflags = [
|
||||
f'-fsycl-host-compiler={host_cxx}',
|
||||
shlex.quote(f'-fsycl-host-compiler-options={cflags}'),
|
||||
]
|
||||
return host_cflags
|
||||
|
||||
|
||||
class BuildExtension(build_ext):
|
||||
"""
|
||||
A custom :mod:`setuptools` build extension .
|
||||
|
||||
This :class:`setuptools.build_ext` subclass takes care of passing the
|
||||
minimum required compiler flags (e.g. ``-std=c++17``) as well as mixed
|
||||
C++/CUDA compilation (and support for CUDA files in general).
|
||||
C++/CUDA/SYCL compilation (and support for CUDA/SYCL files in general).
|
||||
|
||||
When using :class:`BuildExtension`, it is allowed to supply a dictionary
|
||||
for ``extra_compile_args`` (rather than the usual list) that maps from
|
||||
languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to
|
||||
supply to the compiler. This makes it possible to supply different flags to
|
||||
the C++ and CUDA compiler during mixed compilation.
|
||||
languages/compilers (the only expected values are ``cxx``, ``nvcc`` or
|
||||
``sycl``) to a list of additional compiler flags to supply to the compiler.
|
||||
This makes it possible to supply different flags to the C++, CUDA and SYCL
|
||||
compiler during mixed compilation.
|
||||
|
||||
``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we
|
||||
attempt to build using the Ninja backend. Ninja greatly speeds up
|
||||
@ -548,29 +587,41 @@ class BuildExtension(build_ext):
|
||||
compiler_name, compiler_version = self._check_abi()
|
||||
|
||||
cuda_ext = False
|
||||
sycl_ext = False
|
||||
extension_iter = iter(self.extensions)
|
||||
extension = next(extension_iter, None)
|
||||
while not cuda_ext and extension:
|
||||
while not (cuda_ext and sycl_ext) and extension:
|
||||
for source in extension.sources:
|
||||
_, ext = os.path.splitext(source)
|
||||
if ext == '.cu':
|
||||
cuda_ext = True
|
||||
elif ext == '.sycl':
|
||||
sycl_ext = True
|
||||
|
||||
# This check accounts on a case when cuda and sycl sources
|
||||
# are mixed in the same extension. We can stop checking
|
||||
# sources if both are found or there is no more sources.
|
||||
if cuda_ext and sycl_ext:
|
||||
break
|
||||
|
||||
extension = next(extension_iter, None)
|
||||
|
||||
if sycl_ext:
|
||||
assert self.use_ninja, "ninja is required to build sycl extensions."
|
||||
|
||||
if cuda_ext and not IS_HIP_EXTENSION:
|
||||
_check_cuda_version(compiler_name, compiler_version)
|
||||
|
||||
for extension in self.extensions:
|
||||
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
|
||||
# Ensure at least an empty list of flags for 'cxx', 'nvcc' and 'sycl' when
|
||||
# extra_compile_args is a dict. Otherwise, default torch flags do
|
||||
# not get passed. Necessary when only one of 'cxx' and 'nvcc' is
|
||||
# passed to extra_compile_args in CUDAExtension, i.e.
|
||||
# not get passed. Necessary when only one of 'cxx', 'nvcc' or 'sycl' is
|
||||
# passed to extra_compile_args in CUDAExtension or SyclExtension, i.e.
|
||||
# CUDAExtension(..., extra_compile_args={'cxx': [...]})
|
||||
# or
|
||||
# CUDAExtension(..., extra_compile_args={'nvcc': [...]})
|
||||
if isinstance(extension.extra_compile_args, dict):
|
||||
for ext in ['cxx', 'nvcc']:
|
||||
for ext in ['cxx', 'nvcc', 'sycl']:
|
||||
if ext not in extension.extra_compile_args:
|
||||
extension.extra_compile_args[ext] = []
|
||||
|
||||
@ -597,8 +648,11 @@ class BuildExtension(build_ext):
|
||||
if 'nvcc_dlink' in extension.extra_compile_args:
|
||||
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
|
||||
# Register .cu, .cuh, .hip, and .mm as valid source extensions.
|
||||
self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
|
||||
# Register .cu, .cuh, .hip, .mm and .sycl as valid source extensions.
|
||||
# NOTE: At the moment .sycl is not a standard extension for SYCL supported
|
||||
# by compiler. Here we introduce a torch level convention that SYCL sources
|
||||
# should have .sycl file extension.
|
||||
self.compiler.src_extensions += ['.cu', '.cuh', '.hip', '.sycl']
|
||||
if torch.backends.mps.is_built():
|
||||
self.compiler.src_extensions += ['.mm']
|
||||
# Save the original _compile method for later.
|
||||
@ -698,9 +752,10 @@ class BuildExtension(build_ext):
|
||||
common_cflags = self.compiler._get_cc_args(pp_opts, debug, extra_preargs)
|
||||
extra_cc_cflags = self.compiler.compiler_so[1:]
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
with_sycl = any(map(_is_sycl_file, sources))
|
||||
|
||||
# extra_postargs can be either:
|
||||
# - a dict mapping cxx/nvcc to extra flags
|
||||
# - a dict mapping cxx/nvcc/sycl to extra flags
|
||||
# - a list of extra flags.
|
||||
if isinstance(extra_postargs, dict):
|
||||
post_cflags = extra_postargs['cxx']
|
||||
@ -731,6 +786,31 @@ class BuildExtension(build_ext):
|
||||
cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink'])
|
||||
else:
|
||||
cuda_dlink_post_cflags = None
|
||||
|
||||
sycl_post_cflags = None
|
||||
sycl_cflags = None
|
||||
sycl_dlink_post_cflags = None
|
||||
if with_sycl:
|
||||
sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS
|
||||
if isinstance(extra_postargs, dict):
|
||||
sycl_post_cflags = extra_postargs['sycl']
|
||||
else:
|
||||
sycl_post_cflags = list(extra_postargs)
|
||||
append_std17_if_no_std_present(sycl_cflags)
|
||||
_append_sycl_std_if_no_std_present(sycl_cflags)
|
||||
host_cflags = extra_cc_cflags + common_cflags + post_cflags
|
||||
append_std17_if_no_std_present(host_cflags)
|
||||
# escaping quoted arguments to pass them thru SYCL compiler
|
||||
host_cflags = [item.replace('"', '\\\\"') for item in host_cflags]
|
||||
host_cflags = ' '.join(host_cflags)
|
||||
# Note the order: shlex.quote sycl_flags first, _wrap_sycl_host_flags
|
||||
# second. Reason is that sycl host flags are quoted, space containing
|
||||
# strings passed to SYCL compiler.
|
||||
sycl_cflags = [shlex.quote(f) for f in sycl_cflags]
|
||||
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
||||
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
||||
sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags]
|
||||
|
||||
_write_ninja_file_and_compile_objects(
|
||||
sources=sources,
|
||||
objects=objects,
|
||||
@ -739,9 +819,13 @@ class BuildExtension(build_ext):
|
||||
cuda_cflags=cuda_cflags,
|
||||
cuda_post_cflags=cuda_post_cflags,
|
||||
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
|
||||
sycl_cflags=sycl_cflags,
|
||||
sycl_post_cflags=sycl_post_cflags,
|
||||
sycl_dlink_post_cflags=sycl_dlink_post_cflags,
|
||||
build_directory=output_dir,
|
||||
verbose=True,
|
||||
with_cuda=with_cuda)
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=with_sycl)
|
||||
|
||||
# Return *all* object filenames, not just the ones we just built.
|
||||
return objects
|
||||
@ -898,9 +982,13 @@ class BuildExtension(build_ext):
|
||||
cuda_cflags=cuda_cflags,
|
||||
cuda_post_cflags=cuda_post_cflags,
|
||||
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
|
||||
sycl_cflags=None,
|
||||
sycl_post_cflags=None,
|
||||
sycl_dlink_post_cflags=None,
|
||||
build_directory=output_dir,
|
||||
verbose=True,
|
||||
with_cuda=with_cuda)
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=False)
|
||||
|
||||
# Return *all* object filenames, not just the ones we just built.
|
||||
return objects
|
||||
@ -1235,6 +1323,78 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
||||
return setuptools.Extension(name, sources, *args, **kwargs)
|
||||
|
||||
|
||||
def SyclExtension(name, sources, *args, **kwargs):
|
||||
r"""
|
||||
Creates a :class:`setuptools.Extension` for SYCL/C++.
|
||||
|
||||
Convenience method that creates a :class:`setuptools.Extension` with the
|
||||
bare minimum (but often sufficient) arguments to build a SYCL/C++
|
||||
extension.
|
||||
|
||||
All arguments are forwarded to the :class:`setuptools.Extension`
|
||||
constructor.
|
||||
|
||||
.. note::
|
||||
The PyTorch python API (as provided in libtorch_python) cannot be built
|
||||
with the flag ``py_limited_api=True``. When this flag is passed, it is
|
||||
the user's responsibility in their library to not use APIs from
|
||||
libtorch_python (in particular pytorch/python bindings) and to only use
|
||||
APIs from libtorch (aten objects, operators and the dispatcher). For
|
||||
example, to give access to custom ops from python, the library should
|
||||
register the ops through the dispatcher.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CPP_EXT)
|
||||
>>> from torch.utils.cpp_extension import BuildExtension, SyclExtension
|
||||
>>> setup(
|
||||
... name='xpu_extension',
|
||||
... ext_modules=[
|
||||
... SyclExtension(
|
||||
... name='xpu_extension',
|
||||
... sources=['extension.cpp', 'extension_kernel.cpp'],
|
||||
... extra_compile_args={'cxx': ['-g', '-std=c++20', '-fPIC']})
|
||||
... ],
|
||||
... cmdclass={
|
||||
... 'build_ext': BuildExtension
|
||||
... })
|
||||
|
||||
By default the extension will be compiled to run on all archs of the cards visible during the
|
||||
building process of the extension. If down the road a new card is installed the
|
||||
extension may need to be recompiled. You can override the default behavior using
|
||||
`TORCH_XPU_ARCH_LIST` to explicitly specify which device architectures you want the extension
|
||||
to support:
|
||||
|
||||
``TORCH_XPU_ARCH_LIST="pvc,xe-lpg" python build_my_extension.py``
|
||||
|
||||
Note that while it's possible to include all supported archs, the more archs get included the
|
||||
slower the building process will be, as it will build a separate kernel image for each arch.
|
||||
|
||||
Note: Ninja is required to build SyclExtension.
|
||||
"""
|
||||
library_dirs = kwargs.get("library_dirs", [])
|
||||
library_dirs += library_paths()
|
||||
kwargs["library_dirs"] = library_dirs
|
||||
|
||||
libraries = kwargs.get("libraries", [])
|
||||
libraries.append("c10")
|
||||
libraries.append("c10_xpu")
|
||||
libraries.append("torch")
|
||||
libraries.append("torch_cpu")
|
||||
if not kwargs.get('py_limited_api', False):
|
||||
# torch_python uses more than the python limited api
|
||||
libraries.append("torch_python")
|
||||
libraries.append("torch_xpu")
|
||||
kwargs["libraries"] = libraries
|
||||
|
||||
include_dirs = kwargs.get("include_dirs", [])
|
||||
include_dirs += include_paths()
|
||||
kwargs["include_dirs"] = include_dirs
|
||||
|
||||
kwargs["language"] = "c++"
|
||||
|
||||
return setuptools.Extension(name, sources, *args, **kwargs)
|
||||
|
||||
def include_paths(device_type: str = "cpu") -> list[str]:
|
||||
"""
|
||||
Get the include paths required to build a C++ or CUDA or SYCL extension.
|
||||
@ -1323,11 +1483,13 @@ def load(name,
|
||||
sources: Union[str, list[str]],
|
||||
extra_cflags=None,
|
||||
extra_cuda_cflags=None,
|
||||
extra_sycl_cflags=None,
|
||||
extra_ldflags=None,
|
||||
extra_include_paths=None,
|
||||
build_directory=None,
|
||||
verbose=False,
|
||||
with_cuda: Optional[bool] = None,
|
||||
with_sycl: Optional[bool] = None,
|
||||
is_python_module=True,
|
||||
is_standalone=False,
|
||||
keep_intermediates=True):
|
||||
@ -1366,6 +1528,14 @@ def load(name,
|
||||
work fine. If not, setting the ``CUDA_HOME`` environment variable is the
|
||||
safest option.
|
||||
|
||||
SYCL support with mixed compilation is provided. Simply pass SYCL source
|
||||
files (``.sycl``) along with other sources. Such files will be detected
|
||||
and compiled with SYCL compiler (such as Intel DPC++ Compiler) rather
|
||||
than the C++ compiler. You can pass additional flags to SYCL compiler
|
||||
via ``extra_sycl_cflags``, just like with ``extra_cflags`` for C++.
|
||||
SYCL compiler is expected to be found via system PATH environment
|
||||
variable.
|
||||
|
||||
Args:
|
||||
name: The name of the extension to build. This MUST be the same as the
|
||||
name of the pybind11 module!
|
||||
@ -1373,6 +1543,8 @@ def load(name,
|
||||
extra_cflags: optional list of compiler flags to forward to the build.
|
||||
extra_cuda_cflags: optional list of compiler flags to forward to nvcc
|
||||
when building CUDA sources.
|
||||
extra_sycl_cflags: optional list of compiler flags to forward to SYCL
|
||||
compiler when building SYCL sources.
|
||||
extra_ldflags: optional list of linker flags to forward to the build.
|
||||
extra_include_paths: optional list of include directories to forward
|
||||
to the build.
|
||||
@ -1383,6 +1555,11 @@ def load(name,
|
||||
automatically determined based on the existence of ``.cu`` or
|
||||
``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
|
||||
and libraries to be included.
|
||||
with_sycl: Determines whether SYCL headers and libraries are added to
|
||||
the build. If set to ``None`` (default), this value is
|
||||
automatically determined based on the existence of ``.sycl`` in
|
||||
``sources``. Set it to `True`` to force SYCL headers and
|
||||
libraries to be included.
|
||||
is_python_module: If ``True`` (default), imports the produced shared
|
||||
library as a Python module. If ``False``, behavior depends on
|
||||
``is_standalone``.
|
||||
@ -1416,11 +1593,13 @@ def load(name,
|
||||
[sources] if isinstance(sources, str) else sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_sycl_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory or _get_build_directory(name, verbose),
|
||||
verbose,
|
||||
with_cuda,
|
||||
with_sycl,
|
||||
is_python_module,
|
||||
is_standalone,
|
||||
keep_intermediates=keep_intermediates)
|
||||
@ -1608,14 +1787,17 @@ def remove_extension_h_precompiler_headers():
|
||||
def load_inline(name,
|
||||
cpp_sources,
|
||||
cuda_sources=None,
|
||||
sycl_sources=None,
|
||||
functions=None,
|
||||
extra_cflags=None,
|
||||
extra_cuda_cflags=None,
|
||||
extra_sycl_cflags=None,
|
||||
extra_ldflags=None,
|
||||
extra_include_paths=None,
|
||||
build_directory=None,
|
||||
verbose=False,
|
||||
with_cuda=None,
|
||||
with_sycl=None,
|
||||
is_python_module=True,
|
||||
with_pytorch_error_handling=True,
|
||||
keep_intermediates=True,
|
||||
@ -1653,11 +1835,21 @@ def load_inline(name,
|
||||
declare or define this C++ function in one of the ``cpp_sources`` (and
|
||||
include its name in ``functions``).
|
||||
|
||||
The sources in ``sycl_sources`` are concatenated into a separate ``.sycl``
|
||||
file and prepended with ``torch/types.h``, ``sycl/sycl.hpp`` includes.
|
||||
The ``.cpp`` and ``.sycl`` files are compiled separately, but ultimately
|
||||
linked into a single library. Note that no bindings are generated for
|
||||
functions in ``sycl_sources`` per se. To bind to a SYCL kernel, you must
|
||||
create a C++ function that calls it, and either declare or define this
|
||||
C++ function in one of the ``cpp_sources`` (and include its name
|
||||
in ``functions``).
|
||||
|
||||
See :func:`load` for a description of arguments omitted below.
|
||||
|
||||
Args:
|
||||
cpp_sources: A string, or list of strings, containing C++ source code.
|
||||
cuda_sources: A string, or list of strings, containing CUDA source code.
|
||||
sycl_sources: A string, or list of strings, containing SYCL source code.
|
||||
functions: A list of function names for which to generate function
|
||||
bindings. If a dictionary is given, it should map function names to
|
||||
docstrings (which are otherwise just the function names).
|
||||
@ -1666,6 +1858,11 @@ def load_inline(name,
|
||||
automatically determined based on whether ``cuda_sources`` is
|
||||
provided. Set it to ``True`` to force CUDA headers
|
||||
and libraries to be included.
|
||||
with_sycl: Determines whether SYCL headers and libraries are added to
|
||||
the build. If set to ``None`` (default), this value is
|
||||
automatically determined based on whether ``sycl_sources`` is
|
||||
provided. Set it to ``True`` to force SYCL headers
|
||||
and libraries to be included.
|
||||
with_pytorch_error_handling: Determines whether pytorch error and
|
||||
warning macros are handled by pytorch instead of pybind. To do
|
||||
this, each function ``foo`` is called via an intermediary ``_safe_foo``
|
||||
@ -1705,6 +1902,9 @@ def load_inline(name,
|
||||
cuda_sources = cuda_sources or []
|
||||
if isinstance(cuda_sources, str):
|
||||
cuda_sources = [cuda_sources]
|
||||
sycl_sources = sycl_sources or []
|
||||
if isinstance(sycl_sources, str):
|
||||
sycl_sources = [sycl_sources]
|
||||
|
||||
cpp_sources.insert(0, '#include <torch/extension.h>')
|
||||
|
||||
@ -1750,16 +1950,27 @@ def load_inline(name,
|
||||
|
||||
sources.append(cuda_source_path)
|
||||
|
||||
if sycl_sources:
|
||||
sycl_sources.insert(0, '#include <torch/types.h>')
|
||||
sycl_sources.insert(1, '#include <sycl/sycl.hpp>')
|
||||
|
||||
sycl_source_path = os.path.join(build_directory, 'sycl.sycl')
|
||||
_maybe_write(sycl_source_path, "\n".join(sycl_sources))
|
||||
|
||||
sources.append(sycl_source_path)
|
||||
|
||||
return _jit_compile(
|
||||
name,
|
||||
sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_sycl_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory,
|
||||
verbose,
|
||||
with_cuda,
|
||||
with_sycl,
|
||||
is_python_module,
|
||||
is_standalone=False,
|
||||
keep_intermediates=keep_intermediates)
|
||||
@ -1769,11 +1980,13 @@ def _jit_compile(name,
|
||||
sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_sycl_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool],
|
||||
with_sycl: Optional[bool],
|
||||
is_python_module,
|
||||
is_standalone,
|
||||
keep_intermediates=True) -> None:
|
||||
@ -1783,6 +1996,8 @@ def _jit_compile(name,
|
||||
if with_cuda is None:
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
with_cudnn = any('cudnn' in f for f in extra_ldflags or [])
|
||||
if with_sycl is None:
|
||||
with_sycl = any(map(_is_sycl_file, sources))
|
||||
old_version = JIT_EXTENSION_VERSIONER.get_version(name)
|
||||
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
|
||||
name,
|
||||
@ -1790,6 +2005,7 @@ def _jit_compile(name,
|
||||
build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
|
||||
build_directory=build_directory,
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=with_sycl,
|
||||
is_python_module=is_python_module,
|
||||
is_standalone=is_standalone,
|
||||
)
|
||||
@ -1830,11 +2046,13 @@ def _jit_compile(name,
|
||||
sources=sources,
|
||||
extra_cflags=extra_cflags or [],
|
||||
extra_cuda_cflags=extra_cuda_cflags or [],
|
||||
extra_sycl_cflags=extra_sycl_cflags or [],
|
||||
extra_ldflags=extra_ldflags or [],
|
||||
extra_include_paths=extra_include_paths or [],
|
||||
build_directory=build_directory,
|
||||
verbose=verbose,
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=with_sycl,
|
||||
is_standalone=is_standalone)
|
||||
elif verbose:
|
||||
print('No modifications detected for re-loaded extension '
|
||||
@ -1861,9 +2079,13 @@ def _write_ninja_file_and_compile_objects(
|
||||
cuda_cflags,
|
||||
cuda_post_cflags,
|
||||
cuda_dlink_post_cflags,
|
||||
sycl_cflags,
|
||||
sycl_post_cflags,
|
||||
sycl_dlink_post_cflags,
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool]) -> None:
|
||||
with_cuda: Optional[bool],
|
||||
with_sycl: Optional[bool]) -> None:
|
||||
verify_ninja_availability()
|
||||
|
||||
compiler = get_cxx_compiler()
|
||||
@ -1871,6 +2093,8 @@ def _write_ninja_file_and_compile_objects(
|
||||
get_compiler_abi_compatibility_and_version(compiler)
|
||||
if with_cuda is None:
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
if with_sycl is None:
|
||||
with_sycl = any(map(_is_sycl_file, sources))
|
||||
build_file_path = os.path.join(build_directory, 'build.ninja')
|
||||
if verbose:
|
||||
print(f'Emitting ninja build file {build_file_path}...', file=sys.stderr)
|
||||
@ -1889,11 +2113,15 @@ def _write_ninja_file_and_compile_objects(
|
||||
cuda_cflags=cuda_cflags,
|
||||
cuda_post_cflags=cuda_post_cflags,
|
||||
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
|
||||
sycl_cflags=sycl_cflags,
|
||||
sycl_post_cflags=sycl_post_cflags,
|
||||
sycl_dlink_post_cflags=sycl_dlink_post_cflags,
|
||||
sources=sources,
|
||||
objects=objects,
|
||||
ldflags=None,
|
||||
library_target=None,
|
||||
with_cuda=with_cuda)
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=with_sycl)
|
||||
if verbose:
|
||||
print('Compiling objects...', file=sys.stderr)
|
||||
_run_ninja_build(
|
||||
@ -1909,11 +2137,13 @@ def _write_ninja_file_and_build_library(
|
||||
sources: list[str],
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_sycl_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool],
|
||||
with_sycl: Optional[bool],
|
||||
is_standalone: bool = False) -> None:
|
||||
verify_ninja_availability()
|
||||
|
||||
@ -1922,6 +2152,8 @@ def _write_ninja_file_and_build_library(
|
||||
get_compiler_abi_compatibility_and_version(compiler)
|
||||
if with_cuda is None:
|
||||
with_cuda = any(map(_is_cuda_file, sources))
|
||||
if with_sycl is None:
|
||||
with_sycl = any(map(_is_sycl_file, sources))
|
||||
extra_ldflags = _prepare_ldflags(
|
||||
extra_ldflags or [],
|
||||
with_cuda,
|
||||
@ -1946,9 +2178,11 @@ def _write_ninja_file_and_build_library(
|
||||
sources=sources,
|
||||
extra_cflags=extra_cflags or [],
|
||||
extra_cuda_cflags=extra_cuda_cflags or [],
|
||||
extra_sycl_cflags=extra_sycl_cflags or [],
|
||||
extra_ldflags=extra_ldflags or [],
|
||||
extra_include_paths=extra_include_paths or [],
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=with_sycl,
|
||||
is_standalone=is_standalone)
|
||||
|
||||
if verbose:
|
||||
@ -2287,12 +2521,15 @@ def _write_ninja_file_to_build_library(path,
|
||||
sources,
|
||||
extra_cflags,
|
||||
extra_cuda_cflags,
|
||||
extra_sycl_cflags,
|
||||
extra_ldflags,
|
||||
extra_include_paths,
|
||||
with_cuda,
|
||||
with_sycl,
|
||||
is_standalone) -> None:
|
||||
extra_cflags = [flag.strip() for flag in extra_cflags]
|
||||
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
|
||||
extra_sycl_cflags = [flag.strip() for flag in extra_sycl_cflags]
|
||||
extra_ldflags = [flag.strip() for flag in extra_ldflags]
|
||||
extra_include_paths = [flag.strip() for flag in extra_include_paths]
|
||||
|
||||
@ -2360,6 +2597,20 @@ def _write_ninja_file_to_build_library(path,
|
||||
else:
|
||||
cuda_flags = None
|
||||
|
||||
if with_sycl:
|
||||
sycl_cflags = cflags + _COMMON_SYCL_FLAGS
|
||||
sycl_cflags += extra_sycl_cflags
|
||||
_append_sycl_std_if_no_std_present(sycl_cflags)
|
||||
host_cflags = cflags
|
||||
# escaping quoted arguments to pass them thru SYCL compiler
|
||||
host_cflags = [item.replace('\\"', '\\\\"') for item in host_cflags]
|
||||
host_cflags = ' '.join(host_cflags)
|
||||
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
||||
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
||||
else:
|
||||
sycl_cflags = None
|
||||
sycl_dlink_post_cflags = None
|
||||
|
||||
def object_file_path(source_file: str) -> str:
|
||||
# '/path/to/file.cpp' -> 'file'
|
||||
file_name = os.path.splitext(os.path.basename(source_file))[0]
|
||||
@ -2367,6 +2618,8 @@ def _write_ninja_file_to_build_library(path,
|
||||
# Use a different object filename in case a C++ and CUDA file have
|
||||
# the same filename but different extension (.cpp vs. .cu).
|
||||
target = f'{file_name}.cuda.o'
|
||||
elif _is_sycl_file(source_file) and with_sycl:
|
||||
target = f'{file_name}.sycl.o'
|
||||
else:
|
||||
target = f'{file_name}.o'
|
||||
return target
|
||||
@ -2390,11 +2643,15 @@ def _write_ninja_file_to_build_library(path,
|
||||
cuda_cflags=cuda_flags,
|
||||
cuda_post_cflags=None,
|
||||
cuda_dlink_post_cflags=None,
|
||||
sycl_cflags=sycl_cflags,
|
||||
sycl_post_cflags=[],
|
||||
sycl_dlink_post_cflags=sycl_dlink_post_cflags,
|
||||
sources=sources,
|
||||
objects=objects,
|
||||
ldflags=ldflags,
|
||||
library_target=library_target,
|
||||
with_cuda=with_cuda)
|
||||
with_cuda=with_cuda,
|
||||
with_sycl=with_sycl)
|
||||
|
||||
|
||||
def _write_ninja_file(path,
|
||||
@ -2403,18 +2660,27 @@ def _write_ninja_file(path,
|
||||
cuda_cflags,
|
||||
cuda_post_cflags,
|
||||
cuda_dlink_post_cflags,
|
||||
sycl_cflags,
|
||||
sycl_post_cflags,
|
||||
sycl_dlink_post_cflags,
|
||||
sources,
|
||||
objects,
|
||||
ldflags,
|
||||
library_target,
|
||||
with_cuda) -> None:
|
||||
with_cuda,
|
||||
with_sycl) -> None:
|
||||
r"""Write a ninja file that does the desired compiling and linking.
|
||||
|
||||
`path`: Where to write this file
|
||||
`cflags`: list of flags to pass to $cxx. Can be None.
|
||||
`post_cflags`: list of flags to append to the $cxx invocation. Can be None.
|
||||
`cuda_cflags`: list of flags to pass to $nvcc. Can be None.
|
||||
`cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None.
|
||||
`cuda_post_cflags`: list of flags to append to the $nvcc invocation. Can be None.
|
||||
`cuda_dlink_post_cflags`: list of flags to append to the $nvcc device code link invocation. Can be None.
|
||||
`sycl_cflags`: list of flags to pass to SYCL compiler. Can be None.
|
||||
`sycl_post_cflags`: list of flags to append to the SYCL compiler invocation. Can be None.
|
||||
`sycl_dlink_post_cflags`: list of flags to append to the SYCL compiler device code link invocation. Can be None.
|
||||
e.
|
||||
`sources`: list of paths to source files
|
||||
`objects`: list of desired paths to objects, one per source.
|
||||
`ldflags`: list of flags to pass to linker. Can be None.
|
||||
@ -2433,6 +2699,9 @@ def _write_ninja_file(path,
|
||||
cuda_cflags = sanitize_flags(cuda_cflags)
|
||||
cuda_post_cflags = sanitize_flags(cuda_post_cflags)
|
||||
cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
|
||||
sycl_cflags = sanitize_flags(sycl_cflags)
|
||||
sycl_post_cflags = sanitize_flags(sycl_post_cflags)
|
||||
sycl_dlink_post_cflags = sanitize_flags(sycl_dlink_post_cflags)
|
||||
ldflags = sanitize_flags(ldflags)
|
||||
|
||||
# Sanity checks...
|
||||
@ -2453,6 +2722,9 @@ def _write_ninja_file(path,
|
||||
else:
|
||||
nvcc = _join_cuda_home('bin', 'nvcc')
|
||||
config.append(f'nvcc = {nvcc}')
|
||||
if with_sycl or sycl_dlink_post_cflags:
|
||||
sycl = 'icx' if IS_WINDOWS else 'icpx'
|
||||
config.append(f'sycl = {sycl}')
|
||||
|
||||
if IS_HIP_EXTENSION:
|
||||
post_cflags = COMMON_HIP_FLAGS + post_cflags
|
||||
@ -2462,6 +2734,10 @@ def _write_ninja_file(path,
|
||||
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
|
||||
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
|
||||
flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
|
||||
if with_sycl:
|
||||
flags.append(f'sycl_cflags = {" ".join(sycl_cflags)}')
|
||||
flags.append(f'sycl_post_cflags = {" ".join(sycl_post_cflags)}')
|
||||
flags.append(f'sycl_dlink_post_cflags = {" ".join(sycl_dlink_post_cflags)}')
|
||||
flags.append(f'ldflags = {" ".join(ldflags)}')
|
||||
|
||||
# Turn into absolute paths so we can emit them into the ninja build
|
||||
@ -2495,11 +2771,25 @@ def _write_ninja_file(path,
|
||||
cuda_compile_rule.append(
|
||||
f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')
|
||||
|
||||
if with_sycl:
|
||||
sycl_compile_rule = ['rule sycl_compile']
|
||||
# SYCL compiler does not recognize .sycl extension automatically,
|
||||
# so we pass '-x c++' explicitly notifying compiler of file format
|
||||
sycl_compile_rule.append(
|
||||
' command = $sycl $sycl_cflags -c -x c++ $in -o $out $sycl_post_cflags')
|
||||
|
||||
|
||||
# Emit one build rule per source to enable incremental build.
|
||||
build = []
|
||||
for source_file, object_file in zip(sources, objects):
|
||||
is_cuda_source = _is_cuda_file(source_file) and with_cuda
|
||||
rule = 'cuda_compile' if is_cuda_source else 'compile'
|
||||
is_sycl_source = _is_sycl_file(source_file) and with_sycl
|
||||
if is_cuda_source:
|
||||
rule = 'cuda_compile'
|
||||
elif is_sycl_source:
|
||||
rule = 'sycl_compile'
|
||||
else:
|
||||
rule = 'compile'
|
||||
if IS_WINDOWS:
|
||||
source_file = source_file.replace(':', '$:')
|
||||
object_file = object_file.replace(':', '$:')
|
||||
@ -2508,13 +2798,22 @@ def _write_ninja_file(path,
|
||||
build.append(f'build {object_file}: {rule} {source_file}')
|
||||
|
||||
if cuda_dlink_post_cflags:
|
||||
devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
|
||||
devlink_rule = ['rule cuda_devlink']
|
||||
devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags')
|
||||
devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
|
||||
objects += [devlink_out]
|
||||
cuda_devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
|
||||
cuda_devlink_rule = ['rule cuda_devlink']
|
||||
cuda_devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags')
|
||||
cuda_devlink = [f'build {cuda_devlink_out}: cuda_devlink {" ".join(objects)}']
|
||||
objects += [cuda_devlink_out]
|
||||
else:
|
||||
devlink_rule, devlink = [], []
|
||||
cuda_devlink_rule, cuda_devlink = [], []
|
||||
|
||||
if sycl_dlink_post_cflags:
|
||||
sycl_devlink_out = os.path.join(os.path.dirname(objects[0]), 'sycl_dlink.o')
|
||||
sycl_devlink_rule = ['rule sycl_devlink']
|
||||
sycl_devlink_rule.append(' command = $sycl $in -o $out $sycl_dlink_post_cflags')
|
||||
sycl_devlink = [f'build {sycl_devlink_out}: sycl_devlink {" ".join(objects)}']
|
||||
objects += [sycl_devlink_out]
|
||||
else:
|
||||
sycl_devlink_rule, sycl_devlink = [], []
|
||||
|
||||
if library_target is not None:
|
||||
link_rule = ['rule link']
|
||||
@ -2539,7 +2838,9 @@ def _write_ninja_file(path,
|
||||
blocks = [config, flags, compile_rule]
|
||||
if with_cuda:
|
||||
blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined]
|
||||
blocks += [devlink_rule, link_rule, build, devlink, link, default]
|
||||
if with_sycl:
|
||||
blocks.append(sycl_compile_rule) # type: ignore[possibly-undefined]
|
||||
blocks += [cuda_devlink_rule, sycl_devlink_rule, link_rule, build, cuda_devlink, sycl_devlink, link, default]
|
||||
content = "\n\n".join("\n".join(b) for b in blocks)
|
||||
# Ninja requires a new lines at the end of the .ninja file
|
||||
content += "\n"
|
||||
@ -2563,3 +2864,7 @@ def _is_cuda_file(path: str) -> bool:
|
||||
if IS_HIP_EXTENSION:
|
||||
valid_ext.append('.hip')
|
||||
return os.path.splitext(path)[1] in valid_ext
|
||||
|
||||
def _is_sycl_file(path: str) -> bool:
|
||||
valid_ext = ['.sycl']
|
||||
return os.path.splitext(path)[1] in valid_ext
|
||||
|
Reference in New Issue
Block a user