From 9875a834e48f23aecb32c9128c8e51bd841b42d7 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 1 Apr 2024 12:21:05 +0000 Subject: [PATCH] [Intel GPU] oneDNN GPU GEMM support (#117202) # Motivation This PR is a part of RFC #114848, and it is a successor PR of #116249 and #116019. This PR would depend on oneDNN compilation in #116249. Some runtime support is needed in #116019. Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`. Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`. `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in #117824. In next stage, we would continually complete the oneDNN support through enabling `matmul fusion` and `convolution` related code. Co-authored-by: xiaolil1 Co-authored-by: lei,zhenyuan Pull Request resolved: https://github.com/pytorch/pytorch/pull/117202 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet ghstack dependencies: #117098, #117112 --- aten/CMakeLists.txt | 6 + aten/src/ATen/CMakeLists.txt | 18 + aten/src/ATen/native/mkldnn/xpu/Blas.cpp | 436 ++++++++ caffe2/CMakeLists.txt | 14 +- cmake/Modules/FindSYCLToolkit.cmake | 7 + test/xpu/test_gemm.py | 1148 ++++++++++++++++++++++ 6 files changed, 1623 insertions(+), 6 deletions(-) create mode 100644 aten/src/ATen/native/mkldnn/xpu/Blas.cpp create mode 100644 test/xpu/test_gemm.py diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index 427a1b87be1f..bda6aea32706 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -18,6 +18,8 @@ cmake_policy(SET CMP0012 NEW) ############################################# set(ATen_CPU_SRCS) +set(ATen_XPU_SRCS) +set(ATen_XPU_INCLUDE) set(ATen_CPU_TEST_SRCS) set(ATen_CPU_INCLUDE) set(ATen_THIRD_PARTY_INCLUDE) @@ -39,6 +41,7 @@ set(ATen_XPU_INCLUDE) set(ATen_XPU_TEST_SRCS) set(ATen_VULKAN_TEST_SRCS) set(ATen_CPU_DEPENDENCY_LIBS) +set(ATen_XPU_DEPENDENCY_LIBS) set(ATen_CUDA_DEPENDENCY_LIBS) set(ATen_HIP_DEPENDENCY_LIBS) set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS) @@ -105,6 +108,8 @@ add_subdirectory(src/ATen) # Pass source, includes, and libs to parent set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) +set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) +set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE) set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE) set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE) @@ -130,6 +135,7 @@ set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE) set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 8d50e9c3721c..583662e6c63d 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -85,6 +85,8 @@ file(GLOB miopen_cpp "miopen/*.cpp") file(GLOB mkl_cpp "mkl/*.cpp") file(GLOB mkldnn_cpp "mkldnn/*.cpp") +file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.cpp") + file(GLOB native_cpp "native/*.cpp") file(GLOB native_mkl_cpp "native/mkl/*.cpp") file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") @@ -238,6 +240,20 @@ else() set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) endif() +if(USE_XPU) + list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp}) + list(APPEND ATen_XPU_DEPENDENCY_LIBS xpu_mkldnn) + + list(APPEND ATen_XPU_DEPENDENCY_LIBS ${OCL_LIBRARY}) + list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu) + list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu/detail) + list(APPEND ATen_XPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/include) + list(APPEND ATen_XPU_INCLUDE ${XPU_MKLDNN_INCLUDE}) + + list(APPEND ATen_XPU_INCLUDE ${SYCL_INCLUDE_DIR}) + list(APPEND ATen_XPU_DEPENDENCY_LIBS ${SYCL_LIBRARY}) +endif() + # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) @@ -629,6 +645,7 @@ list(APPEND ATen_MOBILE_BENCHMARK_SRCS # Pass source, includes, and libs to parent set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) +set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE) set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE) set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE) @@ -658,6 +675,7 @@ set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE) set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE) set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE) +set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE) set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE) set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE) diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp new file mode 100644 index 000000000000..6cba3f4c9fa1 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -0,0 +1,436 @@ +#include +#include +#include +#include + +namespace at::native::xpu { + +// result = beta * self + alpha * (mat1 * mat2) +Tensor& addmm_out( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + at::Tensor& result) { + checkBackend("addmm_out", {result, self, mat1, mat2}, Backend::XPU); + TORCH_CHECK( + mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); + TORCH_CHECK( + mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); + TORCH_CHECK( + mat1.sizes()[1] == mat2.sizes()[0], + "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[0], + "x", + mat1.sizes()[1], + " and ", + mat2.sizes()[0], + "x", + mat2.sizes()[1], + ")"); + + std::vector result_shape = {mat1.size(0), mat2.size(1)}; + result.resize_(result_shape); + + IntArrayRef result_sizes = result.sizes(); + if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { + return result; + } + + if (mat1.numel() == 0){ + if(beta.to() == 0.f){ + return result.zero_(); + } + return at::mul_out( + result, + self.expand(result.sizes()), + at::native::scalar_tensor( + beta, + self.scalar_type(), + c10::nullopt, + at::kCPU, + c10::nullopt + ) + ); + } + + TORCH_CHECK( + are_expandable(self.sizes(), result_shape), + "addmm_out input must be expanable to:", + result_shape, + " but got:", + self.sizes()); + + // complex/double case + if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) { + AT_ERROR( + "Double and complex datatype matmul is not supported in oneDNN"); + } + + // general case + Tensor bias = Tensor(); + onednn::Attr attr; + float beta_ = beta.to(); + if (beta_ == 0.f) { + if (alpha.to() != 1.f) { + attr.append_post_eltwise( + 1.f, alpha.to(), 0.f, attr.kind_with_linear); + } + } else { + if (alpha.to() == 1.f && beta_ == 1.f) { + bias = self; + } else { + Tensor binary = self.dim() == 1 ? self.unsqueeze(0) : self; + // Tensor binary = self.expand_as(result); + // For post-binary-add, onednn needs binary scale=1.f + // Thus we need the following transformation + // alpha * matmul(mat1, mat2) + beta * binary + // beta * (alpha/beta * matmul(src, wei) + binary) + float alpha_ = alpha.to() / beta_; + if (alpha_ != 1.f) + attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear); + attr.append_post_binary(attr.kind_with_binary_add, binary); + if (beta_ != 1.f) + attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear); + } + } + onednn::matmul(result, mat1, mat2, bias, true, attr); + return result; +} + +Tensor& _addmm_activation_out( + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + bool use_gelu, + at::Tensor& result) { + addmm_out(self, mat1, mat2, beta, alpha, result); + if (use_gelu) { + at::gelu_(result); + } else { + at::relu_(result); + } + return result; +} + +Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) { + checkBackend("mm_out", {result, self, mat2}, Backend::XPU); + TORCH_CHECK(self.dim() == 2, "self must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( + self.sizes()[1] == mat2.sizes()[0], + "mat1 and mat2 shapes cannot be multiplied (", + self.sizes()[0], + "x", + self.sizes()[1], + " and ", + mat2.sizes()[0], + "x", + mat2.sizes()[1], + ")"); + + result.resize_({self.size(0), mat2.size(1)}); + if (self.numel() == 0 || mat2.numel() == 0) { + if (result.numel() > 0) + result.zero_(); + return result; + } + + if (self.is_complex() || self.scalar_type() == ScalarType::Double) { + AT_ERROR( + "Double and complex datatype matmul is not supported in oneDNN"); + } + + onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr()); + return result; +} + +Tensor mm(const Tensor& self, const Tensor& mat2) { + auto result = at::empty({0}, self.options()); + xpu::mm_out(self, mat2, result); + return result; +} + +Tensor mv(const Tensor& self, const Tensor& vec) { + Tensor result = at::empty({self.size(0)}, self.options()); + return at::addmv_(result, self, vec, 0, 1); +} + + +// result = beta * input + alpha * (batch1 @ batch2) +Tensor& baddbmm_out( + const Tensor& input, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { + checkBackend("baddbmm_out", {input, batch1, batch2}, Backend::XPU); + TORCH_CHECK(batch1.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor"); + + std::vector result_shape = { + batch1.size(0), batch1.size(1), batch2.size(2)}; + result.resize_(result_shape); + if (result.numel() == 0){ + return result; + } else if (batch1.size(2) == 0){ + if (beta.to>() == 0.0){ + return result.zero_(); + }else{ + at::mul_out(result, input, beta); + return result; + } + } + + TORCH_CHECK( + are_expandable(input.sizes(), result_shape), + "baddbmm_out input must be expanable to:", + result_shape, + " but got:", + input.sizes()); + + // complex and double case + if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) { + AT_ERROR( + "Double and complex datatype matmul is not supported in oneDNN"); + } + + // general case + onednn::Attr attr; + float beta_ = beta.to(); + Tensor binary; + if (beta_ == 0.f) { + if (alpha.to() != 1.f) { + attr.append_post_eltwise( + 1.f, alpha.to(), 0.f, attr.kind_with_linear); + } + } else { + binary = input.dim() < 3 ? input.unsqueeze(0) : input; + binary = binary.dim() < 3 ? binary.unsqueeze_(0) : binary; + float alpha_ = alpha.to() / beta_; + if (alpha_ != 1.f) + attr.append_post_eltwise(1.f, alpha_, 0.f, attr.kind_with_linear); + attr.append_post_binary(attr.kind_with_binary_add, binary); + if (beta_ != 1.f) + attr.append_post_eltwise(1.f, beta_, 0.f, attr.kind_with_linear); + } + onednn::matmul(result, batch1, batch2, at::Tensor(), true, attr); + return result; +} + +Tensor& baddbmm_( + Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha) { + TORCH_CHECK(self.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", self.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype()); + return at::native::xpu::baddbmm_out( + self, batch1, batch2, beta, alpha, self); +} + +Tensor baddbmm( + const Tensor& input, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha) { + Tensor r = at::empty({0}, input.options()); + TORCH_CHECK(input.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", input.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype()); + r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r); + return r; +} + +Tensor& addbmm_out( + const Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + Tensor& out) { + checkBackend("addbmm_out", {out, self, batch1, batch2}, Backend::XPU); + TORCH_CHECK( + batch1.dim() == 3 && batch2.dim() == 3, + "Batch tensors should be 3D, got dimensions ", + batch1.dim(), + " and ", + batch2.dim()); + + out.resize_({batch1.size(1), batch2.size(2)}); + if (alpha.to() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) { + out.resize_({batch1.size(1), batch2.size(2)}); + if (out.numel() == 0) + return out; + + if (self.defined() && beta.to() != 0.f) { + out = at::mul_out( + out, self, at::native::wrapped_scalar_tensor(at::Scalar(beta))); + } else { + out.zero_(); + } + return out; + } + + Tensor b1; + if (batch1.size(0) > 1) { + b1 = batch1.transpose(0, 1).contiguous().view({batch1.size(1), -1}); + } else { + b1 = batch1.contiguous().view({batch1.size(1), -1}); + } + auto b2 = batch2.contiguous().view({-1, batch2.size(2)}); + at::native::xpu::addmm_out(self, b1, b2, beta, alpha, out); + + return out; +} + +Tensor& addbmm_( + Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha) { + at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, self); + return self; +} + +Tensor addbmm( + const Tensor& self, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha) { + Tensor out = at::empty({0}, self.options()); + at::native::xpu::addbmm_out(self, batch1, batch2, beta, alpha, out); + return out; +} + +Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { + checkBackend("bmm_out", {result, self, batch2}, Backend::XPU); + TORCH_CHECK(self.dim() == 3, "expected 3D tensor"); + TORCH_CHECK(batch2.dim() == 3, "expected 3D tensor"); + + result.resize_({self.size(0), self.size(1), batch2.size(2)}); + if (self.numel() == 0 || batch2.numel() == 0) { + if (result.numel() > 0) + result.zero_(); + return result; + } + + if (self.is_complex() || self.scalar_type() == ScalarType::Double) { + AT_ERROR( + "Double and complex datatype matmul is not supported in oneDNN"); + } + onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr()); + return result; +} + +Tensor bmm(const Tensor& self, const Tensor& batch2) { + auto result = at::empty({0}, self.options()); + at::native::xpu::bmm_out(self, batch2, result); + return result; +} + +Tensor& addmv_out( + const Tensor& self, + const Tensor& mat, + const Tensor& vec, + const Scalar& beta, + const Scalar& alpha, + Tensor& out) { + Tensor self_v; + TORCH_CHECK( + (mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1), + "vector + matrix @ vector expected, got ", + self.dim(), + ", ", + mat.dim(), + ", ", + vec.dim()); + if (self.dim() == 1 && self.size(0) != 1) { + TORCH_CHECK( + (mat.size(1) == vec.size(0) && mat.size(0) == self.size(0)), + "size mismatch, get ", + self.size(0), + ", ", + mat.size(0), + "x", + mat.size(1), + ",", + vec.size(0)); + self_v = self.view({self.size(0), 1}); + } else { + TORCH_CHECK( + (mat.size(1) == vec.size(0)), + "size mismatch, get ", + mat.size(0), + "x", + mat.size(1), + ",", + vec.size(0)); + self_v = self; + } + + Tensor vec_v = vec.view({vec.size(0), 1}); + at::native::xpu::addmm_out(self_v, mat, vec_v, beta, alpha, out); + out.resize_({mat.size(0)}); + return out; +} + +Tensor& tensordot_out( + const Tensor& input1, + const Tensor& input2, + IntArrayRef dims1, + IntArrayRef dims2, + Tensor& result) { + Tensor result_tmp = at::tensordot(input1, input2, dims1, dims2); + auto result_dtype = result_tmp.scalar_type(); + auto output_tensor_dtype = result.scalar_type(); + auto output_device = result.device(); + auto input1_device = input1.device(); + auto input2_device = input2.device(); + // check if the input & output tensors are on the same device. + TORCH_CHECK( + (output_device == input1_device) && (input1_device == input2_device), + "tensordot: Expected the output and input tensors to be on the " + "same device, but got the output tensor on ", + output_device, + ", input tensor a on ", + input1_device, + ", and input tensor b on ", + input2_device); + // check if the computed result has the same dtype as the out tensor + // (because tensordot does not support type promotion) + TORCH_CHECK( + result_dtype == output_tensor_dtype, + "tensordot", + ": Expected the output tensor to have dtype ", + result_dtype, + ", but got an output tensor with dtype ", + output_tensor_dtype); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + +TORCH_LIBRARY_IMPL(aten, XPU, m){ + m.impl("addmm.out", TORCH_FN(addmm_out)); + m.impl("_addmm_activation.out", TORCH_FN(_addmm_activation_out)); + m.impl("mm.out", TORCH_FN(mm_out)); + m.impl("mm", TORCH_FN(mm)); + m.impl("baddbmm.out", TORCH_FN(baddbmm_out)); + m.impl("baddbmm_", TORCH_FN(baddbmm_)); + m.impl("baddbmm", TORCH_FN(baddbmm)); + m.impl("addbmm.out", TORCH_FN(addbmm_out)); + m.impl("addbmm_", TORCH_FN(addbmm_)); + m.impl("addbmm", TORCH_FN(addbmm)); + m.impl("bmm.out", TORCH_FN(bmm_out)); + m.impl("bmm", TORCH_FN(bmm)); + m.impl("addmv.out", TORCH_FN(addmv_out)); + m.impl("tensordot.out", TORCH_FN(tensordot_out)); +} + +} // namespace at::native::xpu diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0f3baa4543f9..d080ef6ce047 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -80,6 +80,9 @@ if(INTERN_BUILD_ATEN_OPS) # Add source, includes, and libs to lists list(APPEND Caffe2_CPU_SRCS ${ATen_CPU_SRCS}) list(APPEND Caffe2_GPU_SRCS ${ATen_CUDA_CPP_SRCS}) + list(APPEND Caffe2_XPU_SRCS ${ATen_XPU_SRCS}) + list(APPEND Caffe2_XPU_INCLUDE ${ATen_XPU_INCLUDE}) + list(APPEND Caffe2_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS}) list(APPEND Caffe2_GPU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY}) list(APPEND Caffe2_GPU_CU_SRCS ${ATen_CUDA_CU_SRCS}) list(APPEND Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY}) @@ -174,6 +177,7 @@ endif() if(CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_CPU_SRCS CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_GPU_SRCS CAFFE2_ALLOWLISTED_FILES) + caffe2_do_allowlist(Caffe2_XPU_SRCS CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_GPU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_GPU_CU_SRCS CAFFE2_ALLOWLISTED_FILES) caffe2_do_allowlist(Caffe2_GPU_CU_SRCS_W_SORT_BY_KEY CAFFE2_ALLOWLISTED_FILES) @@ -1607,9 +1611,7 @@ if(USE_CUDA) caffe2_interface_library(torch_cuda torch_cuda_library) elseif(USE_ROCM) caffe2_interface_library(torch_hip torch_hip_library) -endif() - -if(USE_XPU) +elseif(USE_XPU) caffe2_interface_library(torch_xpu torch_xpu_library) endif() @@ -1621,9 +1623,7 @@ if(USE_CUDA) install(TARGETS torch_cuda torch_cuda_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") elseif(USE_ROCM) install(TARGETS torch_hip torch_hip_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") -endif() - -if(USE_XPU) +elseif(USE_XPU) install(TARGETS torch_xpu torch_xpu_library EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") endif() @@ -1689,6 +1689,8 @@ if(USE_XPU) torch_xpu INTERFACE $) target_include_directories( torch_xpu PRIVATE ${Caffe2_XPU_INCLUDE}) + target_link_libraries( + torch_xpu PRIVATE ${Caffe2_XPU_DEPENDENCY_LIBS}) target_link_libraries(torch_xpu PUBLIC torch_cpu_library) endif() diff --git a/cmake/Modules/FindSYCLToolkit.cmake b/cmake/Modules/FindSYCLToolkit.cmake index 758c4378636b..d9345bb2fe0d 100644 --- a/cmake/Modules/FindSYCLToolkit.cmake +++ b/cmake/Modules/FindSYCLToolkit.cmake @@ -56,6 +56,13 @@ find_library( NO_DEFAULT_PATH ) +find_library( + OCL_LIBRARY + NAMES OpenCL + HINTS ${SYCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + if((NOT SYCL_INCLUDE_DIR) OR (NOT SYCL_LIBRARY_DIR) OR (NOT SYCL_LIBRARY)) set(SYCL_FOUND False) set(SYCL_REASON_FAILURE "SYCL library is incomplete!!") diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py new file mode 100644 index 000000000000..0157677a582f --- /dev/null +++ b/test/xpu/test_gemm.py @@ -0,0 +1,1148 @@ +# Owner(s): ["module: intel"] + +import itertools +import math +import random +from functools import partial +from itertools import product + +import numpy as np + +import torch +from torch.testing import make_tensor +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, + precisionOverride, +) +from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase + + +class TestBasicGEMM(TestCase): + def _test_addmm_addmv( + self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None + ): + dtype = t.dtype + numpy_dtype = dtype + if dtype in {torch.bfloat16, torch.half}: + numpy_dtype = torch.float + if dtype.is_complex: + alpha = 0.9 + 0.3j if alpha is None else alpha + beta = 0.5 + 0.6j if beta is None else beta + else: + alpha = 1.2 if alpha is None else alpha + beta = 0.8 if beta is None else beta + if activation == "gelu": + res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True) + else: + res1 = f(t, m, v, alpha=alpha, beta=beta) + res2 = torch.full_like(res1, math.nan) + if transpose_out: + res2 = res2.t().clone(memory_format=torch.contiguous_format).t() + if activation == "gelu": + f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True) + else: + f(t, m, v, alpha=alpha, beta=beta, out=res2) + m.to(numpy_dtype).cpu().numpy() + v.to(numpy_dtype).cpu().numpy() + res3 = alpha * ( + m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy() + ) + if beta != 0: + res3 += (beta * t).to(numpy_dtype).cpu().numpy() + if activation == "relu": + res3 = res3 * (res3 > 0) + elif activation == "gelu": + res3_t = torch.from_numpy(res3).to(dtype) + approximate = "tanh" if t.is_cuda else "none" + res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate) + res3 = res3_t.to(numpy_dtype).cpu().numpy() + else: + assert activation is None, f"unsupported activation {activation}" + res3 = torch.from_numpy(res3).to(dtype) + self.assertEqual(res1, res2) + self.assertEqual(res1, res3) + + def _test_addmm_impl(self, func, activation, device, dtype): + M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) + m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) + m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) + self._test_addmm_addmv(func, M, m1, m2, activation=activation) + + # vector-shaped bias and beta=1 result in epilogue fusion in CUDA + V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device) + self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) + + # Test 0-strided + M = ( + torch.randn(10, 1, device="cpu", dtype=torch.float32) + .to(dtype) + .expand(10, 25) + .to(device) + ) + m1 = ( + torch.randn(10, 1, device="cpu", dtype=torch.float32) + .to(dtype) + .expand(10, 50) + .to(device) + ) + m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) + self._test_addmm_addmv(func, M, m1, m2, activation=activation) + + # Test beta=0, M=nan + M = ( + torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32) + .to(dtype) + .to(device) + ) + m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) + m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) + self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation) + + # Test transpose + for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): + + def maybe_transpose(cond, m): + if not cond: + return m + return m.t().clone(memory_format=torch.contiguous_format).t() + + M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) + m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) + m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) + self._test_addmm_addmv( + func, M, m1, m2, transpose_out=t4, activation=activation + ) + + if t1: + # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) + self._test_addmm_addmv( + func, + V, + m1, + m2, + beta=1, + transpose_out=t4, + activation=activation, + ) + + @precisionOverride( + { + torch.float: 1e-4, + torch.half: 1e-1, + } + ) + @dtypes(torch.float32, torch.half) + def test_addmm(self, device, dtype): + self._test_addmm_impl(torch.addmm, None, device, dtype) + + @precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4}) + @dtypes(torch.bfloat16, torch.half, torch.float) + def test_addmv(self, device, dtype): + # have to use torch.randn(...).to(bfloat16) instead of + # torch.randn(..., dtype=bfloat16). randn does not support + # bfloat16 yet. + # "*0.2" to reduce errors for low precision + ts = [ + 0.2 * torch.randn(50, device=device).to(dtype), + 0.2 * torch.randn(1, device=device).to(dtype).expand(50), + ] + vs = [ + 0.2 * torch.randn(100, device=device).to(dtype), + 0.2 + * torch.ones(1, device=device) + .to(dtype) + .expand(100), # to reduce errors for low precision + ] + ms = [ + # 0d + 0.2 + * torch.ones((), device=device) + .to(dtype) + .expand(50, 100), # to reduce errors for low precision + # 1d + 0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100), + # this initialization reduces errors for low precision for broadcasted matrices + # by making sure that intermediate and result values are exactly representable + # in low precision type + 0.2 + * torch.randint(3, (50, 1), dtype=torch.float, device=device) + .to(dtype) + .expand(50, 100), + # 2d + 0.2 * torch.randn((50, 100), device=device).to(dtype), + 0.2 * torch.randn((100, 50), device=device).to(dtype).t(), + ] + for m, v, t in itertools.product(ms, vs, ts): + self._test_addmm_addmv(torch.addmv, t, m, v) + # Test beta=0, t=nan + t = torch.full((50,), math.nan, device=device).to(dtype) + for m, v in itertools.product(ms, vs): + self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) + + @dtypes( + torch.half, + torch.float32, + ) + def test_mm(self, device, dtype): + def _test_mm(n, m, p, dtype, genf): + # helper function + def matrixmultiply(mat1, mat2): + n = mat1.size(0) + m = mat1.size(1) + p = mat2.size(1) + dtype_ = torch.float if dtype == torch.half else dtype + if dtype == torch.half: + mat1 = mat1.float() + mat2 = mat2.float() + res = torch.zeros(n, p, dtype=dtype_, device=device) + for i, j in iter_indices(res): + res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) + return res.half() if dtype == torch.half else res + + # contiguous case + mat1 = genf(n, m) + mat2 = genf(m, p) + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # non contiguous case 1 + mat1 = genf(n, m) + mat2 = genf(p, m).t() + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # non contiguous case 2 + mat1 = genf(m, n).t() + mat2 = genf(m, p) + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # non contiguous case 3 + mat1 = genf(m, n).t() + mat2 = genf(p, m).t() + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # test with zero stride + mat1 = genf(n, m) + mat2 = genf(m, 1).expand(m, p) + res = torch.mm(mat1, mat2) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # explicitly exercise the _out variant in torch.mm(). + # contiguous case + mat1 = genf(n, m) + mat2 = genf(m, p) + res = genf(n, p) + torch.mm(mat1, mat2, out=res) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + # explicitly exercise the _out variant in torch.mm(). + # non contiguous case 3 + mat1 = genf(m, n).t() + mat2 = genf(p, m).t() + res = genf(n, p) + torch.mm(mat1, mat2, out=res) + + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) + + def genf_int(x, y): + return torch.randint(0, 100, (x, y), dtype=dtype, device=device) + + def genf_bfloat(x, y): + return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1 + + def genf_float(x, y): + return torch.randn(x, y, dtype=dtype, device=device) + + def genf_Half(x, y): + return torch.randn(x, y, dtype=dtype, device=device) + + for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]: + if (dtype == torch.int32) or (dtype == torch.int64): + genf = genf_int + elif dtype == torch.bfloat16: + genf = genf_bfloat + elif dtype == torch.half: + genf = genf_Half + else: + genf = genf_float + + _test_mm(n, m, p, dtype, genf) + + @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) + @dtypes(torch.float32, torch.bfloat16, torch.half) + def test_bmm(self, device, dtype): + batch_sizes = [1, 10] + M, N, O = 23, 15, 12 + numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 + + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2]) + + def generate_inputs(num_batches): + # transposed tensors + for perm1, perm2 in itertools.product( + itertools.permutations((0, 1, 2)), repeat=2 + ): + b1 = make_tensor( + (num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1 + ) + b2 = make_tensor( + (num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1 + ) + b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) + b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) + yield b1, b2 + # broadcasting tensors + for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6): + shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1) + shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1) + b1 = make_tensor( + shape1, dtype=dtype, device=device, low=-0.1, high=0.1 + ).expand(num_batches, M, N) + b2 = make_tensor( + shape2, dtype=dtype, device=device, low=-0.1, high=0.1 + ).expand(num_batches, N, O) + yield b1, b2 + # zero-sized tensors + for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): + shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) + shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) + b1 = torch.randn(shape1, dtype=dtype, device=device) + b2 = torch.randn(shape2, dtype=dtype, device=device) + yield b1, b2 + + for num_batches in batch_sizes: + for (b1, b2), perm3 in itertools.product( + generate_inputs(num_batches), itertools.permutations((0, 1, 2)) + ): + res1 = torch.bmm(b1, b2) + res2 = ( + torch.full( + (num_batches, M, O), math.nan, dtype=dtype, device=device + ) + .permute(perm3) + .contiguous() + .permute(invert_perm(perm3)) + ) + torch.bmm(b1, b2, out=res2) + expect = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype) + self.assertEqual(expect, res1) + self.assertEqual(expect, res2) + + if self.device_type == "cuda": + # check that mixed arguments are rejected + self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu())) + self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2)) + self.assertRaises( + RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu()) + ) + + def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): + getattr(out_tensor, func + "_")(b1, b2) + self.assertEqual(out_tensor, ref) + res3 = out_tensor.clone() + + with self.assertWarnsOnceRegex( + UserWarning, f"This overload of {func}_ is deprecated" + ): + getattr(out_tensor, func + "_")(1, b1, b2) + self.assertEqual(out_tensor, ref * 2), + getattr(res3, func + "_")(b1, b2, beta=1) + self.assertEqual(out_tensor, res3) + + with self.assertWarnsOnceRegex( + UserWarning, f"This overload of {func}_ is deprecated" + ): + getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2) + self.assertEqual(out_tensor, ref * 2.5) + getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5) + self.assertEqual(out_tensor, res3) + + with self.assertWarnsOnceRegex( + UserWarning, f"This overload of {func} is deprecated" + ): + self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2)) + + res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5) + self.assertEqual(res4, ref * 3), + + nan = torch.full_like(out_tensor, math.nan) + res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1) + self.assertEqual(res5, ref) + + if b1.is_complex(): + res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j) + self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref) + else: + res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5) + self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref) + + res7 = torch.full_like(out_tensor, math.nan) + getattr(torch, func)(nan, b1, b2, beta=0, out=res7) + self.assertEqual(res7, ref) + + @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) + @dtypes(torch.float32, torch.bfloat16, torch.half) + def test_addbmm(self, device, dtype): + num_batches = 2 + M, N, O = 16, 17, 18 + + is_supported = True + + if not is_supported: + b1 = make_tensor( + (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 + ) + b2 = make_tensor( + (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 + ) + t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1) + self.assertRaisesRegex( + RuntimeError, + "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", + lambda: torch.addbmm(t, b1, b2), + ) + return + + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2]) + + def generate_tensor(): + numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 + # transposed tensors + for perm1, perm2 in itertools.product( + itertools.permutations((0, 1, 2)), repeat=2 + ): + for perm3 in itertools.permutations((0, 1)): + b1 = ( + make_tensor( + (num_batches, M, N), + dtype=dtype, + device=device, + low=-1, + high=1, + ) + * 0.1 + ) + b2 = ( + make_tensor( + (num_batches, N, O), + dtype=dtype, + device=device, + low=-1, + high=1, + ) + * 0.1 + ) + b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) + b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) + ref = ( + torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() + @ b2.to(numpy_dtype).cpu().numpy() + ) + .to(device=device, dtype=dtype) + .sum(0) + ) + out_tensor = ( + torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3) + ) + yield b1, b2, ref, out_tensor + # broadcasting tensors + for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): + shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) + shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) + b1 = ( + make_tensor( + shape1, dtype=dtype, device=device, low=-1, high=1 + ).expand(num_batches, M, N) + * 0.1 + ) + b2 = ( + make_tensor( + shape2, dtype=dtype, device=device, low=-1, high=1 + ).expand(num_batches, N, O) + * 0.1 + ) + ref = ( + torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() + @ b2.to(numpy_dtype).cpu().numpy() + ) + .to(device=device, dtype=dtype) + .sum(0) + ) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + # zero-sized tensors + for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): + shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) + shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) + b1 = ( + make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1) + * 0.1 + ) + b2 = ( + make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1) + * 0.1 + ) + ref = ( + torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() + @ b2.to(numpy_dtype).cpu().numpy() + ) + .to(device=device, dtype=dtype) + .sum(0) + ) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + + for b1, b2, ref, out_tensor in generate_tensor(): + self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) + + @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) + @dtypes(torch.float32, torch.bfloat16, torch.half) + def test_baddbmm(self, device, dtype): + num_batches = 10 + M, N, O = 12, 8, 50 + + def invert_perm(p): + d = {x: i for i, x in enumerate(p)} + return (d[0], d[1], d[2]) + + def generate_tensor(): + numpy_dtype = ( + dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32 + ) + # transposed tensors + for perm1, perm2, perm3 in itertools.product( + itertools.permutations((0, 1, 2)), repeat=3 + ): + b1 = make_tensor( + (num_batches, M, N), dtype=dtype, device=device, low=-1, high=1 + ) + b2 = make_tensor( + (num_batches, N, O), dtype=dtype, device=device, low=-1, high=1 + ) + b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1)) + b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2)) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype) + out_tensor = torch.zeros_like(ref) + out_tensor = ( + out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3)) + ) + yield b1, b2, ref, out_tensor + # broadcasting tensors + for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6): + shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1) + shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1) + b1 = make_tensor( + shape1, dtype=dtype, device=device, low=-1, high=1 + ).expand(num_batches, M, N) + b2 = make_tensor( + shape2, dtype=dtype, device=device, low=-1, high=1 + ).expand(num_batches, N, O) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + # zero-sized tensors + for z1, z2, z3, z4 in itertools.product((True, False), repeat=4): + shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0) + shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0) + b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2) + b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2) + ref = torch.from_numpy( + b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy() + ).to(device=device, dtype=dtype) + out_tensor = torch.zeros_like(ref) + yield b1, b2, ref, out_tensor + + for b1, b2, ref, out_tensor in generate_tensor(): + self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor) + + def test_tensordot(self, device): + a = torch.arange(60.0, device=device).reshape(3, 4, 5) + b = torch.arange(24.0, device=device).reshape(4, 3, 2) + c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() + cn = torch.from_numpy( + np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1])) + ) + self.assertEqual(c, cn) + + cout = torch.zeros((5, 2), device=device) + torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu() + self.assertEqual(c, cout) + + a = torch.randn(2, 3, 4, 5, device=device) + b = torch.randn(4, 5, 6, 7, device=device) + c = torch.tensordot(a, b, dims=2).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2)) + + with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): + torch.tensordot(a, b, dims=-1) + + self.assertEqual(c, cn) + c = torch.tensordot(a, b).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) + self.assertEqual(c, cn) + + a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0) + an = torch.from_numpy( + np.tensordot( + np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0 + ) + ) + self.assertEqual(a, an) + + @dtypes(torch.float) + @precisionOverride({torch.float32: 1e-4}) + def test_1_sized_with_0_strided(self, device, dtype): + a = make_tensor((8, 1, 64), dtype=dtype, device=device) + a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1]) + b = make_tensor((8, 64, 512), dtype=dtype, device=device) + b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512]) + res = torch.bmm(a_strided, b_strided) + expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to( + device=device, dtype=dtype + ) + self.assertEqual(expect, res) + + def _select_broadcastable_dims(self, dims_full=None): + # select full dimensionality + if dims_full is None: + dims_full = [] + ndims = random.randint(1, 4) + dims_full = [random.randint(1, 8) for _ in range(ndims)] + else: + ndims = len(dims_full) + + # select actual dimensions for ops: + # larger: full ndims, individual sizes may be reduced + # smaller: possibly reduced ndims, sizes may be reduced + smaller_ndims = random.randint(1, ndims) + dims_small = [] + dims_large = [] + for i in range(ndims - 1, -1, -1): + j = random.randint(1, 3) + if j == 1: # no reduced singleton dimension + ds = dims_full[i] + dl = dims_full[i] + elif j == 2: # larger may have reduced singleton dimension + ds = dims_full[i] + dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] + elif j == 3: # smaller may have reduced singleton dimension + ds = 1 + dl = dims_full[i] + dims_large = [dl] + dims_large + if len(dims_small) < smaller_ndims: + dims_small = [ds] + dims_small + return (dims_small, dims_large, dims_full) + + def test_broadcast_fused_matmul(self, device): + fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] + + for fn in fns: + batch_dim = random.randint(1, 8) + n_dim = random.randint(1, 8) + m_dim = random.randint(1, 8) + p_dim = random.randint(1, 8) + + def dims_full_for_fn(): + if fn == "baddbmm": + return ( + [batch_dim, n_dim, p_dim], + [batch_dim, n_dim, m_dim], + [batch_dim, m_dim, p_dim], + ) + elif fn == "addbmm": + return ( + [n_dim, p_dim], + [batch_dim, n_dim, m_dim], + [batch_dim, m_dim, p_dim], + ) + elif fn == "addmm": + return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) + elif fn == "addmv": + return ([n_dim], [n_dim, m_dim], [m_dim]) + elif fn == "addr": + return ([n_dim, m_dim], [n_dim], [m_dim]) + else: + raise AssertionError("unknown function") + + (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() + (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) + + t0_small = torch.randn(*t0_dims_small, device=device).float() + t1 = torch.randn(*t1_dims, device=device).float() + t2 = torch.randn(*t2_dims, device=device).float() + + t0_full = t0_small.expand(*t0_dims_full).to(device) + + fntorch = getattr(torch, fn) + r0 = fntorch(t0_small, t1, t2) + r1 = fntorch(t0_full, t1, t2) + self.assertEqual(r0, r1) + + @dtypes(torch.float32) + def test_strided_mm_bmm(self, device, dtype): + # Tests strided view case with stride smaller than corresponding dimension size + x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device) + new_shape = [2, 2, 2] + new_stride = [3, 1, 1] + sx = torch.as_strided(x, size=new_shape, stride=new_stride) + + torch_fn = lambda x: torch.bmm(x, x) # noqa: E731 + np_fn = lambda x: np.matmul(x, x) # noqa: E731 + self.compare_with_numpy(torch_fn, np_fn, sx) + + torch_fn = lambda x: torch.mm(x, x) # noqa: E731 + self.compare_with_numpy(torch_fn, np_fn, sx[0]) + + def test_mm_empty_inputs_mixed_dtype_errors(self, device): + a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device) + b = torch.randn(10, 20, dtype=torch.float32, device=device) + with self.assertRaisesRegex( + RuntimeError, "expected .* and .* to have the same dtype, but got:" + ): + torch.mm(a, b) + + def test_matmul_45724(self, device): + # https://github.com/pytorch/pytorch/issues/45724 + a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) + b = torch.rand(65537, 64, 22, device=device, dtype=torch.half) + c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device) + cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half() + torch.matmul(a, b, out=c) + self.assertEqual(c, cpu_result) + + @dtypes( + torch.int16, + torch.int32, + torch.int64, + torch.float16, + torch.float32, + torch.float64, + ) + def test_baddbmm_input_dtypes_compatibility(self, device, dtype): + batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) + batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device) + input_tensor = torch.rand((1, 2, 2), device=device).to(dtype) + if dtype != torch.float32: + with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"): + y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0) + else: + out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan) + y_ref = torch.bmm(batch1, batch2) + y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) + self.assertEqual(out, y_ref) + + @dtypes(torch.float) + def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): + for shape in [[3, 2, 2], [2, 20, 20]]: + mat1, mat2 = ( + torch.randn(shape, dtype=dtype, device=device) for _ in range(2) + ) + inputs = [ + torch.randn(shape, dtype=dtype, device=device), + torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), + ] + outs = [ + None, + torch.randn(shape, dtype=dtype, device=device), + torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan), + ] + options = itertools.product(inputs, outs) + for input, out in options: + y_ref = torch.bmm(mat1, mat2) + y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) + self.assertEqual(y_ref, y) + + @dtypes(torch.float) + def test_addmm_sizes(self, device, dtype): + for m in [0, 1, 25]: + for n in [0, 1, 10]: + for k in [0, 1, 8]: + M = torch.randn(n, m, device=device).to(dtype) + m1 = torch.randn(n, k, device=device).to(dtype) + m2 = torch.randn(k, m, device=device).to(dtype) + self._test_addmm_addmv(torch.addmm, M, m1, m2) + + m1 = torch.randn(n, k + 1, device=device).to(dtype) + m2 = torch.randn(k, m, device=device).to(dtype) + self.assertRaisesRegex( + RuntimeError, + f"{n}x{k + 1}.*{k}x{m}", + lambda: torch.addmm(M, m1, m2), + ) + self.assertRaisesRegex( + RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2) + ) + + @precisionOverride( + { + torch.double: 1e-8, + torch.float: 1e-4, + torch.bfloat16: 5e-2, + torch.half: 5e-2, + torch.cfloat: 1e-4, + torch.cdouble: 1e-8, + } + ) + @dtypes(torch.float32, torch.bfloat16, torch.half) + def test_addmm_gelu(self, device, dtype): + self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) + + @precisionOverride( + { + torch.double: 1e-8, + torch.float: 1e-4, + torch.bfloat16: 5e-2, + torch.half: 5e-2, + torch.cfloat: 1e-4, + torch.cdouble: 1e-8, + } + ) + @dtypes(torch.float32, torch.bfloat16, torch.half) + def test_addmm_relu(self, device, dtype): + self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) + + @dtypes(torch.float, torch.bfloat16, torch.half) + def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): + # tests (o, s)*(s). o is output size, s is summed size. + o = 5 + s = 3 + a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s) + x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype) + y_data = torch.ones(o, device=device, dtype=dtype) + control = torch.tensor( + [15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype + ) + + def _test(row_major, incx, incy, lda_tail): + if row_major: + a_storage = torch.full( + (o, s + lda_tail), float("nan"), device=device, dtype=dtype + ) + else: + a_storage = torch.full( + (s, o + lda_tail), float("nan"), device=device, dtype=dtype + ).permute(1, 0) + a = a_storage[:o, :s].copy_(a_data) + + x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype) + x = x_storage[:, 0].copy_(x_data) + + y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype) + y = y_storage[:, 0].copy_(y_data) + + self._test_addmm_addmv(torch.addmv, y, a, x) + + for row_major, incx, incy, lda_tail in itertools.product( + (False, True), (1, 2), (1, 2), (0, 1) + ): + _test(row_major, incx, incy, lda_tail) + + @precisionOverride( + { + torch.double: 1e-8, + torch.float: 1e-4, + torch.bfloat16: 0.6, + torch.half: 1e-1, + torch.cfloat: 1e-4, + torch.cdouble: 1e-8, + } + ) + @dtypes(torch.bfloat16, torch.half, torch.float32) + def test_corner_cases_of_cublasltmatmul(self, device, dtype): + # common case + M = torch.randn(128, device=device).to(dtype) + m1 = torch.randn(2048, 2400, device=device).to(dtype) + m2 = torch.randn(128, 2400, device=device).to(dtype) + torch.nn.functional.linear(m1, m2, M) + # Ntrans_B has ld >> rows + m1 = torch.rand([128, 2400]).to(dtype).to(device).t() + m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340] + M = torch.rand([128]).to(dtype).to(device) + torch.addmm(M, m2.t(), m1) + # trans_A has ld >> rows + m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t() + m2 = torch.randn(2048, 2400, device=device).to(dtype) + M = torch.rand([128]).to(dtype).to(device) + torch.addmm(M, m2, m1) + # large tensor dim > 65535 + M = torch.randn(16, device=device).to(dtype) + m1 = torch.randn(32, 131071, device=device).to(dtype) + m2 = torch.randn(16, 131071, device=device).to(dtype) + torch.nn.functional.linear(m1, m2, M) + + def test_blas_empty(self, device): + def fn(torchfn, *args, test_out=False, **kwargs): + def call_torch_fn(*args, **kwargs): + return torchfn( + *tuple( + torch.randn(shape, device=device) + if isinstance(shape, tuple) + else shape + for shape in args + ), + **kwargs, + ) + + result = call_torch_fn(*args, **kwargs) + if not test_out: + return result + else: + out = torch.full_like(result, math.nan) + out1 = call_torch_fn(*args, **kwargs, out=out) + return out + + # mm, addmm + self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) + self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) + self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) + self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) + self.assertEqual( + torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6)) + ) + self.assertEqual( + torch.zeros((5, 6), device=device), + fn(torch.mm, (5, 0), (0, 6), test_out=True), + ) + + self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) + self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape) + t = torch.randn((5, 6), device=device) + self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) + self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) + + # mv, addmv + self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) + self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) + self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) + self.assertEqual( + torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True) + ) + + self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) + t = torch.randn((3,), device=device) + self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) + self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) + + # bmm, baddbmm + self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) + self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) + self.assertEqual( + torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6)) + ) + self.assertEqual( + torch.zeros((3, 5, 6), device=device), + fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True), + ) + + self.assertEqual( + (0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape + ) + self.assertEqual( + (3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape + ) + self.assertEqual( + (0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape + ) + self.assertEqual( + (3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape + ) + c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) + self.assertEqual( + -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2) + ) # Issue #33467 + self.assertEqual( + -2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True) + ) # Issue #33467 + + # addbmm + self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) + t = torch.randn((5, 6), device=device) + self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) + self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) + + # matmul + self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,))) + self.assertEqual( + torch.tensor(0.0, device=device), + fn(torch.matmul, (0,), (0,), test_out=True), + ) + self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) + self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) + self.assertEqual( + torch.zeros((5, 3, 4), device=device), + fn(torch.matmul, (5, 3, 0), (5, 0, 4)), + ) + self.assertEqual( + torch.zeros((5, 3, 4), device=device), + fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True), + ) + + # dot + self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,))) + self.assertEqual( + torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True) + ) + + def test_large_bmm_backward(self, device): + A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT + B = torch.randn([1, 1024, 65536], device=device, requires_grad=True) + G = torch.randn([1024, 2, 65536], device=device) + + # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM + (A @ B).backward(G) + + def test_large_bmm_mm_backward(self, device): + A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT + B = torch.randn([1024, 65536], device=device, requires_grad=True) + G = torch.randn([1024, 2, 65536], device=device) + + # Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM + (A @ B).backward(G) + + def check_single_matmul(self, x, y): + def assertEqual(answer, expected): + if x.dtype.is_floating_point or x.dtype.is_complex: + k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix + self.assertEqual( + answer, + expected, + msg=f"{x.shape} x {y.shape} = {answer.shape}", + atol=k * 5e-5, + rtol=1e-4, + ) + else: + self.assertEqual( + answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}" + ) + + # test x @ y + expected = np.matmul(x.cpu(), y.cpu()) + ans = torch.matmul(x, y) + self.assertTrue(ans.is_contiguous()) + assertEqual(ans, expected) + + # test out + out = torch.empty_like(ans) + ans = torch.matmul(x, y, out=out) + self.assertIs(ans, out) + self.assertTrue(ans.is_contiguous()) + assertEqual(ans, expected) + + def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3): + """ + Generates sequences of tuples (x, y) of with size(x) = x_dim and + size(y) <= y_dim that are compatible wrt. matmul + """ + assert x_dim >= 1 + assert y_dim >= 2 + x = x_dim + for y in range(1, y_dim + 1): + for batch, mn in product( + product(range(batch_size), repeat=max(x - 2, y - 2, 0)), + product(range(matrix_size), repeat=min(y, 2)), + ): + if x == 1: + size_x = mn[:1] + size_y = batch + mn + yield size_x, size_y + else: + for k in range(matrix_size): + size_x = (k,) + mn[:1] + if x > 2: + size_x = batch[-(x - 2) :] + size_x + size_y = mn + if y > 2: + size_y = batch[-(y - 2) :] + size_y + yield size_x, size_y + + @dtypes(torch.float) + def test_matmul_small_brute_force_1d_Nd(self, device, dtype): + make_arg = partial(make_tensor, device=device, dtype=dtype) + + for (size_x, size_y), nctg_x, nctg_y in product( + self.gen_sizes_matmul(1), (True, False), (True, False) + ): + x = make_arg(size_x, noncontiguous=nctg_x) + y = make_arg(size_y, noncontiguous=nctg_y) + self.check_single_matmul(x, y) + + @dtypes(torch.float) + def test_matmul_small_brute_force_2d_Nd(self, device, dtype): + make_arg = partial(make_tensor, device=device, dtype=dtype) + + for (size_x, size_y), nctg_x, nctg_y in product( + self.gen_sizes_matmul(2), (True, False), (True, False) + ): + x = make_arg(size_x, noncontiguous=nctg_x) + y = make_arg(size_y, noncontiguous=nctg_y) + self.check_single_matmul(x, y) + + @dtypes(torch.float) + def test_matmul_small_brute_force_3d_Nd(self, device, dtype): + make_arg = partial(make_tensor, device=device, dtype=dtype) + + for (size_x, size_y), nctg_x, nctg_y in product( + self.gen_sizes_matmul(3), (True, False), (True, False) + ): + x = make_arg(size_x, noncontiguous=nctg_x) + y = make_arg(size_y, noncontiguous=nctg_y) + self.check_single_matmul(x, y) + + @dtypes(torch.float) + def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): + a = torch.empty( + (256, 512), device=device, dtype=dtype, requires_grad=True + ).unsqueeze(0) + b = torch.empty( + (4, 128, 512), device=device, dtype=dtype, requires_grad=True + ).transpose(-1, -2) + c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0) + + torch.matmul(a.detach(), b.detach(), out=c) + + with self.assertRaisesRegex( + RuntimeError, + "functions with out=... arguments don't support automatic differentiation", + ): + torch.matmul(a, b, out=c) + + with torch.no_grad(): + torch.matmul(a, b, out=c) + + +instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu") + +if __name__ == "__main__": + run_tests()