mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "xpu: support sycl with torch.utils.cpp_extension APIs (#132945)"
This reverts commit 607379960bc5093a1fe51ff72c3e0fd39ac126ab.
Reverted https://github.com/pytorch/pytorch/pull/132945 on behalf of https://github.com/malfet due to It just broke all the tests, see b16ae97ad0/1
([comment](https://github.com/pytorch/pytorch/pull/132945#issuecomment-2661498747))
This commit is contained in:
@ -17,7 +17,7 @@ import torch.multiprocessing as mp
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
|
||||
from torch.testing._internal.common_utils import gradcheck, TEST_XPU
|
||||
from torch.testing._internal.common_utils import gradcheck
|
||||
from torch.utils.cpp_extension import (
|
||||
_TORCH_PATH,
|
||||
check_compiler_is_gcc,
|
||||
@ -116,26 +116,6 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_extension(self):
|
||||
# NOTE: The name of the extension must equal the name of the module.
|
||||
module = torch.utils.cpp_extension.load(
|
||||
name="torch_test_xpu_extension",
|
||||
sources=[
|
||||
"cpp_extensions/xpu_extension.sycl",
|
||||
],
|
||||
verbose=True,
|
||||
keep_intermediates=False,
|
||||
)
|
||||
|
||||
x = torch.zeros(100, device="xpu", dtype=torch.float32)
|
||||
y = torch.zeros(100, device="xpu", dtype=torch.float32)
|
||||
|
||||
z = module.sigmoid_add(x, y).cpu()
|
||||
|
||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
||||
def test_mps_extension(self):
|
||||
module = torch.utils.cpp_extension.load(
|
||||
@ -462,80 +442,6 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
z = torch.ops.inline_jit_extension_custom_op_cuda.cos_add(x, y)
|
||||
self.assertEqual(z, x.cos() + y.cos())
|
||||
|
||||
@unittest.skipIf(not TEST_XPU, "XPU not found")
|
||||
def test_inline_jit_compile_extension_xpu(self):
|
||||
sycl_source = """
|
||||
#include <c10/xpu/XPUStream.h>
|
||||
|
||||
class CosAddKernel {
|
||||
public:
|
||||
void operator()(const sycl::nd_item<3> &item_ct1) const {
|
||||
const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
if (index < size) {
|
||||
output[index] = cosf(x[index]) + cosf(y[index]);
|
||||
}
|
||||
}
|
||||
CosAddKernel(const float* _x, const float* _y, float* _output, int _size):
|
||||
x(_x),
|
||||
y(_y),
|
||||
output(_output),
|
||||
size(_size)
|
||||
{}
|
||||
private:
|
||||
const float* x;
|
||||
const float* y;
|
||||
float* output;
|
||||
int size;
|
||||
};
|
||||
|
||||
void cos_add_kernel(
|
||||
const float* x,
|
||||
const float* y,
|
||||
float* output,
|
||||
int size) {
|
||||
CosAddKernel krn(x, y, output, size);
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
|
||||
sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();
|
||||
queue.submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for<CosAddKernel>(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads),
|
||||
sycl::range<3>(1, 1, threads)),
|
||||
krn);
|
||||
});
|
||||
}
|
||||
|
||||
torch::Tensor cos_add(torch::Tensor x, torch::Tensor y) {
|
||||
auto output = torch::zeros_like(x);
|
||||
const int threads = 1024;
|
||||
const int blocks = (output.numel() + threads - 1) / threads;
|
||||
cos_add_kernel(x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());
|
||||
return output;
|
||||
}
|
||||
"""
|
||||
|
||||
# Here, the C++ source need only declare the function signature.
|
||||
cpp_source = "torch::Tensor cos_add(torch::Tensor x, torch::Tensor y);"
|
||||
|
||||
module = torch.utils.cpp_extension.load_inline(
|
||||
name="inline_jit_extension_xpu",
|
||||
cpp_sources=cpp_source,
|
||||
sycl_sources=sycl_source,
|
||||
functions=["cos_add"],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
self.assertEqual(module.cos_add.__doc__.split("\n")[2], "cos_add")
|
||||
|
||||
x = torch.randn(4, 4, device="xpu", dtype=torch.float32)
|
||||
y = torch.randn(4, 4, device="xpu", dtype=torch.float32)
|
||||
|
||||
z = module.cos_add(x, y)
|
||||
self.assertEqual(z, x.cos() + y.cos())
|
||||
|
||||
def test_inline_jit_compile_extension_throws_when_functions_is_bad(self):
|
||||
with self.assertRaises(ValueError):
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
|
Reference in New Issue
Block a user