mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Intel GPU] format XPU oneDNN integration codes (#139721)
# Motivation This PR add XPU oneDNN integration codes into lintrunner config `.lintrunner.toml`, which would format cpp source and cpp headers codes at `aten/src/ATen/native/mkldnn/xpu/` and `aten/src/ATen/native/mkldnn/xpu/detail/` Pull Request resolved: https://github.com/pytorch/pytorch/pull/139721 Approved by: https://github.com/guangyey, https://github.com/cyyever, https://github.com/EikanWang, https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e487eda7a
commit
e21ee6327d
@ -68,6 +68,8 @@ include_patterns = [
|
||||
'aten/src/ATen/native/cuda/Fused*.cu',
|
||||
'aten/src/ATen/native/cudnn/*.h',
|
||||
'aten/src/ATen/native/cudnn/*.cpp',
|
||||
'aten/src/ATen/native/mkldnn/xpu/**/*.h',
|
||||
'aten/src/ATen/native/mkldnn/xpu/**/*.cpp',
|
||||
'c10/**/*.h',
|
||||
'c10/**/*.cpp',
|
||||
'torch/csrc/**/*.h',
|
||||
@ -200,6 +202,8 @@ include_patterns = [
|
||||
'aten/src/ATen/core/*.cpp',
|
||||
'aten/src/ATen/cudnn/*.h',
|
||||
'aten/src/ATen/cudnn/*.cpp',
|
||||
'aten/src/ATen/native/mkldnn/xpu/**/*.h',
|
||||
'aten/src/ATen/native/mkldnn/xpu/**/*.cpp',
|
||||
'aten/src/ATen/detail/*',
|
||||
'aten/src/ATen/functorch/*.h',
|
||||
'aten/src/ATen/functorch/*.cpp',
|
||||
|
@ -1,8 +1,8 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <torch/library.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
@ -45,9 +45,16 @@ Tensor& addmm_out(
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
mat1.dtype() == mat2.dtype(),
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
mat1.dtype() == mat2.dtype(),
|
||||
"expected mat1 and mat2 to have the same dtype, but got: ",
|
||||
mat1.dtype(),
|
||||
" != ",
|
||||
mat2.dtype())
|
||||
// complex/double case
|
||||
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) {
|
||||
TORCH_CHECK(
|
||||
false, "Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
@ -57,21 +64,15 @@ Tensor& addmm_out(
|
||||
return result;
|
||||
}
|
||||
|
||||
if (mat1.numel() == 0){
|
||||
if(beta.to<float>() == 0.f){
|
||||
if (mat1.numel() == 0) {
|
||||
if (beta.to<float>() == 0.f) {
|
||||
return result.zero_();
|
||||
}
|
||||
return at::mul_out(
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta,
|
||||
self.scalar_type(),
|
||||
std::nullopt,
|
||||
at::kCPU,
|
||||
std::nullopt
|
||||
)
|
||||
);
|
||||
result,
|
||||
self.expand(result.sizes()),
|
||||
at::native::scalar_tensor(
|
||||
beta, self.scalar_type(), std::nullopt, at::kCPU, std::nullopt));
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
@ -81,12 +82,6 @@ Tensor& addmm_out(
|
||||
" but got:",
|
||||
self.sizes());
|
||||
|
||||
// complex/double case
|
||||
if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) {
|
||||
TORCH_CHECK(false,
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
// general case
|
||||
Tensor bias = Tensor();
|
||||
onednn::Attr attr;
|
||||
@ -151,9 +146,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
mat2.sizes()[1],
|
||||
")");
|
||||
TORCH_CHECK(
|
||||
self.dtype() == mat2.dtype(),
|
||||
"expected self and mat2 to have the same dtype, but got: ", self.dtype(), " != ", mat2.dtype()
|
||||
)
|
||||
self.dtype() == mat2.dtype(),
|
||||
"expected self and mat2 to have the same dtype, but got: ",
|
||||
self.dtype(),
|
||||
" != ",
|
||||
mat2.dtype())
|
||||
|
||||
result.resize_({self.size(0), mat2.size(1)});
|
||||
if (self.numel() == 0 || mat2.numel() == 0) {
|
||||
@ -163,8 +160,8 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||
}
|
||||
|
||||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
|
||||
TORCH_CHECK(false,
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
TORCH_CHECK(
|
||||
false, "Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||
@ -182,7 +179,6 @@ Tensor mv(const Tensor& self, const Tensor& vec) {
|
||||
return at::addmv_(result, self, vec, 0, 1);
|
||||
}
|
||||
|
||||
|
||||
// result = beta * input + alpha * (batch1 @ batch2)
|
||||
Tensor& baddbmm_out(
|
||||
const Tensor& input,
|
||||
@ -198,12 +194,12 @@ Tensor& baddbmm_out(
|
||||
std::vector<int64_t> result_shape = {
|
||||
batch1.size(0), batch1.size(1), batch2.size(2)};
|
||||
result.resize_(result_shape);
|
||||
if (result.numel() == 0){
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
} else if (batch1.size(2) == 0){
|
||||
if (beta.to<c10::complex<double>>() == 0.0){
|
||||
} else if (batch1.size(2) == 0) {
|
||||
if (beta.to<c10::complex<double>>() == 0.0) {
|
||||
return result.zero_();
|
||||
}else{
|
||||
} else {
|
||||
at::mul_out(result, input, beta);
|
||||
return result;
|
||||
}
|
||||
@ -218,8 +214,8 @@ Tensor& baddbmm_out(
|
||||
|
||||
// complex and double case
|
||||
if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) {
|
||||
TORCH_CHECK(false,
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
TORCH_CHECK(
|
||||
false, "Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
// general case
|
||||
@ -251,9 +247,15 @@ Tensor& baddbmm_(
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
TORCH_CHECK(self.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", self.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
|
||||
return at::native::xpu::baddbmm_out(
|
||||
self, batch1, batch2, beta, alpha, self);
|
||||
TORCH_CHECK(
|
||||
self.dtype() == batch1.dtype(),
|
||||
"Input dtypes must be the same, got: input ",
|
||||
self.dtype(),
|
||||
", batch1: ",
|
||||
batch1.dtype(),
|
||||
", batch2: ",
|
||||
batch2.dtype());
|
||||
return at::native::xpu::baddbmm_out(self, batch1, batch2, beta, alpha, self);
|
||||
}
|
||||
|
||||
Tensor baddbmm(
|
||||
@ -263,7 +265,14 @@ Tensor baddbmm(
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha) {
|
||||
Tensor r = at::empty({0}, input.options());
|
||||
TORCH_CHECK(input.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", input.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
|
||||
TORCH_CHECK(
|
||||
input.dtype() == batch1.dtype(),
|
||||
"Input dtypes must be the same, got: input ",
|
||||
input.dtype(),
|
||||
", batch1: ",
|
||||
batch1.dtype(),
|
||||
", batch2: ",
|
||||
batch2.dtype());
|
||||
r = at::native::xpu::baddbmm_out(input, batch1, batch2, beta, alpha, r);
|
||||
return r;
|
||||
}
|
||||
@ -282,6 +291,10 @@ Tensor& addbmm_out(
|
||||
batch1.dim(),
|
||||
" and ",
|
||||
batch2.dim());
|
||||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
|
||||
TORCH_CHECK(
|
||||
false, "Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
|
||||
out.resize_({batch1.size(1), batch2.size(2)});
|
||||
if (alpha.to<float>() == 0.f || batch1.numel() == 0 || batch2.numel() == 0) {
|
||||
@ -344,8 +357,8 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
||||
}
|
||||
|
||||
if (self.is_complex() || self.scalar_type() == ScalarType::Double) {
|
||||
TORCH_CHECK(false,
|
||||
"Double and complex datatype matmul is not supported in oneDNN");
|
||||
TORCH_CHECK(
|
||||
false, "Double and complex datatype matmul is not supported in oneDNN");
|
||||
}
|
||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
@ -439,35 +452,63 @@ Tensor& tensordot_out(
|
||||
return result;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m){
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
||||
}
|
||||
} // namespace xpu
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmm_out_xpu)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::addmm_out(self, mat1, mat2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(mm_out_xpu)(const Tensor& self, const Tensor& mat2, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(mm_out_xpu)
|
||||
(const Tensor& self, const Tensor& mat2, const Tensor& result) {
|
||||
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(bmm_out_xpu)(const Tensor& self, const Tensor& batch2, const Tensor &result) {
|
||||
TORCH_IMPL_FUNC(bmm_out_xpu)
|
||||
(const Tensor& self, const Tensor& batch2, const Tensor& result) {
|
||||
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_activation_out_xpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmm_activation_out_xpu)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
bool use_gelu,
|
||||
const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::_addmm_activation_out(self, mat1, mat2, beta, alpha, use_gelu, const_cast<Tensor&>(result));
|
||||
xpu::_addmm_activation_out(
|
||||
self, mat1, mat2, beta, alpha, use_gelu, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(baddbmm_out_xpu)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(baddbmm_out_xpu)
|
||||
(const Tensor& self,
|
||||
const Tensor& batch1,
|
||||
const Tensor& batch2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
xpu::baddbmm_out(
|
||||
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmv_out_xpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
|
||||
TORCH_IMPL_FUNC(addmv_out_xpu)
|
||||
(const Tensor& self,
|
||||
const Tensor& mat,
|
||||
const Tensor& vec,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
|
@ -2,15 +2,15 @@
|
||||
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/utils/ParamUtils.h>
|
||||
#include <ATen/ops/full.h>
|
||||
#include <ATen/ops/neg.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <optional>
|
||||
#include <ATen/native/utils/ParamUtils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <optional>
|
||||
|
||||
using namespace dnnl;
|
||||
using namespace at::native;
|
||||
@ -337,9 +337,13 @@ Attr get_onednn_conv_sum_attr(
|
||||
dilation_);
|
||||
MemoryFormat mem_fmt = at::MemoryFormat::Contiguous;
|
||||
auto input_fmt = input_r.suggest_memory_format();
|
||||
auto input_is_cl = (input_fmt == at::MemoryFormat::ChannelsLast || input_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
auto input_is_cl =
|
||||
(input_fmt == at::MemoryFormat::ChannelsLast ||
|
||||
input_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
auto weight_fmt = weight_r.suggest_memory_format();
|
||||
auto weight_is_cl = (weight_fmt == at::MemoryFormat::ChannelsLast || weight_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
auto weight_is_cl =
|
||||
(weight_fmt == at::MemoryFormat::ChannelsLast ||
|
||||
weight_fmt == at::MemoryFormat::ChannelsLast3d);
|
||||
|
||||
bool propagate_channels_last = input_is_cl || weight_is_cl;
|
||||
if (propagate_channels_last)
|
||||
@ -403,7 +407,8 @@ Tensor _convolution_out(
|
||||
3 == ndim || 4 == ndim || 5 == ndim,
|
||||
"convolution only supports 3D, 4D, 5D tensor");
|
||||
// get computation format for Conv/TransposedConv
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(input_r, weight_r, transposed_);
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(input_r, weight_r, transposed_);
|
||||
|
||||
Tensor input = input_r, weight = weight_r;
|
||||
// PyTorch does not support ChannelsLast1D case,
|
||||
@ -499,7 +504,7 @@ Tensor _convolution_out(
|
||||
}
|
||||
|
||||
// create output and propagate memory format
|
||||
if (! output_r.defined()) {
|
||||
if (!output_r.defined()) {
|
||||
auto dst_tz = conv_dst_size(
|
||||
input.ndimension(),
|
||||
input.sizes(),
|
||||
@ -577,7 +582,8 @@ Tensor convolution_overrideable(
|
||||
auto k = weight_r.ndimension();
|
||||
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
|
||||
if (xpu_conv_use_channels_last(input_r, weight_r)) {
|
||||
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
|
||||
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
|
||||
: at::MemoryFormat::ChannelsLast;
|
||||
}
|
||||
Tensor input_c = input_r.contiguous(backend_memory_format);
|
||||
Tensor weight_c = weight_r.contiguous(backend_memory_format);
|
||||
@ -618,7 +624,8 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
"so far only support float, bfloat16, half and double convolution backward in XPU backend, your data type is ",
|
||||
grad_output.scalar_type());
|
||||
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight, transposed);
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(input, weight, transposed);
|
||||
|
||||
Tensor grad_output_, input_, weight_;
|
||||
IntArrayRef stride_, padding_, dilation_, output_padding_;
|
||||
@ -655,9 +662,10 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
}
|
||||
|
||||
// ensure the tensors are contiguous
|
||||
auto mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input_.ndimension())
|
||||
auto mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(input_.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
grad_output_ = grad_output_.contiguous(mfmt);
|
||||
grad_output_ = grad_output_.contiguous(mfmt);
|
||||
weight_ = weight_.contiguous(mfmt);
|
||||
input_ = input_.contiguous(mfmt);
|
||||
|
||||
@ -730,9 +738,11 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m){
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
m.impl("convolution_overrideable", TORCH_FN(convolution_overrideable));
|
||||
m.impl("convolution_backward_overrideable", TORCH_FN(convolution_backward_overrideable));
|
||||
m.impl(
|
||||
"convolution_backward_overrideable",
|
||||
TORCH_FN(convolution_backward_overrideable));
|
||||
}
|
||||
|
||||
} // namespace xpu
|
||||
|
@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_types.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_types.h>
|
||||
|
||||
namespace at::native::onednn {
|
||||
/* oneDNN quantization usage:
|
||||
@ -75,7 +75,12 @@ to oneDNN doc.
|
||||
using kind_t = dnnl::primitive::kind;
|
||||
struct PostOpParam {
|
||||
// eltwise post op constructor
|
||||
PostOpParam(float scale, float alpha, float beta, dnnl::algorithm algo, kind_t kind)
|
||||
PostOpParam(
|
||||
float scale,
|
||||
float alpha,
|
||||
float beta,
|
||||
dnnl::algorithm algo,
|
||||
kind_t kind)
|
||||
: scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {}
|
||||
// sum post op constructor
|
||||
PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {}
|
||||
@ -95,7 +100,11 @@ struct PostOpParam {
|
||||
PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {}
|
||||
|
||||
// post sum or binary with scale post op constructor
|
||||
PostOpParam(at::Tensor& binary, float scale, dnnl::algorithm algo, kind_t kind)
|
||||
PostOpParam(
|
||||
at::Tensor& binary,
|
||||
float scale,
|
||||
dnnl::algorithm algo,
|
||||
kind_t kind)
|
||||
: scale_(scale), binary_(binary), algo_(algo), kind_(kind) {}
|
||||
|
||||
// for int8 sum/eltwise
|
||||
@ -182,8 +191,9 @@ class Attr {
|
||||
// append binary post op
|
||||
Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) {
|
||||
auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary;
|
||||
bool binary_is_channels_last = (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
|
||||
binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d);
|
||||
bool binary_is_channels_last =
|
||||
(binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
|
||||
binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d);
|
||||
|
||||
binary_ = binary_is_channels_last ? binary_ : binary_.contiguous();
|
||||
dnnl::memory::desc md = get_onednn_md(binary_);
|
||||
@ -233,8 +243,8 @@ class Attr {
|
||||
dnnl::memory::format_tag::abcde);
|
||||
break;
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(0,
|
||||
"XPU only supports append_bias for Conv1d, Conv2d and Conv3d.");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0, "XPU only supports append_bias for Conv1d, Conv2d and Conv3d.");
|
||||
}
|
||||
// In this case, expected_md = binary_md
|
||||
ops_params_.push_back(PostOpParam(
|
||||
@ -248,7 +258,7 @@ class Attr {
|
||||
return *this;
|
||||
}
|
||||
|
||||
dnnl::post_ops extract_post_ops(const at::Tensor& dst){
|
||||
dnnl::post_ops extract_post_ops(const at::Tensor& dst) {
|
||||
// this function is used to extract post ops params from the ops_params_
|
||||
// and put them into onednn post ops
|
||||
for (size_t i = 0; i < ops_params_.size(); ++i) {
|
||||
@ -332,8 +342,8 @@ class Attr {
|
||||
// [1, C, 1, 1], channel broadcast
|
||||
// [dst.shape], no broadcast and eltwise-wise binary operations on dst
|
||||
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
for (size_t i = 0; i < ops_params_.size(); ++i) {
|
||||
kind_t kind = ops_params_[i].kind_;
|
||||
if (kind == kind_t::binary) {
|
||||
@ -346,8 +356,7 @@ class Attr {
|
||||
DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1);
|
||||
|
||||
binary_m = at::native::onednn::make_onednn_memory(
|
||||
md, engine, binary.data_ptr()
|
||||
);
|
||||
md, engine, binary.data_ptr());
|
||||
|
||||
args.insert(
|
||||
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m});
|
||||
|
@ -24,7 +24,7 @@ dnnl::memory::dims conv_dst_size(
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation) {
|
||||
bool has_dilation = dilation.size() > 0;
|
||||
dnnl::memory::dims dst_size(ndim);
|
||||
dnnl::memory::dims dst_size(ndim);
|
||||
dst_size[0] = src_size[src_batch_size_dim];
|
||||
dst_size[1] = weight_size[weight_dst_channels_dim];
|
||||
for (int d = 2; d < ndim; ++d) {
|
||||
@ -41,7 +41,7 @@ dnnl::memory::dims conv_dst_size(
|
||||
}
|
||||
|
||||
static inline dnnl::memory::dims compatible_dilation(IntArrayRef& dilation) {
|
||||
dnnl::memory::dims ret = dilation.vec();
|
||||
dnnl::memory::dims ret = dilation.vec();
|
||||
for (auto it = ret.begin(); it != ret.end(); it++) {
|
||||
*it -= 1;
|
||||
}
|
||||
@ -71,18 +71,20 @@ static inline dnnl::memory::format_tag conv_weight_fmt(
|
||||
const bool grouped = false,
|
||||
const bool is_channels_last = false) {
|
||||
if (!is_channels_last) {
|
||||
return (ndim == 3)
|
||||
? (grouped ? dnnl::memory::format_tag::goiw : dnnl::memory::format_tag::oiw)
|
||||
return (ndim == 3) ? (grouped ? dnnl::memory::format_tag::goiw
|
||||
: dnnl::memory::format_tag::oiw)
|
||||
: (ndim == 4)
|
||||
? (grouped ? dnnl::memory::format_tag::goihw : dnnl::memory::format_tag::oihw)
|
||||
? (grouped ? dnnl::memory::format_tag::goihw
|
||||
: dnnl::memory::format_tag::oihw)
|
||||
: ((ndim == 5) ? (grouped ? dnnl::memory::format_tag::goidhw
|
||||
: dnnl::memory::format_tag::oidhw)
|
||||
: dnnl::memory::format_tag::undef);
|
||||
} else {
|
||||
return (ndim == 3)
|
||||
? (grouped ? dnnl::memory::format_tag::gowi : dnnl::memory::format_tag::owi)
|
||||
return (ndim == 3) ? (grouped ? dnnl::memory::format_tag::gowi
|
||||
: dnnl::memory::format_tag::owi)
|
||||
: (ndim == 4)
|
||||
? (grouped ? dnnl::memory::format_tag::gohwi : dnnl::memory::format_tag::ohwi)
|
||||
? (grouped ? dnnl::memory::format_tag::gohwi
|
||||
: dnnl::memory::format_tag::ohwi)
|
||||
: ((ndim == 5) ? (grouped ? dnnl::memory::format_tag::godhwi
|
||||
: dnnl::memory::format_tag::odhwi)
|
||||
: dnnl::memory::format_tag::undef);
|
||||
@ -97,8 +99,9 @@ static inline dnnl::memory::dims compatible_weight_dims(
|
||||
const IntArrayRef wsizes) {
|
||||
if (ndim == 3) {
|
||||
auto kw = wsizes[2];
|
||||
return (groups != 1) ? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
|
||||
: dnnl::memory::dims({oc, ic, kw});
|
||||
return (groups != 1)
|
||||
? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
|
||||
: dnnl::memory::dims({oc, ic, kw});
|
||||
} else if (ndim == 4) {
|
||||
auto kh = wsizes[2];
|
||||
auto kw = wsizes[3];
|
||||
@ -117,11 +120,8 @@ static inline dnnl::memory::dims compatible_weight_dims(
|
||||
return {};
|
||||
}
|
||||
|
||||
static std::tuple<
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc>
|
||||
conv_get_md(
|
||||
static std::tuple<dnnl::memory::desc, dnnl::memory::desc, dnnl::memory::desc>
|
||||
conv_get_md(
|
||||
const at::Tensor& src,
|
||||
const at::Tensor& weight,
|
||||
const at::Tensor& dst,
|
||||
@ -130,8 +130,7 @@ static std::tuple<
|
||||
// create memory desc from the src/weight/dst tensors
|
||||
dnnl::memory::desc src_usr_md, weight_usr_md, dst_usr_md;
|
||||
auto ndim = src.ndimension();
|
||||
auto fmt_src =
|
||||
conv_src_fmt(ndim, is_channels_last);
|
||||
auto fmt_src = conv_src_fmt(ndim, is_channels_last);
|
||||
|
||||
auto src_size = src.sizes().vec();
|
||||
auto src_data_t = get_onednn_dtype_include_double(src);
|
||||
@ -146,10 +145,7 @@ static std::tuple<
|
||||
auto wei_data_t = get_onednn_dtype_include_double(weight);
|
||||
dnnl::memory::dims weight_size =
|
||||
compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
|
||||
auto fmt_weight = conv_weight_fmt(
|
||||
ndim,
|
||||
groups != 1,
|
||||
is_channels_last);
|
||||
auto fmt_weight = conv_weight_fmt(ndim, groups != 1, is_channels_last);
|
||||
weight_usr_md = dnnl::memory::desc(weight_size, wei_data_t, fmt_weight);
|
||||
|
||||
return {src_usr_md, weight_usr_md, dst_usr_md};
|
||||
@ -167,14 +163,15 @@ sycl::event convolution(
|
||||
int64_t groups,
|
||||
Attr& attr,
|
||||
const std::vector<sycl::event>& deps) {
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
bool is_channels_last = use_channels_last_for_conv(src, weight, false);
|
||||
|
||||
// create usr_md for tensors, and md for conv primitive
|
||||
auto [src_md, weight_md, dst_md] = conv_get_md(src, weight, dst, groups, is_channels_last);
|
||||
auto [src_md, weight_md, dst_md] =
|
||||
conv_get_md(src, weight, dst, groups, is_channels_last);
|
||||
|
||||
auto bia_fmt = dnnl::memory::format_tag::x;
|
||||
auto bia_md = bia.defined()
|
||||
@ -185,7 +182,8 @@ sycl::event convolution(
|
||||
// create conv primitive descriptor
|
||||
dnnl::memory::dims _stride = stride.vec();
|
||||
dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
|
||||
dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
|
||||
dnnl::memory::dims _padding_back_bottom_right =
|
||||
padding_back_bottom_right.vec();
|
||||
dnnl::memory::dims _dilation = compatible_dilation(dilation);
|
||||
|
||||
// extract post ops
|
||||
@ -195,11 +193,12 @@ sycl::event convolution(
|
||||
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
|
||||
pattr.set_deterministic(true);
|
||||
}
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn()) {
|
||||
pattr.set_deterministic(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
|
||||
engine,
|
||||
@ -222,7 +221,6 @@ sycl::event convolution(
|
||||
weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
|
||||
dst_m = make_onednn_memory(dst_md, engine, dst.data_ptr());
|
||||
|
||||
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
if (bia.defined()) {
|
||||
bia_m = make_onednn_memory(bia_md, engine, bia.data_ptr());
|
||||
@ -238,13 +236,16 @@ sycl::event convolution(
|
||||
|
||||
size_t scratchpad_size = conv_fwd_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
src.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_m = make_onednn_memory(
|
||||
conv_fwd_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
|
||||
|
||||
auto conv_forward = dnnl::convolution_forward(conv_fwd_pd);
|
||||
auto conv_fwd_event = dnnl::sycl_interop::execute(conv_forward, stream, args, deps);
|
||||
auto conv_fwd_event =
|
||||
dnnl::sycl_interop::execute(conv_forward, stream, args, deps);
|
||||
|
||||
return conv_fwd_event;
|
||||
}
|
||||
@ -261,11 +262,12 @@ sycl::event convolution_backward_weights(
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
const std::vector<sycl::event>& deps) {
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
bool is_channels_last = use_channels_last_for_conv(src, diff_dst, /*is_transposed=*/false);
|
||||
bool is_channels_last =
|
||||
use_channels_last_for_conv(src, diff_dst, /*is_transposed=*/false);
|
||||
|
||||
// create dnnl::memory desc
|
||||
auto [src_md, weight_md, dst_md] =
|
||||
@ -278,15 +280,17 @@ sycl::event convolution_backward_weights(
|
||||
// create fwd primitive hint
|
||||
dnnl::memory::dims _stride = stride.vec();
|
||||
dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
|
||||
dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
|
||||
dnnl::memory::dims _padding_back_bottom_right =
|
||||
padding_back_bottom_right.vec();
|
||||
dnnl::memory::dims _dilation = compatible_dilation(dilation);
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
|
||||
pattr.set_deterministic(true);
|
||||
}
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn()) {
|
||||
pattr.set_deterministic(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
auto conv_fwd_pd = dnnl::convolution_forward::primitive_desc(
|
||||
@ -339,14 +343,17 @@ sycl::event convolution_backward_weights(
|
||||
|
||||
size_t scratchpad_size = conv_bwd_w_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
src.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_m = make_onednn_memory(
|
||||
conv_bwd_w_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
|
||||
|
||||
// execute primitive
|
||||
auto conv_bwd_w = dnnl::convolution_backward_weights(conv_bwd_w_pd);
|
||||
sycl::event conv_bwd_w_event = dnnl::sycl_interop::execute(conv_bwd_w, stream, args, deps);
|
||||
sycl::event conv_bwd_w_event =
|
||||
dnnl::sycl_interop::execute(conv_bwd_w, stream, args, deps);
|
||||
|
||||
return conv_bwd_w_event;
|
||||
}
|
||||
@ -362,33 +369,37 @@ sycl::event convolution_backward_data(
|
||||
int64_t groups,
|
||||
bool bias_defined,
|
||||
const std::vector<sycl::event>& deps) {
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
bool is_channels_last = use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/false);
|
||||
bool is_channels_last =
|
||||
use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/false);
|
||||
|
||||
// create memory desc
|
||||
auto [src_md, weight_md, dst_md] =
|
||||
conv_get_md(diff_src, weight, diff_dst, groups, is_channels_last);
|
||||
dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
|
||||
auto bia_md = bias_defined
|
||||
? dnnl::memory::desc({diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
|
||||
? dnnl::memory::desc(
|
||||
{diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
|
||||
: dnnl::memory::desc();
|
||||
|
||||
// create fwd primitive desc hint
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn()){
|
||||
pattr.set_deterministic(true);
|
||||
}
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn()) {
|
||||
pattr.set_deterministic(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
dnnl::memory::dims _stride = stride.vec();
|
||||
dnnl::memory::dims _padding_front_top_left = padding_front_top_left.vec();
|
||||
dnnl::memory::dims _padding_back_bottom_right = padding_back_bottom_right.vec();
|
||||
dnnl::memory::dims _padding_back_bottom_right =
|
||||
padding_back_bottom_right.vec();
|
||||
dnnl::memory::dims _dilation = compatible_dilation(dilation);
|
||||
auto conv_forward_pd = dnnl::convolution_forward::primitive_desc(
|
||||
engine,
|
||||
@ -425,12 +436,13 @@ sycl::event convolution_backward_data(
|
||||
wei_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
|
||||
diff_dst_m = make_onednn_memory(dst_md, engine, diff_dst.data_ptr());
|
||||
|
||||
|
||||
// insert args
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
size_t scratchpad_size = conv_backward_data_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, diff_dst.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
diff_dst.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_memory = make_onednn_memory(
|
||||
conv_backward_data_pd.scratchpad_desc(),
|
||||
engine,
|
||||
@ -443,9 +455,9 @@ sycl::event convolution_backward_data(
|
||||
// execute primitive
|
||||
auto conv_backward_data =
|
||||
dnnl::convolution_backward_data(conv_backward_data_pd);
|
||||
auto conv_backward_data_event = dnnl::sycl_interop::execute(conv_backward_data, stream, args, deps);
|
||||
auto conv_backward_data_event =
|
||||
dnnl::sycl_interop::execute(conv_backward_data, stream, args, deps);
|
||||
return conv_backward_data_event;
|
||||
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
@ -1,14 +1,15 @@
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/xpu/XPUFunctions.h>
|
||||
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
static inline dnnl::memory::dims deconv_compatible_dilation(IntArrayRef& dilation) {
|
||||
static inline dnnl::memory::dims deconv_compatible_dilation(
|
||||
IntArrayRef& dilation) {
|
||||
dnnl::memory::dims ret = dilation.vec();
|
||||
for (auto it = ret.begin(); it != ret.end(); it++) {
|
||||
*it -= 1;
|
||||
@ -96,8 +97,9 @@ static inline dnnl::memory::dims deconv_compatible_weight_dims(
|
||||
IntArrayRef weight_size) {
|
||||
if (ndim == 3) {
|
||||
auto kw = weight_size[2];
|
||||
return (groups != 1) ? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
|
||||
: dnnl::memory::dims({oc, ic, kw});
|
||||
return (groups != 1)
|
||||
? dnnl::memory::dims({groups, oc / groups, ic / groups, kw})
|
||||
: dnnl::memory::dims({oc, ic, kw});
|
||||
} else if (ndim == 4) {
|
||||
auto kh = weight_size[2];
|
||||
auto kw = weight_size[3];
|
||||
@ -116,10 +118,7 @@ static inline dnnl::memory::dims deconv_compatible_weight_dims(
|
||||
}
|
||||
}
|
||||
|
||||
static std::tuple<
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc>
|
||||
static std::tuple<dnnl::memory::desc, dnnl::memory::desc, dnnl::memory::desc>
|
||||
deconv_get_plain_md(
|
||||
const at::Tensor& src,
|
||||
const at::Tensor& weight,
|
||||
@ -141,7 +140,8 @@ deconv_get_plain_md(
|
||||
auto weight_dt = get_onednn_dtype_include_double(weight);
|
||||
auto fmt_weight = deconv_weight_fmt(
|
||||
weight, ndim, weight_size, groups != 1, is_channels_last_suggested);
|
||||
dnnl::memory::desc weight_usr_md = dnnl::memory::desc(weight_size, weight_dt, fmt_weight);
|
||||
dnnl::memory::desc weight_usr_md =
|
||||
dnnl::memory::desc(weight_size, weight_dt, fmt_weight);
|
||||
|
||||
return {src_usr_md, weight_usr_md, dst_usr_md};
|
||||
}
|
||||
@ -158,11 +158,12 @@ sycl::event deconvolution(
|
||||
int64_t groups,
|
||||
Attr& attr,
|
||||
const std::vector<sycl::event>& deps) {
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(src, weight, /*is_transposed=*/true);
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(src, weight, /*is_transposed=*/true);
|
||||
|
||||
// create usr_md for tensors, and md for conv primitive
|
||||
auto [src_md, weight_md, dst_md] =
|
||||
@ -183,10 +184,11 @@ sycl::event deconvolution(
|
||||
dnnl::primitive_attr pattr;
|
||||
dnnl::post_ops po = attr.extract_post_ops(dst);
|
||||
pattr.set_post_ops(po);
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
@ -225,15 +227,17 @@ sycl::event deconvolution(
|
||||
|
||||
size_t scratchpad_size = deconv_fwd_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
src.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_m = make_onednn_memory(
|
||||
deconv_fwd_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
|
||||
|
||||
auto deconv_fwd = dnnl::deconvolution_forward(deconv_fwd_pd);
|
||||
sycl::event deconv_event = dnnl::sycl_interop::execute(deconv_fwd, stream, args, deps);
|
||||
sycl::event deconv_event =
|
||||
dnnl::sycl_interop::execute(deconv_fwd, stream, args, deps);
|
||||
return deconv_event;
|
||||
|
||||
}
|
||||
|
||||
sycl::event deconvolution_backward_data(
|
||||
@ -246,29 +250,30 @@ sycl::event deconvolution_backward_data(
|
||||
int64_t groups,
|
||||
bool bias_defined,
|
||||
const std::vector<sycl::event>& deps) {
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
bool is_channels_last_suggested =
|
||||
use_channels_last_for_conv(diff_dst, weight, /*is_transposed=*/true);
|
||||
// create memory desc
|
||||
auto [src_md, weight_md, dst_md] =
|
||||
deconv_get_plain_md(
|
||||
diff_src, weight, diff_dst, groups, is_channels_last_suggested);
|
||||
auto [src_md, weight_md, dst_md] = deconv_get_plain_md(
|
||||
diff_src, weight, diff_dst, groups, is_channels_last_suggested);
|
||||
|
||||
dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
|
||||
auto bias_md = bias_defined
|
||||
? dnnl::memory::desc({diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
|
||||
? dnnl::memory::desc(
|
||||
{diff_dst.size(1)}, weight_md.get_data_type(), bia_fmt)
|
||||
: dnnl::memory::desc();
|
||||
|
||||
// create fwd primitive desc hint
|
||||
dnnl::primitive_attr pattr;
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
dnnl::memory::dims _stride = stride.vec();
|
||||
dnnl::memory::dims _padding = padding.vec();
|
||||
@ -288,17 +293,18 @@ sycl::event deconvolution_backward_data(
|
||||
pattr);
|
||||
|
||||
// create bwd primitive desc
|
||||
auto deconv_backward_data_pd = dnnl::deconvolution_backward_data::primitive_desc(
|
||||
engine,
|
||||
dnnl::algorithm::deconvolution_direct,
|
||||
src_md,
|
||||
weight_md,
|
||||
dst_md,
|
||||
_stride,
|
||||
_dilation,
|
||||
_padding,
|
||||
_padding,
|
||||
deconv_fwd_pd);
|
||||
auto deconv_backward_data_pd =
|
||||
dnnl::deconvolution_backward_data::primitive_desc(
|
||||
engine,
|
||||
dnnl::algorithm::deconvolution_direct,
|
||||
src_md,
|
||||
weight_md,
|
||||
dst_md,
|
||||
_stride,
|
||||
_dilation,
|
||||
_padding,
|
||||
_padding,
|
||||
deconv_fwd_pd);
|
||||
|
||||
// create memory
|
||||
dnnl::memory diff_dst_m, wei_m, diff_src_m;
|
||||
@ -311,7 +317,9 @@ sycl::event deconvolution_backward_data(
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
size_t scratchpad_size = deconv_backward_data_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, diff_dst.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
diff_dst.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_memory = make_onednn_memory(
|
||||
deconv_backward_data_pd.scratchpad_desc(),
|
||||
engine,
|
||||
@ -324,9 +332,9 @@ sycl::event deconvolution_backward_data(
|
||||
// execute primitive
|
||||
auto deconv_backward_data =
|
||||
dnnl::deconvolution_backward_data(deconv_backward_data_pd);
|
||||
sycl::event deconv_bwd_data_event = dnnl::sycl_interop::execute(deconv_backward_data, stream, args, deps);
|
||||
sycl::event deconv_bwd_data_event =
|
||||
dnnl::sycl_interop::execute(deconv_backward_data, stream, args, deps);
|
||||
return deconv_bwd_data_event;
|
||||
|
||||
}
|
||||
|
||||
sycl::event deconvolution_backward_weights(
|
||||
@ -339,8 +347,8 @@ sycl::event deconvolution_backward_weights(
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
const std::vector<sycl::event>& deps) {
|
||||
auto engine =
|
||||
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
|
||||
auto engine = GpuEngineManager::Instance().get_engine(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
bool is_channels_last_suggested =
|
||||
@ -348,7 +356,7 @@ sycl::event deconvolution_backward_weights(
|
||||
|
||||
// create memory desc
|
||||
auto [src_md, weight_md, dst_md] = deconv_get_plain_md(
|
||||
src, diff_weight, diff_dst, groups, is_channels_last_suggested);
|
||||
src, diff_weight, diff_dst, groups, is_channels_last_suggested);
|
||||
|
||||
dnnl::memory::format_tag bia_fmt = dnnl::memory::format_tag::x;
|
||||
auto bia_md = diff_bia.defined()
|
||||
@ -361,10 +369,11 @@ sycl::event deconvolution_backward_weights(
|
||||
dnnl::memory::dims _dilation = deconv_compatible_dilation(dilation);
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
auto deconv_fwd_pd = dnnl::deconvolution_forward::primitive_desc(
|
||||
engine,
|
||||
@ -415,7 +424,9 @@ sycl::event deconvolution_backward_weights(
|
||||
|
||||
size_t scratchpad_size = deconv_bwd_w_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, src.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
src.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_m = make_onednn_memory(
|
||||
deconv_bwd_w_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_m});
|
||||
@ -423,9 +434,9 @@ sycl::event deconvolution_backward_weights(
|
||||
// execute primitive
|
||||
auto deconv_bwd_w = dnnl::deconvolution_backward_weights(deconv_bwd_w_pd);
|
||||
|
||||
sycl::event deconv_bwd_w_event = dnnl::sycl_interop::execute(deconv_bwd_w, stream, args, deps);
|
||||
sycl::event deconv_bwd_w_event =
|
||||
dnnl::sycl_interop::execute(deconv_bwd_w, stream, args, deps);
|
||||
return deconv_bwd_w_event;
|
||||
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
@ -35,7 +35,8 @@ sycl::event matmul(
|
||||
|
||||
at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous();
|
||||
at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous();
|
||||
at::Tensor dst = is_onednn_matmul_strides(result, true) ? result : result.contiguous();
|
||||
at::Tensor dst =
|
||||
is_onednn_matmul_strides(result, true) ? result : result.contiguous();
|
||||
|
||||
int64_t m = dst.size(-2);
|
||||
int64_t n = dst.size(-1);
|
||||
@ -118,11 +119,13 @@ sycl::event matmul(
|
||||
dnnl::memory::desc bias_md;
|
||||
|
||||
// Naive Master weight
|
||||
if (m1_dt == dnnl::memory::data_type::bf16 && m2_dt == dnnl::memory::data_type::f32) {
|
||||
if (m1_dt == dnnl::memory::data_type::bf16 &&
|
||||
m2_dt == dnnl::memory::data_type::f32) {
|
||||
m2_dt = dnnl::memory::data_type::bf16;
|
||||
dst_dt = dnnl::memory::data_type::bf16;
|
||||
} else if (
|
||||
m1_dt == dnnl::memory::data_type::f32 && m2_dt == dnnl::memory::data_type::bf16) {
|
||||
m1_dt == dnnl::memory::data_type::f32 &&
|
||||
m2_dt == dnnl::memory::data_type::bf16) {
|
||||
m1_dt = dnnl::memory::data_type::bf16;
|
||||
dst_dt = dnnl::memory::data_type::bf16;
|
||||
}
|
||||
@ -176,10 +179,11 @@ sycl::event matmul(
|
||||
dnnl::primitive_attr pattr;
|
||||
pattr.set_post_ops(po);
|
||||
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
#if ONEDNN_SUPPORT_DETERMINISTIC
|
||||
if (at::globalContext().deterministicAlgorithms() ||
|
||||
at::globalContext().deterministicMkldnn())
|
||||
pattr.set_deterministic(true);
|
||||
#endif
|
||||
|
||||
// scratchpad
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
@ -191,10 +195,11 @@ sycl::event matmul(
|
||||
// STEP3: create primitive
|
||||
if (with_bias) {
|
||||
bias_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
|
||||
matmul_pd =
|
||||
dnnl::matmul::primitive_desc(engine, m1_md, m2_md, bias_md, dst_md, pattr);
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
engine, m1_md, m2_md, bias_md, dst_md, pattr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
|
||||
matmul_pd =
|
||||
dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
|
||||
}
|
||||
|
||||
matmul_p = dnnl::matmul(matmul_pd);
|
||||
@ -220,7 +225,9 @@ sycl::event matmul(
|
||||
|
||||
size_t scratchpad_size = matmul_pd.scratchpad_desc().get_size();
|
||||
at::Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)}, m1.options().dtype(at::kByte), std::nullopt);
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
m1.options().dtype(at::kByte),
|
||||
std::nullopt);
|
||||
auto scratchpad_memory = make_onednn_memory(
|
||||
matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
|
||||
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
|
||||
@ -233,7 +240,8 @@ sycl::event matmul(
|
||||
args.insert({DNNL_ARG_BIAS, bias_m});
|
||||
}
|
||||
|
||||
sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps);
|
||||
sycl::event matmul_event =
|
||||
dnnl::sycl_interop::execute(matmul_p, stream, args, deps);
|
||||
|
||||
if (!dst.is_same(result))
|
||||
result.copy_(dst);
|
||||
|
@ -5,7 +5,7 @@ namespace at::native::onednn {
|
||||
dnnl::memory make_onednn_memory(
|
||||
dnnl::memory::desc md,
|
||||
dnnl::engine& engine,
|
||||
void* ptr){
|
||||
void* ptr) {
|
||||
return dnnl::sycl_interop::make_memory(
|
||||
md,
|
||||
engine,
|
||||
@ -114,10 +114,8 @@ dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) {
|
||||
}
|
||||
|
||||
dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) {
|
||||
return {
|
||||
get_onednn_dims(tensor),
|
||||
get_onednn_dtype(tensor),
|
||||
get_onednn_strides(tensor)};
|
||||
Tensor t = tensor.sizes().size() == 0 ? tensor.unsqueeze(0) : tensor;
|
||||
return {get_onednn_dims(t), get_onednn_dtype(t), get_onednn_strides(t)};
|
||||
}
|
||||
|
||||
bool onednn_strides_check(const at::Tensor& src) {
|
||||
@ -206,9 +204,7 @@ bool is_broadcast(const at::Tensor& t) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_onednn_matmul_strides(
|
||||
const at::Tensor& tensor,
|
||||
bool is_dst) {
|
||||
bool is_onednn_matmul_strides(const at::Tensor& tensor, bool is_dst) {
|
||||
// https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html
|
||||
// oneDNN matmul only support 2-dim and 3-dim
|
||||
// 2D src(Mxk), wei(KxN), dst(MxN)
|
||||
@ -290,14 +286,14 @@ bool binary_valid(
|
||||
* 5. self and other should be in the same datatype.
|
||||
* 6. self and other should be contiguous or channel-last contiguous.*/
|
||||
|
||||
|
||||
// 1. self and other should be xpu tensor and be defined.
|
||||
if ((!self.defined()) || (!other.defined()) || (!self.is_xpu()) ||
|
||||
(!other.is_xpu()))
|
||||
return false;
|
||||
|
||||
// 2. self or other should not be scalar (wrapped tensor).
|
||||
if (self.unsafeGetTensorImpl()->is_wrapped_number() || other.unsafeGetTensorImpl()->is_wrapped_number())
|
||||
if (self.unsafeGetTensorImpl()->is_wrapped_number() ||
|
||||
other.unsafeGetTensorImpl()->is_wrapped_number())
|
||||
return false;
|
||||
|
||||
// 3. dim of self and other should be equal and must be larger than 0 and
|
||||
@ -349,19 +345,19 @@ bool binary_valid(
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool is_channels_last(at::MemoryFormat fmt){
|
||||
return (at::MemoryFormat::ChannelsLast == fmt) || (at::MemoryFormat::ChannelsLast3d == fmt);
|
||||
static inline bool is_channels_last(at::MemoryFormat fmt) {
|
||||
return (at::MemoryFormat::ChannelsLast == fmt) ||
|
||||
(at::MemoryFormat::ChannelsLast3d == fmt);
|
||||
}
|
||||
|
||||
static inline bool is_smf_channels_last(const Tensor& t){
|
||||
static inline bool is_smf_channels_last(const Tensor& t) {
|
||||
return is_channels_last(t.suggest_memory_format());
|
||||
}
|
||||
|
||||
bool use_channels_last_for_conv(
|
||||
const at::Tensor& src,
|
||||
const at::Tensor& weight,
|
||||
bool is_transpose){
|
||||
|
||||
bool is_transpose) {
|
||||
if (!src.defined() || src.is_sparse()) {
|
||||
// suggest channels_first
|
||||
return false;
|
||||
@ -377,4 +373,4 @@ bool use_channels_last_for_conv(
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace at::native::onednn
|
||||
|
@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <iostream>
|
||||
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
@ -10,8 +10,8 @@
|
||||
#include <oneapi/dnnl/dnnl_sycl.hpp>
|
||||
#include <oneapi/dnnl/dnnl_version.h>
|
||||
|
||||
|
||||
#define ONEDNN_SUPPORT_DETERMINISTIC (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=4)
|
||||
#define ONEDNN_SUPPORT_DETERMINISTIC \
|
||||
(DNNL_VERSION_MAJOR >= 3 && DNNL_VERSION_MINOR >= 4)
|
||||
|
||||
namespace at::native::onednn {
|
||||
|
||||
@ -38,9 +38,7 @@ dnnl::memory::desc get_onednn_md(const at::Tensor& tensor);
|
||||
bool onednn_strides_check(const at::Tensor& src);
|
||||
bool is_broadcast(const at::Tensor& t);
|
||||
|
||||
bool is_onednn_matmul_strides(
|
||||
const at::Tensor& tensor,
|
||||
bool is_dst = false);
|
||||
bool is_onednn_matmul_strides(const at::Tensor& tensor, bool is_dst = false);
|
||||
|
||||
bool is_broadcast_from_other_to_self(
|
||||
const at::Tensor& self,
|
||||
|
@ -1,11 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
|
||||
namespace at::native::onednn{
|
||||
namespace at::native::onednn {
|
||||
|
||||
TORCH_API sycl::event matmul(
|
||||
at::Tensor& result,
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
|
||||
|
||||
/* *
|
||||
* Do NOT put any kernels or call any device binaries here!
|
||||
|
@ -343,8 +343,8 @@ inductor_expected_failures_single_sample["xpu"] = {
|
||||
"cholesky_inverse": {f64},
|
||||
# could not create a primitive
|
||||
"addbmm": {f64},
|
||||
"addmm": {f16, f32, f64},
|
||||
"addmv": {f32, f64},
|
||||
"addmm": {f64},
|
||||
"addmv": {f64},
|
||||
# could not create a primitive descriptor for
|
||||
# a deconvolution forward propagation primitive
|
||||
"nn.functional.conv_transpose2d": {f32, f64},
|
||||
|
Reference in New Issue
Block a user