mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This patch adds support for sycl kernels build via `torch.utils.cpp_extension.load`, `torch.utils.cpp_extension.load_inline` and (new) `class SyclExtension` APIs. Files having `.sycl` extension are considered to have sycl kernels and are compiled with `icpx` (dpc++ sycl compiler from Intel). Files with other extensions, `.cpp`, `.cu`, are handled as before. API supports building sycl along with other file types into single extension. Note that `.sycl` file extension is a PyTorch convention for files containing sycl code which I propose to adopt. We did follow up with compiler team to introduce such file extension in the compiler, but they are opposed to this. At the same time discussion around sycl file extension and adding sycl language support into such tools as cmake is ongoing. Eventually cmake also considers to introduce some file extension convention for sycl. I hope we can further influence cmake and compiler communities to broader adopt `.sycl` file extension. By default SYCL kernels are compiled for all Intel GPU devices for which pytorch native aten SYCL kernels are compiled. At the moment `pvc,xe-lpg`. This behavior can be overridden by setting `TORCH_XPU_ARCH_LIST` environment variables to the comma separated list of desired devices to compile for. Fixes: #132944 CC: @gujinghui @EikanWang @fengyuan14 @guangyey @jgong5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132945 Approved by: https://github.com/albanD, https://github.com/guangyey, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
64 lines
2.0 KiB
Plaintext
64 lines
2.0 KiB
Plaintext
#include <c10/xpu/XPUStream.h>
|
|
#include <torch/extension.h>
|
|
#include <sycl/sycl.hpp>
|
|
|
|
void sigmoid_add_kernel(const float* x,
|
|
const float* y,
|
|
float* output,
|
|
const int size,
|
|
const sycl::nd_item<3> &item_ct1) {
|
|
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
|
item_ct1.get_local_id(2);
|
|
if (index < size) {
|
|
const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index]));
|
|
const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index]));
|
|
output[index] = sigmoid_x + sigmoid_y;
|
|
}
|
|
}
|
|
|
|
class SigmoidAddKernel {
|
|
public:
|
|
void operator()(const sycl::nd_item<3> &item_ct1) const {
|
|
sigmoid_add_kernel(x, y, output, size, item_ct1);
|
|
}
|
|
SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size):
|
|
x(_x),
|
|
y(_y),
|
|
output(_output),
|
|
size(_size)
|
|
{}
|
|
private:
|
|
const float* x;
|
|
const float* y;
|
|
float* output;
|
|
int size;
|
|
};
|
|
|
|
void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) {
|
|
SigmoidAddKernel krn(x, y, output, size);
|
|
const int threads = 1024;
|
|
const int blocks = (size + threads - 1) / threads;
|
|
|
|
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
|
|
queue.submit([&](sycl::handler &cgh) {
|
|
cgh.parallel_for<SigmoidAddKernel>(
|
|
sycl::nd_range<3>(
|
|
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
|
|
sycl::range<3>(1, 1, threads)),
|
|
krn);
|
|
});
|
|
}
|
|
|
|
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
|
|
TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor");
|
|
TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor");
|
|
auto output = torch::zeros_like(x);
|
|
sigmoid_add_xpu(
|
|
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
|
|
return output;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)");
|
|
}
|