mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Intel GPU] oneDNN GPU GEMM support (#117202)
# Motivation This PR is a part of RFC #114848, and it is a successor PR of #116249 and #116019. This PR would depend on oneDNN compilation in #116249. Some runtime support is needed in #116019. Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`. Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`. `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in #117824. In next stage, we would continually complete the oneDNN support through enabling `matmul fusion` and `convolution` related code. Co-authored-by: xiaolil1 <xiaoli.liu@intel.com> Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/117202 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet ghstack dependencies: #117098, #117112
This commit is contained in:
committed by
PyTorch MergeBot
parent
6330acae76
commit
9875a834e4
@ -18,6 +18,8 @@ cmake_policy(SET CMP0012 NEW)
|
||||
#############################################
|
||||
|
||||
set(ATen_CPU_SRCS)
|
||||
set(ATen_XPU_SRCS)
|
||||
set(ATen_XPU_INCLUDE)
|
||||
set(ATen_CPU_TEST_SRCS)
|
||||
set(ATen_CPU_INCLUDE)
|
||||
set(ATen_THIRD_PARTY_INCLUDE)
|
||||
@ -39,6 +41,7 @@ set(ATen_XPU_INCLUDE)
|
||||
set(ATen_XPU_TEST_SRCS)
|
||||
set(ATen_VULKAN_TEST_SRCS)
|
||||
set(ATen_CPU_DEPENDENCY_LIBS)
|
||||
set(ATen_XPU_DEPENDENCY_LIBS)
|
||||
set(ATen_CUDA_DEPENDENCY_LIBS)
|
||||
set(ATen_HIP_DEPENDENCY_LIBS)
|
||||
set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS)
|
||||
@ -105,6 +108,8 @@ add_subdirectory(src/ATen)
|
||||
# Pass source, includes, and libs to parent
|
||||
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
|
||||
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
|
||||
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
|
||||
@ -130,6 +135,7 @@ set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
|
||||
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
|
||||
set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
|
||||
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
|
||||
|
@ -85,6 +85,8 @@ file(GLOB miopen_cpp "miopen/*.cpp")
|
||||
file(GLOB mkl_cpp "mkl/*.cpp")
|
||||
file(GLOB mkldnn_cpp "mkldnn/*.cpp")
|
||||
|
||||
file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.cpp")
|
||||
|
||||
file(GLOB native_cpp "native/*.cpp")
|
||||
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
|
||||
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
|
||||
@ -238,6 +240,20 @@ else()
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp})
|
||||
list(APPEND ATen_XPU_DEPENDENCY_LIBS xpu_mkldnn)
|
||||
|
||||
list(APPEND ATen_XPU_DEPENDENCY_LIBS ${OCL_LIBRARY})
|
||||
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu)
|
||||
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu/detail)
|
||||
list(APPEND ATen_XPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/include)
|
||||
list(APPEND ATen_XPU_INCLUDE ${XPU_MKLDNN_INCLUDE})
|
||||
|
||||
list(APPEND ATen_XPU_INCLUDE ${SYCL_INCLUDE_DIR})
|
||||
list(APPEND ATen_XPU_DEPENDENCY_LIBS ${SYCL_LIBRARY})
|
||||
endif()
|
||||
|
||||
# Metal
|
||||
if(USE_PYTORCH_METAL_EXPORT)
|
||||
# Add files needed from exporting metal models(optimized_for_mobile)
|
||||
@ -629,6 +645,7 @@ list(APPEND ATen_MOBILE_BENCHMARK_SRCS
|
||||
# Pass source, includes, and libs to parent
|
||||
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
|
||||
@ -658,6 +675,7 @@ set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
|
||||
set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
|
||||
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
|
||||
set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
|
||||
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
|
||||
|
436
aten/src/ATen/native/mkldnn/xpu/Blas.cpp
Normal file
436
aten/src/ATen/native/mkldnn/xpu/Blas.cpp
Normal file
@ -0,0 +1,436 @@
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
|
||||
namespace at::native::xpu {
|
||||
|
||||
// result = beta * self + alpha * (mat1 * mat2)
|
||||
Tensor& addmm_out(
|
||||
const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
at::Tensor& result) {
|
||||
checkBackend("addmm_out", {result, self, mat1, mat2}, Backend::XPU);
|
||||
TORCH_CHECK(
|
||||
mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0],
|
||||
"x",
|
||||
mat1.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
|
||||
IntArrayRef result_sizes = result.sizes();
|
||||
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
if (mat1.numel() == 0){
|
||||
if(beta.to<float>() == 0.f){
|
||||
return result.zero_();
|
||||
}
|
||||
return at::mul_out(
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
c10::nullopt,
|
||||
at::kCPU,
|
||||
c10::nullopt
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
are_expandable(self.sizes(), result_shape),
|
||||
"addmm_out input must be expanable to:",
|
||||
result_shape,
|
||||
" but got:",
|
||||
self.sizes());
|
||||
|
||||
// complex/double case
|
||||
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) {
|
||||
AT_ERROR(
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
// general case
|
||||
Tensor bias = Tensor();
|
||||
onednn::Attr attr;
|
||||
float beta_ = beta.to<float>();
|
||||
if (beta_ == 0.f) {
|
||||
if (alpha.to<float>() != 1.f) {
|
||||
attr.append_post_eltwise(
|
||||
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
|
||||
}
|
||||
} else {
|
||||
if (alpha.to<float>() == 1.f && beta_ == 1.f) {
|
||||
bias = self;
|
||||
} else {
|
||||
Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self;
|
||||
// Tensor binary = self.expand_as(result);
|
||||
// For post-binary-add, onednn needs binary scale=1.f
|
||||
// Thus we need the following transformation
|
||||
// alpha * matmul(mat1, mat2) + beta * binary
|
||||
// beta * (alpha/beta * matmul(src, wei) + binary)
|
||||
float alpha_ = alpha.to<float>() / beta_;
|
||||
if (alpha_ != 1.f)
|
||||
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
|
||||
attr.append_post_binary(attr.kind_with_binary_add, binary);
|
||||
if (beta_ != 1.f)
|
||||
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
|
||||
}
|
||||
}
|
||||
onednn::matmul(result, mat1, mat2, bias, true, attr);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& _addmm_activation_out(
|
||||
const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
bool use_gelu,
|
||||
at::Tensor& result) {
|
||||
addmm_out(self, mat1, mat2, beta, alpha, result);
|
||||
if (use_gelu) {
|
||||
at::gelu_(result);
|
||||
} else {
|
||||
at::relu_(result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
checkBackend("mm_out", {result, self, mat2}, Backend::XPU);
|
||||
TORCH_CHECK(self.dim() == 2, "self must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
TORCH_CHECK(
|
||||
self.sizes()[1] == mat2.sizes()[0],
|
||||
"mat1 and mat2 shapes cannot be multiplied (",
|
||||
self.sizes()[0],
|
||||
"x",
|
||||
self.sizes()[1],
|
||||
" and ",
|
||||
mat2.sizes()[0],
|
||||
"x",
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
|
||||
result.resize_({self.size(0), mat2.size(1)});
|
||||
if (self.numel() == 0 || mat2.numel() == 0) {
|
||||
if (result.numel() > 0)
|
||||
result.zero_();
|
||||
return result;
|
||||
}
|
||||
|
||||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
|
||||
AT_ERROR(
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor mm(const Tensor& self, const Tensor& mat2) {
|
||||
auto result = at::empty({0}, self.options());
|
||||
xpu::mm_out(self, mat2, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor mv(const Tensor& self, const Tensor& vec) {
|
||||
Tensor result = at::empty({self.size(0)}, self.options());
|
||||
return at::addmv_(result, self, vec, 0, 1);
|
||||
}
|
||||
|
||||
|
||||
// result = beta * input + alpha * (batch1 @ batch2)
|
||||
Tensor& baddbmm_out(
|
||||
const Tensor& input,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& result) {
|
||||
checkBackend("baddbmm_out", {input, batch1, batch2}, Backend::XPU);
|
||||
TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");
|
||||
|
||||
std::vector<int64_t> result_shape = {
|
||||
batch1.size(0), batch1.size(1), batch2.size(2)};
|
||||
result.resize_(result_shape);
|
||||
if (result.numel() == 0){
|
||||
return result;
|
||||
} else if (batch1.size(2) == 0){
|
||||
if (beta.to<c10::complex<double>>() == 0.0){
|
||||
return result.zero_();
|
||||
}else{
|
||||
at::mul_out(result, input, beta);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
are_expandable(input.sizes(), result_shape),
|
||||
"baddbmm_out input must be expanable to:",
|
||||
result_shape,
|
||||
" but got:",
|
||||
input.sizes());
|
||||
|
||||
// complex and double case
|
||||
if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) {
|
||||
AT_ERROR(
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
// general case
|
||||
onednn::Attr attr;
|
||||
float beta_ = beta.to<float>();
|
||||
Tensor binary;
|
||||
if (beta_ == 0.f) {
|
||||
if (alpha.to<float>() != 1.f) {
|
||||
attr.append_post_eltwise(
|
||||
1.f, alpha.to<float>(), 0.f, attr.kind_with_linear);
|
||||
}
|
||||
} else {
|
||||
binary = input.dim() < 3 ? input.unsqueeze(0) : input;
|
||||
binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary;
|
||||
float alpha_ = alpha.to<float>() / beta_;
|
||||
if (alpha_ != 1.f)
|
||||
attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear);
|
||||
attr.append_post_binary(attr.kind_with_binary_add, binary);
|
||||
if (beta_ != 1.f)
|
||||
attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear);
|
||||
}
|
||||
onednn::matmul(result, batch1, batch2, at::Tensor(), true, attr);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& baddbmm_(
|
||||
Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
TORCH_CHECK(self.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", self.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
|
||||
return at::native::xpu::baddbmm_out(
|
||||
self, batch1, batch2, beta, alpha, self);
|
||||
}
|
||||
|
||||
Tensor baddbmm(
|
||||
const Tensor& input,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
Tensor r = at::empty({0}, input.options());
|
||||
TORCH_CHECK(input.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", input.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
|
||||
r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r);
|
||||
return r;
|
||||
}
|
||||
|
||||
Tensor& addbmm_out(
|
||||
const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& out) {
|
||||
checkBackend("addbmm_out", {out, self, batch1, batch2}, Backend::XPU);
|
||||
TORCH_CHECK(
|
||||
batch1.dim() == 3 && batch2.dim() == 3,
|
||||
"Batch tensors should be 3D, got dimensions ",
|
||||
batch1.dim(),
|
||||
" and ",
|
||||
batch2.dim());
|
||||
|
||||
out.resize_({batch1.size(1), batch2.size(2)});
|
||||
if (alpha.to<float>() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) {
|
||||
out.resize_({batch1.size(1), batch2.size(2)});
|
||||
if (out.numel() == 0)
|
||||
return out;
|
||||
|
||||
if (self.defined() && beta.to<float>() != 0.f) {
|
||||
out = at::mul_out(
|
||||
out, self, at::native::wrapped_scalar_tensor(at::Scalar(beta)));
|
||||
} else {
|
||||
out.zero_();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor b1;
|
||||
if (batch1.size(0) > 1) {
|
||||
b1 = batch1.transpose(0, 1).contiguous().view({batch1.size(1), -1});
|
||||
} else {
|
||||
b1 = batch1.contiguous().view({batch1.size(1), -1});
|
||||
}
|
||||
auto b2 = batch2.contiguous().view({-1, batch2.size(2)});
|
||||
at::native::xpu::addmm_out(self, b1, b2, beta, alpha, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& addbmm_(
|
||||
Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, self);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor addbmm(
|
||||
const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
Tensor out = at::empty({0}, self.options());
|
||||
at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
||||
checkBackend("bmm_out", {result, self, batch2}, Backend::XPU);
|
||||
TORCH_CHECK(self.dim() == 3, "expected 3D tensor");
|
||||
TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor");
|
||||
|
||||
result.resize_({self.size(0), self.size(1), batch2.size(2)});
|
||||
if (self.numel() == 0 || batch2.numel() == 0) {
|
||||
if (result.numel() > 0)
|
||||
result.zero_();
|
||||
return result;
|
||||
}
|
||||
|
||||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
|
||||
AT_ERROR(
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor bmm(const Tensor& self, const Tensor& batch2) {
|
||||
auto result = at::empty({0}, self.options());
|
||||
at::native::xpu::bmm_out(self, batch2, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor& addmv_out(
|
||||
const Tensor& self,
|
||||
const Tensor& mat,
|
||||
const Tensor& vec,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Tensor& out) {
|
||||
Tensor self_v;
|
||||
TORCH_CHECK(
|
||||
(mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1),
|
||||
"vector + matrix @ vector expected, got ",
|
||||
self.dim(),
|
||||
", ",
|
||||
mat.dim(),
|
||||
", ",
|
||||
vec.dim());
|
||||
if (self.dim() == 1 && self.size(0) != 1) {
|
||||
TORCH_CHECK(
|
||||
(mat.size(1) == vec.size(0) && mat.size(0) == self.size(0)),
|
||||
"size mismatch, get ",
|
||||
self.size(0),
|
||||
", ",
|
||||
mat.size(0),
|
||||
"x",
|
||||
mat.size(1),
|
||||
",",
|
||||
vec.size(0));
|
||||
self_v = self.view({self.size(0), 1});
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
(mat.size(1) == vec.size(0)),
|
||||
"size mismatch, get ",
|
||||
mat.size(0),
|
||||
"x",
|
||||
mat.size(1),
|
||||
",",
|
||||
vec.size(0));
|
||||
self_v = self;
|
||||
}
|
||||
|
||||
Tensor vec_v = vec.view({vec.size(0), 1});
|
||||
at::native::xpu::addmm_out(self_v, mat, vec_v, beta, alpha, out);
|
||||
out.resize_({mat.size(0)});
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor& tensordot_out(
|
||||
const Tensor& input1,
|
||||
const Tensor& input2,
|
||||
IntArrayRef dims1,
|
||||
IntArrayRef dims2,
|
||||
Tensor& result) {
|
||||
Tensor result_tmp = at::tensordot(input1, input2, dims1, dims2);
|
||||
auto result_dtype = result_tmp.scalar_type();
|
||||
auto output_tensor_dtype = result.scalar_type();
|
||||
auto output_device = result.device();
|
||||
auto input1_device = input1.device();
|
||||
auto input2_device = input2.device();
|
||||
// check if the input & output tensors are on the same device.
|
||||
TORCH_CHECK(
|
||||
(output_device == input1_device) && (input1_device == input2_device),
|
||||
"tensordot: Expected the output and input tensors to be on the "
|
||||
"same device, but got the output tensor on ",
|
||||
output_device,
|
||||
", input tensor a on ",
|
||||
input1_device,
|
||||
", and input tensor b on ",
|
||||
input2_device);
|
||||
// check if the computed result has the same dtype as the out tensor
|
||||
// (because tensordot does not support type promotion)
|
||||
TORCH_CHECK(
|
||||
result_dtype == output_tensor_dtype,
|
||||
"tensordot",
|
||||
": Expected the output tensor to have dtype ",
|
||||
result_dtype,
|
||||
", but got an output tensor with dtype ",
|
||||
output_tensor_dtype);
|
||||
at::native::resize_output(result, result_tmp.sizes());
|
||||
result.copy_(result_tmp);
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m){
|
||||
m.impl("addmm.out", TORCH_FN(addmm_out));
|
||||
m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out));
|
||||
m.impl("mm.out", TORCH_FN(mm_out));
|
||||
m.impl("mm", TORCH_FN(mm));
|
||||
m.impl("baddbmm.out", TORCH_FN(baddbmm_out));
|
||||
m.impl("baddbmm_", TORCH_FN(baddbmm_));
|
||||
m.impl("baddbmm", TORCH_FN(baddbmm));
|
||||
m.impl("addbmm.out", TORCH_FN(addbmm_out));
|
||||
m.impl("addbmm_", TORCH_FN(addbmm_));
|
||||
m.impl("addbmm", TORCH_FN(addbmm));
|
||||
m.impl("bmm.out", TORCH_FN(bmm_out));
|
||||
m.impl("bmm", TORCH_FN(bmm));
|
||||
m.impl("addmv.out", TORCH_FN(addmv_out));
|
||||
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
||||
}
|
||||
|
||||
} // namespace at::native::xpu
|
@ -80,6 +80,9 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
# Add source, includes, and libs to lists
|
||||
list(APPEND Caffe2_CPU_SRCS ${ATen_CPU_SRCS})
|
||||
list(APPEND Caffe2_GPU_SRCS ${ATen_CUDA_CPP_SRCS})
|
||||
list(APPEND Caffe2_XPU_SRCS ${ATen_XPU_SRCS})
|
||||
list(APPEND Caffe2_XPU_INCLUDE ${ATen_XPU_INCLUDE})
|
||||
list(APPEND Caffe2_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS})
|
||||
list(APPEND Caffe2_GPU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY})
|
||||
list(APPEND Caffe2_GPU_CU_SRCS ${ATen_CUDA_CU_SRCS})
|
||||
list(APPEND Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY})
|
||||
@ -174,6 +177,7 @@ endif()
|
||||
if(CAFFE2_ALLOWLISTED_FILES)
|
||||
caffe2_do_allowlist(Caffe2_CPU_SRCS CAFFE2_ALLOWLISTED_FILES)
|
||||
caffe2_do_allowlist(Caffe2_GPU_SRCS CAFFE2_ALLOWLISTED_FILES)
|
||||
caffe2_do_allowlist(Caffe2_XPU_SRCS CAFFE2_ALLOWLISTED_FILES)
|
||||
caffe2_do_allowlist(Caffe2_GPU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES)
|
||||
caffe2_do_allowlist(Caffe2_GPU_CU_SRCS CAFFE2_ALLOWLISTED_FILES)
|
||||
caffe2_do_allowlist(Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES)
|
||||
@ -1607,9 +1611,7 @@ if(USE_CUDA)
|
||||
caffe2_interface_library(torch_cuda torch_cuda_library)
|
||||
elseif(USE_ROCM)
|
||||
caffe2_interface_library(torch_hip torch_hip_library)
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
elseif(USE_XPU)
|
||||
caffe2_interface_library(torch_xpu torch_xpu_library)
|
||||
endif()
|
||||
|
||||
@ -1621,9 +1623,7 @@ if(USE_CUDA)
|
||||
install(TARGETS torch_cuda torch_cuda_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
elseif(USE_ROCM)
|
||||
install(TARGETS torch_hip torch_hip_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
elseif(USE_XPU)
|
||||
install(TARGETS torch_xpu torch_xpu_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}")
|
||||
endif()
|
||||
|
||||
@ -1689,6 +1689,8 @@ if(USE_XPU)
|
||||
torch_xpu INTERFACE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(
|
||||
torch_xpu PRIVATE ${Caffe2_XPU_INCLUDE})
|
||||
target_link_libraries(
|
||||
torch_xpu PRIVATE ${Caffe2_XPU_DEPENDENCY_LIBS})
|
||||
|
||||
target_link_libraries(torch_xpu PUBLIC torch_cpu_library)
|
||||
endif()
|
||||
|
@ -56,6 +56,13 @@ find_library(
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
find_library(
|
||||
OCL_LIBRARY
|
||||
NAMES OpenCL
|
||||
HINTS ${SYCL_LIBRARY_DIR}
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
if((NOT SYCL_INCLUDE_DIR) OR (NOT SYCL_LIBRARY_DIR) OR (NOT SYCL_LIBRARY))
|
||||
set(SYCL_FOUND False)
|
||||
set(SYCL_REASON_FAILURE "SYCL library is incomplete!!")
|
||||
|
1148
test/xpu/test_gemm.py
Normal file
1148
test/xpu/test_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user