Revert "xpu: support sycl with torch.utils.cpp_extension APIs (#132945)"

This reverts commit 607379960bc5093a1fe51ff72c3e0fd39ac126ab.

Reverted https://github.com/pytorch/pytorch/pull/132945 on behalf of https://github.com/malfet due to It just broke all the tests, see b16ae97ad0/1 ([comment](https://github.com/pytorch/pytorch/pull/132945#issuecomment-2661498747))
This commit is contained in:
PyTorch MergeBot
2025-02-16 16:03:41 +00:00
parent b16ae97ad0
commit dd5d0ea6bb
7 changed files with 29 additions and 519 deletions

View File

@ -17,7 +17,7 @@ import torch.multiprocessing as mp
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
from torch.testing._internal.common_utils import gradcheck
from torch.utils.cpp_extension import (
_TORCH_PATH,
check_compiler_is_gcc,
@ -116,26 +116,6 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: The name of the extension must equal the name of the module.
module = torch.utils.cpp_extension.load(
name="torch_test_xpu_extension",
sources=[
"cpp_extensions/xpu_extension.sycl",
],
verbose=True,
keep_intermediates=False,
)
x = torch.zeros(100, device="xpu", dtype=torch.float32)
y = torch.zeros(100, device="xpu", dtype=torch.float32)
z = module.sigmoid_add(x, y).cpu()
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not TEST_MPS, "MPS not found")
def test_mps_extension(self):
module = torch.utils.cpp_extension.load(
@ -462,80 +442,6 @@ class TestCppExtensionJIT(common.TestCase):
z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y)
self.assertEqual(z, x.cos() + y.cos())
@unittest.skipIf(not TEST_XPU, "XPU not found")
def test_inline_jit_compile_extension_xpu(self):
sycl_source = """
#include <c10/xpu/XPUStream.h>
class CosAddKernel {
public:
void operator()(const sycl::nd_item<3> &item_ct1) const {
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2);
if (index < size) {
output[index] = cosf(x[index]) + cosf(y[index]);
}
}
CosAddKernel(const float* _x, const float* _y, float* _output, int _size):
x(_x),
y(_y),
output(_output),
size(_size)
{}
private:
const float* x;
const float* y;
float* output;
int size;
};
void cos_add_kernel(
const float* x,
const float* y,
float* output,
int size) {
CosAddKernel krn(x, y, output, size);
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
queue.submit([&](sycl::handler &cgh) {
cgh.parallel_for<CosAddKernel>(
sycl::nd_range<3>(
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
sycl::range<3>(1, 1, threads)),
krn);
});
}
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
auto output = torch::zeros_like(x);
const int threads = 1024;
const int blocks = (output.numel() + threads - 1) / threads;
cos_add_kernel(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
return output;
}
"""
# Here, the C++ source need only declare the function signature.
cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"
module = torch.utils.cpp_extension.load_inline(
name="inline_jit_extension_xpu",
cpp_sources=cpp_source,
sycl_sources=sycl_source,
functions=["cos_add"],
verbose=True,
)
self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")
x = torch.randn(4, 4, device="xpu", dtype=torch.float32)
y = torch.randn(4, 4, device="xpu", dtype=torch.float32)
z = module.cos_add(x, y)
self.assertEqual(z, x.cos() + y.cos())
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
with self.assertRaises(ValueError):
torch.utils.cpp_extension.load_inline(