[Intel GPU] oneDNN GPU GEMM support (#117202)

# Motivation

This PR is a part of RFC #114848, and it  is a successor PR of #116249 and #116019. This PR would depend on oneDNN compilation in #116249. Some runtime support is needed in #116019.

Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`.

Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`.  `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in #117824.

In next stage, we would continually complete the oneDNN support through enabling  `matmul fusion`  and `convolution` related code.

Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117202
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #117098, #117112
This commit is contained in:
ZhiweiYan-96
2024-04-01 12:21:05 +00:00
committed by PyTorch MergeBot
parent 6330acae76
commit 9875a834e4
6 changed files with 1623 additions and 6 deletions

View File

@ -18,6 +18,8 @@ cmake_policy(SET CMP0012 NEW)
#############################################
set(ATen_CPU_SRCS)
set(ATen_XPU_SRCS)
set(ATen_XPU_INCLUDE)
set(ATen_CPU_TEST_SRCS)
set(ATen_CPU_INCLUDE)
set(ATen_THIRD_PARTY_INCLUDE)
@ -39,6 +41,7 @@ set(ATen_XPU_INCLUDE)
set(ATen_XPU_TEST_SRCS)
set(ATen_VULKAN_TEST_SRCS)
set(ATen_CPU_DEPENDENCY_LIBS)
set(ATen_XPU_DEPENDENCY_LIBS)
set(ATen_CUDA_DEPENDENCY_LIBS)
set(ATen_HIP_DEPENDENCY_LIBS)
set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS)
@ -105,6 +108,8 @@ add_subdirectory(src/ATen)
# Pass source, includes, and libs to parent
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
@ -130,6 +135,7 @@ set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)

View File

@ -85,6 +85,8 @@ file(GLOB miopen_cpp "miopen/*.cpp")
file(GLOB mkl_cpp "mkl/*.cpp")
file(GLOB mkldnn_cpp "mkldnn/*.cpp")
file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.cpp")
file(GLOB native_cpp "native/*.cpp")
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
@ -238,6 +240,20 @@ else()
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
endif()
if(USE_XPU)
list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp})
list(APPEND ATen_XPU_DEPENDENCY_LIBS xpu_mkldnn)
list(APPEND ATen_XPU_DEPENDENCY_LIBS ${OCL_LIBRARY})
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu)
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu/detail)
list(APPEND ATen_XPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/include)
list(APPEND ATen_XPU_INCLUDE ${XPU_MKLDNN_INCLUDE})
list(APPEND ATen_XPU_INCLUDE ${SYCL_INCLUDE_DIR})
list(APPEND ATen_XPU_DEPENDENCY_LIBS ${SYCL_LIBRARY})
endif()
# Metal
if(USE_PYTORCH_METAL_EXPORT)
# Add files needed from exporting metal models(optimized_for_mobile)
@ -629,6 +645,7 @@ list(APPEND ATen_MOBILE_BENCHMARK_SRCS
# Pass source, includes, and libs to parent
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
@ -658,6 +675,7 @@ set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)

View File

@ -0,0 +1,436 @@
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/Resize.h>
#include <torch/library.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
namespace at::native::xpu {
// result = beta * self + alpha * (mat1 * mat2)
Tensor& addmm_out(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
at::Tensor& result) {
checkBackend("addmm_out", {result, self, mat1, mat2}, Backend::XPU);
TORCH_CHECK(
mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(
mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
result.resize_(result_shape);
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
return result;
}
if (mat1.numel() == 0){
if(beta.to<float>() == 0.f){
return result.zero_();
}
return at::mul_out(
result,
self.expand(result.sizes()),
at::native::scalar_tensor(
beta,
self.scalar_type(),
c10::nullopt,
at::kCPU,
c10::nullopt
)
);
}
TORCH_CHECK(
are_expandable(self.sizes(), result_shape),
"addmm_out input must be expanable to:",
result_shape,
" but got:",
self.sizes());
// complex/double case
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) {
AT_ERROR(
"Double and complex datatype matmul is not supported in oneDNN");
}
// general case
Tensor bias = Tensor();
onednn::Attr attr;
float beta_ = beta.to<float>();
if (beta_ == 0.f) {
if (alpha.to<float>() != 1.f) {
attr.append_post_eltwise(
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
}
} else {
if (alpha.to<float>() == 1.f && beta_ == 1.f) {
bias = self;
} else {
Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self;
// Tensor binary = self.expand_as(result);
// For post-binary-add, onednn needs binary scale=1.f
// Thus we need the following transformation
// alpha * matmul(mat1, mat2) + beta * binary
// beta * (alpha/beta * matmul(src, wei) + binary)
float alpha_ = alpha.to<float>() / beta_;
if (alpha_ != 1.f)
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
attr.append_post_binary(attr.kind_with_binary_add, binary);
if (beta_ != 1.f)
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
}
}
onednn::matmul(result, mat1, mat2, bias, true, attr);
return result;
}
Tensor& _addmm_activation_out(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
bool use_gelu,
at::Tensor& result) {
addmm_out(self, mat1, mat2, beta, alpha, result);
if (use_gelu) {
at::gelu_(result);
} else {
at::relu_(result);
}
return result;
}
Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
checkBackend("mm_out", {result, self, mat2}, Backend::XPU);
TORCH_CHECK(self.dim() == 2, "self must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
self.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
self.sizes()[0],
"x",
self.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");
result.resize_({self.size(0), mat2.size(1)});
if (self.numel() == 0 || mat2.numel() == 0) {
if (result.numel() > 0)
result.zero_();
return result;
}
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
AT_ERROR(
"Double and complex datatype matmul is not supported in oneDNN");
}
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
return result;
}
Tensor mm(const Tensor& self, const Tensor& mat2) {
auto result = at::empty({0}, self.options());
xpu::mm_out(self, mat2, result);
return result;
}
Tensor mv(const Tensor& self, const Tensor& vec) {
Tensor result = at::empty({self.size(0)}, self.options());
return at::addmv_(result, self, vec, 0, 1);
}
// result = beta * input + alpha * (batch1 @ batch2)
Tensor& baddbmm_out(
const Tensor& input,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
checkBackend("baddbmm_out", {input, batch1, batch2}, Backend::XPU);
TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");
std::vector<int64_t> result_shape = {
batch1.size(0), batch1.size(1), batch2.size(2)};
result.resize_(result_shape);
if (result.numel() == 0){
return result;
} else if (batch1.size(2) == 0){
if (beta.to<c10::complex<double>>() == 0.0){
return result.zero_();
}else{
at::mul_out(result, input, beta);
return result;
}
}
TORCH_CHECK(
are_expandable(input.sizes(), result_shape),
"baddbmm_out input must be expanable to:",
result_shape,
" but got:",
input.sizes());
// complex and double case
if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) {
AT_ERROR(
"Double and complex datatype matmul is not supported in oneDNN");
}
// general case
onednn::Attr attr;
float beta_ = beta.to<float>();
Tensor binary;
if (beta_ == 0.f) {
if (alpha.to<float>() != 1.f) {
attr.append_post_eltwise(
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
}
} else {
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary;
float alpha_ = alpha.to<float>() / beta_;
if (alpha_ != 1.f)
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
attr.append_post_binary(attr.kind_with_binary_add, binary);
if (beta_ != 1.f)
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
}
onednn::matmul(result, batch1, batch2, at::Tensor(), true, attr);
return result;
}
Tensor& baddbmm_(
Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
TORCH_CHECK(self.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", self.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
return at::native::xpu::baddbmm_out(
self, batch1, batch2, beta, alpha, self);
}
Tensor baddbmm(
const Tensor& input,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
Tensor r = at::empty({0}, input.options());
TORCH_CHECK(input.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", input.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r);
return r;
}
Tensor& addbmm_out(
const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
checkBackend("addbmm_out", {out, self, batch1, batch2}, Backend::XPU);
TORCH_CHECK(
batch1.dim() == 3 && batch2.dim() == 3,
"Batch tensors should be 3D, got dimensions ",
batch1.dim(),
" and ",
batch2.dim());
out.resize_({batch1.size(1), batch2.size(2)});
if (alpha.to<float>() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) {
out.resize_({batch1.size(1), batch2.size(2)});
if (out.numel() == 0)
return out;
if (self.defined() && beta.to<float>() != 0.f) {
out = at::mul_out(
out, self, at::native::wrapped_scalar_tensor(at::Scalar(beta)));
} else {
out.zero_();
}
return out;
}
Tensor b1;
if (batch1.size(0) > 1) {
b1 = batch1.transpose(0, 1).contiguous().view({batch1.size(1), -1});
} else {
b1 = batch1.contiguous().view({batch1.size(1), -1});
}
auto b2 = batch2.contiguous().view({-1, batch2.size(2)});
at::native::xpu::addmm_out(self, b1, b2, beta, alpha, out);
return out;
}
Tensor& addbmm_(
Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, self);
return self;
}
Tensor addbmm(
const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha) {
Tensor out = at::empty({0}, self.options());
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, out);
return out;
}
Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
checkBackend("bmm_out", {result, self, batch2}, Backend::XPU);
TORCH_CHECK(self.dim() == 3, "expected 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");
result.resize_({self.size(0), self.size(1), batch2.size(2)});
if (self.numel() == 0 || batch2.numel() == 0) {
if (result.numel() > 0)
result.zero_();
return result;
}
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
AT_ERROR(
"Double and complex datatype matmul is not supported in oneDNN");
}
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
return result;
}
Tensor bmm(const Tensor& self, const Tensor& batch2) {
auto result = at::empty({0}, self.options());
at::native::xpu::bmm_out(self, batch2, result);
return result;
}
Tensor& addmv_out(
const Tensor& self,
const Tensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
Tensor self_v;
TORCH_CHECK(
(mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1),
"vector + matrix @ vector expected, got ",
self.dim(),
", ",
mat.dim(),
", ",
vec.dim());
if (self.dim() == 1 && self.size(0) != 1) {
TORCH_CHECK(
(mat.size(1) == vec.size(0) && mat.size(0) == self.size(0)),
"size mismatch, get ",
self.size(0),
", ",
mat.size(0),
"x",
mat.size(1),
",",
vec.size(0));
self_v = self.view({self.size(0), 1});
} else {
TORCH_CHECK(
(mat.size(1) == vec.size(0)),
"size mismatch, get ",
mat.size(0),
"x",
mat.size(1),
",",
vec.size(0));
self_v = self;
}
Tensor vec_v = vec.view({vec.size(0), 1});
at::native::xpu::addmm_out(self_v, mat, vec_v, beta, alpha, out);
out.resize_({mat.size(0)});
return out;
}
Tensor& tensordot_out(
const Tensor& input1,
const Tensor& input2,
IntArrayRef dims1,
IntArrayRef dims2,
Tensor& result) {
Tensor result_tmp = at::tensordot(input1, input2, dims1, dims2);
auto result_dtype = result_tmp.scalar_type();
auto output_tensor_dtype = result.scalar_type();
auto output_device = result.device();
auto input1_device = input1.device();
auto input2_device = input2.device();
// check if the input & output tensors are on the same device.
TORCH_CHECK(
(output_device == input1_device) && (input1_device == input2_device),
"tensordot: Expected the output and input tensors to be on the "
"same device, but got the output tensor on ",
output_device,
", input tensor a on ",
input1_device,
", and input tensor b on ",
input2_device);
// check if the computed result has the same dtype as the out tensor
// (because tensordot does not support type promotion)
TORCH_CHECK(
result_dtype == output_tensor_dtype,
"tensordot",
": Expected the output tensor to have dtype ",
result_dtype,
", but got an output tensor with dtype ",
output_tensor_dtype);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}
TORCH_LIBRARY_IMPL(aten, XPU, m){
m.impl("addmm.out", TORCH_FN(addmm_out));
m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out));
m.impl("mm.out", TORCH_FN(mm_out));
m.impl("mm", TORCH_FN(mm));
m.impl("baddbmm.out", TORCH_FN(baddbmm_out));
m.impl("baddbmm_", TORCH_FN(baddbmm_));
m.impl("baddbmm", TORCH_FN(baddbmm));
m.impl("addbmm.out", TORCH_FN(addbmm_out));
m.impl("addbmm_", TORCH_FN(addbmm_));
m.impl("addbmm", TORCH_FN(addbmm));
m.impl("bmm.out", TORCH_FN(bmm_out));
m.impl("bmm", TORCH_FN(bmm));
m.impl("addmv.out", TORCH_FN(addmv_out));
m.impl("tensordot.out", TORCH_FN(tensordot_out));
}
} // namespace at::native::xpu

View File

@ -80,6 +80,9 @@ if(INTERN_BUILD_ATEN_OPS)
# Add source, includes, and libs to lists
list(APPEND Caffe2_CPU_SRCS ${ATen_CPU_SRCS})
list(APPEND Caffe2_GPU_SRCS ${ATen_CUDA_CPP_SRCS})
list(APPEND Caffe2_XPU_SRCS ${ATen_XPU_SRCS})
list(APPEND Caffe2_XPU_INCLUDE ${ATen_XPU_INCLUDE})
list(APPEND Caffe2_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS})
list(APPEND Caffe2_GPU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY})
list(APPEND Caffe2_GPU_CU_SRCS ${ATen_CUDA_CU_SRCS})
list(APPEND Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY})
@ -174,6 +177,7 @@ endif()
if(CAFFE2_ALLOWLISTED_FILES)
caffe2_do_allowlist(Caffe2_CPU_SRCS CAFFE2_ALLOWLISTED_FILES)
caffe2_do_allowlist(Caffe2_GPU_SRCS CAFFE2_ALLOWLISTED_FILES)
caffe2_do_allowlist(Caffe2_XPU_SRCS CAFFE2_ALLOWLISTED_FILES)
caffe2_do_allowlist(Caffe2_GPU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES)
caffe2_do_allowlist(Caffe2_GPU_CU_SRCS CAFFE2_ALLOWLISTED_FILES)
caffe2_do_allowlist(Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES)
@ -1607,9 +1611,7 @@ if(USE_CUDA)
caffe2_interface_library(torch_cuda torch_cuda_library)
elseif(USE_ROCM)
caffe2_interface_library(torch_hip torch_hip_library)
endif()
if(USE_XPU)
elseif(USE_XPU)
caffe2_interface_library(torch_xpu torch_xpu_library)
endif()
@ -1621,9 +1623,7 @@ if(USE_CUDA)
install(TARGETS torch_cuda torch_cuda_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
elseif(USE_ROCM)
install(TARGETS torch_hip torch_hip_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()
if(USE_XPU)
elseif(USE_XPU)
install(TARGETS torch_xpu torch_xpu_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()
@ -1689,6 +1689,8 @@ if(USE_XPU)
torch_xpu INTERFACE $<INSTALL_INTERFACE:include>)
target_include_directories(
torch_xpu PRIVATE ${Caffe2_XPU_INCLUDE})
target_link_libraries(
torch_xpu PRIVATE ${Caffe2_XPU_DEPENDENCY_LIBS})
target_link_libraries(torch_xpu PUBLIC torch_cpu_library)
endif()

View File

@ -56,6 +56,13 @@ find_library(
NO_DEFAULT_PATH
)
find_library(
OCL_LIBRARY
NAMES OpenCL
HINTS ${SYCL_LIBRARY_DIR}
NO_DEFAULT_PATH
)
if((NOT SYCL_INCLUDE_DIR) OR (NOT SYCL_LIBRARY_DIR) OR (NOT SYCL_LIBRARY))
set(SYCL_FOUND False)
set(SYCL_REASON_FAILURE "SYCL library is incomplete!!")

1148
test/xpu/test_gemm.py Normal file

File diff suppressed because it is too large Load Diff