Intel GPU oneDNN upstreaming for primitive integration (#117112)

# Motivation

As proposed in https://github.com/pytorch/pytorch/issues/114848 and https://github.com/pytorch/pytorch/issues/114723, oneDNN library is an important component for Intel GPU software ecosystem.

Current PR is based on #117098, where oneDNN library for Intel GPU should be ready.  This PR is the integration code from aten to oneDNN. GEMM integration code is the core part in this PR. Accompanied with GEMM, more basic support like runtime (device, stream), primitive attr is also included.

We put the oneDNN integration code in directory `aten/src/ATen/native/mkldnn/xpu/detail`. We add a namespace `at::native::xpu::onednn` for oneDNN integration.

The code in this PR would be used in following PRs, where aten operators would call the functions in these integration code.. We separate the prs due to onednn integration is logically separable with aten operator implementation. Also, this can ease the burden of reviewing by avoid too much codes in single PR.

Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117112
Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
ZhiweiYan-96
2024-04-16 06:58:14 +00:00
committed by PyTorch MergeBot
parent 944d046645
commit cc18afa25f
7 changed files with 1139 additions and 0 deletions

View File

@ -0,0 +1,365 @@
#pragma once
#include <ATen/ATen.h>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_types.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
namespace at::native::onednn {
/* oneDNN quantization usage:
https://oneapi-src.github.io/oneDNN/dev_guide_attributes_quantization.html#
src_fp32 = scale_src * (src_int8 - zero_point)
wei_fp32 = scale_wei * (wei_int8 - zero_point)
dst_fp32 = scale_dst * (dst_int8 - zero_point)
fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32
Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei)
Int8 Convolution: dst_int8 = 1 / scale_dst * dst_fp32;
Considering zero-point (asymmetric):
dst_fp32 = (src_int8 - src_zp) * src_sc * wei_int8 * wei_sc
dst_sc * (dst_int8 - dst_zp) = (src_int8 - src_zp) * wei_int8 * src_sc *
wei_sc
dst_int8 = (src_int8 - src_zp) * wei_int8 * src_sc * wei_sc / dst_sc +
dst_zp
considering bias:
fp32 Convolution: dst_fp32 = src_fp32 * wei_fp32 + bias
Int8 Convolution: dst_fp32 = (src_int8 * wei_int8) * (scale_src * scale_wei)
+ bias Int8 Convolution: dst_fp32 = (src_int8 * wei_int8 + bias/(scale_src *
scale_wei)) * (scale_src * scale_wei) Int8 Convolution: dst_int8 = 1 /
scale_dst * dst_fp32;
*/
/*
oneDNN postops usage:
Currently, oneDNN supports 5 kinds of post ops. More details can be refered
to oneDNN doc.
https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise
0. without post ops
dst = Conv(src, wei) + bias;
dst_int8 = 1/q_scale * dst; q_scale is the op output quantization scale
fp32 API: Attr attr;
int8 API: Attr attr(q_scale);
1. append eltwise post op
dst = elt_scale * Eltwise{conv_scale * [Conv(src, wei) + bias], alpha, beta}
dst_int8 = 1/q_scale * dst;
fp32 API:
Attr attr;
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm)
int8 API:
Attr attr(q_scale);
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
attr.append_post_eltwise(elt_scale, alpha, beta, eltwise_algorithm)
2. append sum post op
dst = conv_scale * Conv(src, wei) + sum_scale * (dst - zp)
dst_int8 = 1/q_scale * dst;
fp32 API:
Attr attr;
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
attr.append_post_sum(sum_scale)
int8 API:
Attr attr(q_scale);
attr.append_post_eltwise(1.f, conv_scale, 0.f, kind_with_linear)
attr.append_post_sum(sum_scale)
3. append binary post op
dst = Binary[Conv(src, wei)]
*/
using kind_t = dnnl::primitive::kind;
struct PostOpParam {
// eltwise post op constructor
PostOpParam(float scale, float alpha, float beta, dnnl::algorithm algo, kind_t kind)
: scale_(scale), alpha_(alpha), beta_(beta), algo_(algo), kind_(kind) {}
// sum post op constructor
PostOpParam(float scale, kind_t kind) : scale_(scale), kind_(kind) {}
// binary post op constructor
PostOpParam(
at::Tensor& binary,
dnnl::memory::desc& binary_md,
dnnl::memory::desc& expected_md,
dnnl::algorithm algo,
kind_t kind)
: binary_(binary),
meta_(binary_md),
expected_meta_(expected_md),
algo_(algo),
kind_(kind) {}
// prelu post op constructor
PostOpParam(int mask, kind_t kind) : mask_(mask), kind_(kind) {}
// post sum or binary with scale post op constructor
PostOpParam(at::Tensor& binary, float scale, dnnl::algorithm algo, kind_t kind)
: scale_(scale), binary_(binary), algo_(algo), kind_(kind) {}
// for int8 sum/eltwise
float scale_ = 1.0;
// for eltwise
float alpha_ = 0.0;
float beta_ = 0.0;
// for binary
at::Tensor binary_ = at::Tensor();
at::Tensor expected_binary_ = at::Tensor();
void* binary_ptr_ = nullptr;
dnnl::memory::desc meta_ = dnnl::memory::desc();
dnnl::memory::desc expected_meta_ = dnnl::memory::desc();
// for prelu
int mask_ = 0;
// common
dnnl::algorithm algo_ = dnnl::algorithm::eltwise_relu;
kind_t kind_ = kind_t::eltwise;
};
class Attr {
public:
Attr() : q_scale_(1.f), q_zero_point_(0) {}
Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {}
/***** eltwise *****/
dnnl::algorithm kind_with_relu = dnnl::algorithm::eltwise_relu;
dnnl::algorithm kind_with_sigmoid = dnnl::algorithm::eltwise_logistic;
dnnl::algorithm kind_with_gelu_tanh = dnnl::algorithm::eltwise_gelu_tanh;
dnnl::algorithm kind_with_gelu_erf = dnnl::algorithm::eltwise_gelu_erf;
dnnl::algorithm kind_with_mish = dnnl::algorithm::eltwise_mish;
dnnl::algorithm kind_with_linear = dnnl::algorithm::eltwise_linear;
dnnl::algorithm kind_with_swish = dnnl::algorithm::eltwise_swish;
dnnl::algorithm kind_with_sqrt = dnnl::algorithm::eltwise_sqrt;
dnnl::algorithm kind_with_tanh = dnnl::algorithm::eltwise_tanh;
dnnl::algorithm kind_with_square = dnnl::algorithm::eltwise_square;
dnnl::algorithm kind_with_abs = dnnl::algorithm::eltwise_abs;
dnnl::algorithm kind_with_exp = dnnl::algorithm::eltwise_exp;
dnnl::algorithm kind_with_log = dnnl::algorithm::eltwise_log;
dnnl::algorithm kind_with_round = dnnl::algorithm::eltwise_round;
dnnl::algorithm kind_with_hardswish = dnnl::algorithm::eltwise_hardswish;
dnnl::algorithm kind_with_soft_relu = dnnl::algorithm::eltwise_soft_relu;
dnnl::algorithm kind_with_elu = dnnl::algorithm::eltwise_elu;
dnnl::algorithm kind_with_pow = dnnl::algorithm::eltwise_pow;
dnnl::algorithm kind_with_clip = dnnl::algorithm::eltwise_clip;
// note: hardsigmoid seems oneDNN still not support
dnnl::algorithm kind_with_hardsigmoid = dnnl::algorithm::eltwise_hardsigmoid;
/***** binary *****/
dnnl::algorithm kind_with_binary_mul = dnnl::algorithm::binary_mul;
dnnl::algorithm kind_with_binary_add = dnnl::algorithm::binary_add;
dnnl::algorithm kind_with_binary_sub = dnnl::algorithm::binary_sub;
dnnl::algorithm kind_with_binary_div = dnnl::algorithm::binary_div;
dnnl::algorithm kind_with_binary_eq = dnnl::algorithm::binary_eq;
dnnl::algorithm kind_with_binary_ne = dnnl::algorithm::binary_ne;
dnnl::algorithm kind_with_binary_ge = dnnl::algorithm::binary_ge;
dnnl::algorithm kind_with_binary_gt = dnnl::algorithm::binary_gt;
dnnl::algorithm kind_with_binary_le = dnnl::algorithm::binary_le;
dnnl::algorithm kind_with_binary_lt = dnnl::algorithm::binary_lt;
dnnl::algorithm kind_with_binary_max = dnnl::algorithm::binary_max;
dnnl::algorithm kind_with_binary_min = dnnl::algorithm::binary_min;
// append sum post op
Attr& append_post_sum(
float sum_scale,
float sum_q_scale = 1.f,
int64_t zp = 0) {
ops_params_.push_back(
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum));
return *this;
}
// append eltwise post op
Attr& append_post_eltwise(
float scale,
float alpha,
float beta,
dnnl::algorithm algo) {
ops_params_.push_back(
PostOpParam(scale, alpha, beta, algo, kind_t::eltwise));
return *this;
}
// append binary post op
Attr& append_post_binary(dnnl::algorithm algo, const at::Tensor& binary) {
auto binary_ = binary.is_quantized() ? at::dequantize(binary) : binary;
bool binary_is_channels_last = (binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
binary_.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d);
binary_ = binary_is_channels_last ? binary_ : binary_.contiguous();
dnnl::memory::desc md = get_onednn_md(binary_);
auto expected_md = dnnl::memory::desc(
md.get_dims(), md.get_data_type(), dnnl::memory::format_tag::any);
ops_params_.push_back(
PostOpParam(binary_, md, expected_md, algo, kind_t::binary));
return *this;
}
Attr& append_scale_binary(
dnnl::algorithm algo,
at::Tensor binary,
float scale,
float sum_q_scale = 1.f,
int64_t zp = 0) {
ops_params_.push_back(PostOpParam(
binary, /*scale_sum*/ scale * sum_q_scale, algo, kind_t::binary));
return *this;
}
// append bias with binary_add method (only used for QConv now)
template <int N>
Attr& append_bias(const at::Tensor& binary) {
// In PyTorch, bias are in shape of [OC],
// we expand its shape according to Conv dimension
// Conv1d [OC, 1, 1], Conv2d [1, OC, 1, ,1], Conv3d [1, OC, 1, 1, 1]
at::Tensor binary_ = binary.contiguous();
dnnl::memory::desc binary_md;
switch (N) {
case 1:
binary_md = dnnl::memory::desc(
{binary.size(0), 1, 1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abc);
break;
case 2:
binary_md = dnnl::memory::desc(
{1, binary.size(0), 1, 1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abcd);
break;
case 3:
binary_md = dnnl::memory::desc(
{1, binary.size(0), 1, 1, 1},
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abcde);
break;
default:
TORCH_INTERNAL_ASSERT(0,
"XPU only supports append_bias for Conv1d, Conv2d and Conv3d.");
}
// In this case, expected_md = binary_md
ops_params_.push_back(PostOpParam(
binary_, binary_md, binary_md, kind_with_binary_add, kind_t::binary));
return *this;
}
// append prelu post op
Attr& append_post_prelu(int mask) {
ops_params_.push_back(PostOpParam(mask, kind_t::prelu));
return *this;
}
dnnl::post_ops extract_post_ops(const at::Tensor& dst){
// this function is used to extract post ops params from the ops_params_
// and put them into onednn post ops
for (size_t i = 0; i < ops_params_.size(); ++i) {
kind_t kind = ops_params_[i].kind_;
switch (kind) {
case kind_t::eltwise: {
dnnl::algorithm algo = ops_params_[i].algo_;
float alpha = ops_params_[i].alpha_;
float beta = ops_params_[i].beta_;
dnnl_post_ops_.append_eltwise(algo, alpha, beta);
break;
}
case kind_t::sum: {
float scale = ops_params_[i].scale_;
// TODO [Asymmetric]:
// Post-sum zp for gpu is not supported currently
dnnl_post_ops_.append_sum(scale);
break;
}
case kind_t::binary: {
dnnl::algorithm algo = ops_params_[i].algo_;
auto expected_md = ops_params_[i].expected_meta_;
// In this case user may create src1 memory descriptor with
// format_tag::any or set a specific tag. However, in later case if
// tags mismatch with dst, it would result in suboptimal performance.
// So here we use format_tag::any to make sure the fast can be
// selected.
// Thus we use expected_md (with format_any) here to create pd instead
// of original md
dnnl_post_ops_.append_binary(algo, expected_md);
break;
}
default:
break;
}
}
// if output is quantized, then append the eltwise linear to adjust the
// output scale/zero_point
if (dst.is_quantized()) {
// [Note: Gap of u8 qtensor scale between oneDNN and PyTorch]
// The /2 here is for output_scale collected by observer is different
// from quantization requirements in oneDNN.
// For Observer, the conv_scale (activation scale in other case) is
// computed through 2max_v/(qmax - qmin). The max_v is collected
// from the tensor to be observerd.
// (https://pytorch.org/docs/stable/generated/torch.quantization.observer.MinMaxObserver.html#torch.quantization.observer.MinMaxObserver)
// On the other hand, for u8 in oneDNN, the scale for quantization is
// defined as max_v/(qmax-qmin). Hence, we need to divide by 2 here.
// (https://oneapi-src.github.io/oneDNN/dev_guide_inference_int8.html)
dnnl_post_ops_.append_eltwise(
kind_with_linear, 1.f / q_scale_, q_zero_point_);
}
return dnnl_post_ops_;
}
bool with_sum() {
for (size_t i = 0; i < ops_params_.size(); ++i) {
if (ops_params_[i].kind_ == kind_t::sum) {
return true;
}
}
return false;
}
bool with_binary() {
for (size_t i = 0; i < ops_params_.size(); ++i) {
if (ops_params_[i].kind_ == kind_t::binary) {
return true;
}
}
return false;
}
void construct_post_binary(
dnnl::primitive_desc& pd,
std::unordered_map<int, dnnl::memory>& args) {
// This function is used to construct binary memory desc in binary post ops.
// According to oneDNN doc, the binary tensor can be in shape of
// [1, 1, 1, 1], tensor broadcast
// [1, C, 1, 1], channel broadcast
// [dst.shape], no broadcast and eltwise-wise binary operations on dst
auto engine =
GpuEngineManager::Instance().get_engine({c10::kXPU, c10::xpu::current_device()});
for (size_t i = 0; i < ops_params_.size(); ++i) {
kind_t kind = ops_params_[i].kind_;
if (kind == kind_t::binary) {
dnnl::memory binary_m;
auto binary = ops_params_[i].binary_;
auto md = ops_params_[i].meta_;
// qeury expected_md to achieve peak performance
auto expected_md = pd.query_md(
dnnl::query::exec_arg_md,
DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1);
binary_m = at::native::onednn::make_onednn_memory(
md, engine, binary.data_ptr()
);
args.insert(
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(i) | DNNL_ARG_SRC_1, binary_m});
}
}
}
float q_scale_ = 1.0; // the scale used to quantize the fused result from fp32
// to int8, only works for int8 case
int64_t q_zero_point_ = 0;
std::vector<PostOpParam> ops_params_; // series of post ops
dnnl::post_ops dnnl_post_ops_;
};
} // namespace at::native::onednn

View File

@ -0,0 +1,244 @@
#include <c10/xpu/XPUFunctions.h>
#include <ATen/ATen.h>
#include <ATen/record_function.h>
#include <Attr.h>
#include <Utils.h>
#include <oneapi/dnnl/dnnl.hpp>
namespace at::native::onednn {
sycl::event matmul(
at::Tensor& result,
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Tensor& b_raw,
bool m2_trans,
Attr attr,
const std::vector<sycl::event>& deps) {
int64_t dims = result.dim();
TORCH_CHECK(
dims == 2 || dims == 3,
"oneDNN matmul only works with 2D or 3D, got ",
dims);
TORCH_CHECK(
dims == mat1.dim() && dims == mat2.dim(),
"oneDNN input matrixes must have the same ranks");
TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined");
at::Device cur_device = at::Device(at::kXPU, c10::xpu::current_device());
auto engine = GpuEngineManager::Instance().get_engine(cur_device);
auto stream = GpuStreamManager::Instance().get_stream();
at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous();
at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous();
at::Tensor dst = is_onednn_matmul_strides(result, true) ? result : result.contiguous();
int64_t m = dst.size(-2);
int64_t n = dst.size(-1);
int64_t k = m1.size(-1);
int64_t mb = 1;
if (dims == 3) {
mb = dst.size(0);
TORCH_CHECK(
mb == m1.size(0) && mb == m2.size(0),
"batch size mismatch, dst mb: ",
mb,
"m1 mb",
m1.size(0),
" m2 mb: ",
m2.size(0));
}
// validate bias and make it compatible with oneDNN implementation
bool with_bias = false;
at::Tensor b = b_raw;
if (b.defined()) {
with_bias = true;
if (b.dim() == 1) {
TORCH_CHECK(
b.size(0) == n || b.size(0) == 1,
"matmul supports [n] or [1] when bias dim is 1 ...");
if (b.size(0) == 0) {
with_bias = false;
} else if (m1.dim() == 3) {
b = b.expand({mb, m, n}).contiguous();
} else if (m1.dim() == 2) {
b = b.expand({1, n}).contiguous();
}
} else if (b.dim() == 2) {
TORCH_CHECK(
(b.size(0) == m && b.size(1) == n) ||
(b.size(0) == 1 && b.size(1) == n) ||
(b.size(0) == m && b.size(1) == 1) ||
(b.size(0) == 1 && b.size(1) == 1),
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...");
if (b.size(0) == 1 && b.size(1) == 1)
b = b.expand({1, n}).contiguous();
} else if (b.dim() == 3) {
TORCH_CHECK(
at::are_expandable({mb, m, n}, b.sizes()),
"matmul bias must be expandable to:",
dst.sizes(),
" but got:",
b.sizes());
b = b.expand({mb, m, n}).contiguous();
} else if (b.dim() == 0) {
TORCH_CHECK(
b.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
if (m1.dim() == 3) {
b = b.expand({mb, m, n}).contiguous();
} else {
b = b.expand({1, n}).contiguous();
}
} else {
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
}
}
b = b.contiguous(); // avoid reorder 2 times
// xpu matmul support both ab/ba shape for m2 tensor, we don't check any more
auto m1_usr_dt = get_onednn_dtype(m1);
auto m2_usr_dt = get_onednn_dtype(m2);
auto dst_usr_dt = get_onednn_dtype(dst);
auto m1_dt = m1_usr_dt;
auto m2_dt = m2_usr_dt;
auto dst_dt = dst_usr_dt;
dnnl::memory::data_type bias_dt;
dnnl::memory::desc m1_md, m1_usr_md, m1_any_md;
dnnl::memory::desc m2_md, m2_usr_md, m2_any_md;
dnnl::memory::desc dst_md, dst_usr_md, dst_any_md;
dnnl::memory::desc bias_md;
// Naive Master weight
if (m1_dt == dnnl::memory::data_type::bf16 && m2_dt == dnnl::memory::data_type::f32) {
m2_dt = dnnl::memory::data_type::bf16;
dst_dt = dnnl::memory::data_type::bf16;
} else if (
m1_dt == dnnl::memory::data_type::f32 && m2_dt == dnnl::memory::data_type::bf16) {
m1_dt = dnnl::memory::data_type::bf16;
dst_dt = dnnl::memory::data_type::bf16;
}
dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims;
dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides;
if (dims == 2) {
m1_dims = {m, k};
m2_dims = {k, n};
dst_dims = {m, n};
m1_strides = {m1.stride(0), m1.stride(1)};
if (m2_trans) {
m2_strides = {m2.stride(0), m2.stride(1)};
} else {
m2_strides = {m2.stride(1), m2.stride(0)};
}
dst_strides = {dst.stride(0), dst.stride(1)};
} else {
m1_dims = {mb, m, k};
m2_dims = {mb, k, n};
dst_dims = {mb, m, n};
m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)};
if (m2_trans) {
m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)};
} else {
m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)};
}
dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)};
}
if (with_bias) {
bias_dims = get_onednn_dims(b);
bias_dt = get_onednn_dtype(b);
bias_strides = get_onednn_strides(b);
}
dnnl::post_ops po = attr.extract_post_ops(dst);
std::unordered_map<int, dnnl::memory> args;
dnnl::matmul matmul_p;
dnnl::matmul::primitive_desc matmul_pd;
// STEP1: create memory desc
m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides);
m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides);
dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides);
// STEP2: creat attribute
dnnl::primitive_attr pattr;
pattr.set_post_ops(po);
#if ONEDNN_SUPPORT_DETERMINISTIC
if(at::globalContext().deterministicAlgorithms())
pattr.set_deterministic(true);
#endif
// scratchpad
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (m1_dt == dnnl::memory::data_type::f32) {
pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
}
// STEP3: create primitive
if (with_bias) {
bias_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
matmul_pd =
dnnl::matmul::primitive_desc(engine, m1_md, m2_md, bias_md, dst_md, pattr);
} else {
matmul_pd = dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
}
matmul_p = dnnl::matmul(matmul_pd);
m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides);
m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides);
dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides);
// STEP4: create memory
auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr());
auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr());
auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr());
auto expected_m1_md = matmul_pd.src_desc();
auto expected_m2_md = matmul_pd.weights_desc();
auto expected_dst_md = matmul_pd.dst_desc();
dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m;
at::Tensor m1_, m2_, dst_;
if (attr.with_binary())
attr.construct_post_binary(matmul_pd, args);
size_t scratchpad_size = matmul_pd.scratchpad_desc().get_size();
at::Tensor scratchpad_tensor = at::empty(
{static_cast<int64_t>(scratchpad_size)}, m1.options().dtype(at::kByte), c10::nullopt);
auto scratchpad_memory = make_onednn_memory(
matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
args.insert({DNNL_ARG_SRC, m1_m});
args.insert({DNNL_ARG_WEIGHTS, m2_m});
args.insert({DNNL_ARG_DST, dst_m});
if (with_bias) {
auto bias_m = make_onednn_memory(bias_md, engine, b.data_ptr());
args.insert({DNNL_ARG_BIAS, bias_m});
}
sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps);
if (!dst.is_same(result))
result.copy_(dst);
return matmul_event;
}
} // namespace at::native::onednn

View File

@ -0,0 +1,352 @@
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
namespace at::native::onednn {
dnnl::memory make_onednn_memory(
dnnl::memory::desc md,
dnnl::engine& engine,
void* ptr){
return dnnl::sycl_interop::make_memory(
md,
engine,
dnnl::sycl_interop::memory_kind::usm,
ptr == nullptr ? DNNL_MEMORY_ALLOCATE : ptr);
}
dnnl::memory::format_tag get_dnnl_default_format(
int ndims,
bool is_channels_last,
bool allow_undef) {
switch (ndims) {
case 1:
return dnnl::memory::format_tag::a;
case 2:
return dnnl::memory::format_tag::ab;
case 3:
return is_channels_last ? dnnl::memory::format_tag::acb
: dnnl::memory::format_tag::abc;
case 4:
return is_channels_last ? dnnl::memory::format_tag::acdb
: dnnl::memory::format_tag::abcd;
case 5:
return is_channels_last ? dnnl::memory::format_tag::acdeb
: dnnl::memory::format_tag::abcde;
case 6:
return dnnl::memory::format_tag::abcdef;
case 7:
return dnnl::memory::format_tag::abcdefg;
case 8:
return dnnl::memory::format_tag::abcdefgh;
case 9:
return dnnl::memory::format_tag::abcdefghi;
case 10:
return dnnl::memory::format_tag::abcdefghij;
case 11:
return dnnl::memory::format_tag::abcdefghijk;
case 12:
return dnnl::memory::format_tag::abcdefghijkl;
default:
if (!allow_undef) {
TORCH_CHECK(false, "oneDNN doesn't support tensor dimension > 12");
}
return dnnl::memory::format_tag::undef;
}
}
dnnl::memory::data_type get_onednn_dtype(
const at::Tensor& tensor,
bool allow_undef) {
switch (tensor.scalar_type()) {
case at::ScalarType::Byte:
return dnnl::memory::data_type::u8;
case at::ScalarType::Char:
return dnnl::memory::data_type::s8;
case at::ScalarType::QInt8:
return dnnl::memory::data_type::s8;
case at::ScalarType::QUInt8:
return dnnl::memory::data_type::u8;
case at::ScalarType::Int:
return dnnl::memory::data_type::s32;
case at::ScalarType::Half:
return dnnl::memory::data_type::f16;
case at::ScalarType::Float:
return dnnl::memory::data_type::f32;
case at::ScalarType::BFloat16:
return dnnl::memory::data_type::bf16;
default:
if (!allow_undef) {
TORCH_CHECK(
false,
c10::toString(tensor.scalar_type()),
" is not supported in oneDNN!");
}
return dnnl::memory::data_type::undef;
};
}
dnnl::memory::data_type get_onednn_dtype_include_double(
const at::Tensor& tensor,
bool allow_undef) {
if (tensor.scalar_type() == at::ScalarType::Double)
return dnnl::memory::data_type::f64;
return get_onednn_dtype(tensor, allow_undef);
}
bool is_supported_onednn_dtype(const at::Tensor& tensor) {
return get_onednn_dtype(tensor, /*allow_undef*/ true) ==
dnnl::memory::data_type::undef
? false
: true;
}
dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor) {
dnnl::memory::dims dims;
for (size_t i = 0; i < tensor.sizes().size(); i++)
dims.push_back(tensor.size(i));
return dims;
}
dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor) {
dnnl::memory::dims strides;
for (size_t i = 0; i < tensor.strides().size(); i++)
strides.push_back(tensor.stride(i));
return strides;
}
dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) {
return {
get_onednn_dims(tensor),
get_onednn_dtype(tensor),
get_onednn_strides(tensor)};
}
bool onednn_strides_check(const at::Tensor& src) {
auto adims = get_onednn_dims(src);
int ndims = (int)adims.size();
auto dims = adims.data();
auto data_type = static_cast<dnnl_data_type_t>(
get_onednn_dtype(src, /*allow_undef*/ true));
auto strides_info = get_onednn_strides(src);
auto strides = strides_info.empty() ? nullptr : &strides_info[0];
dnnl_memory_desc_t md;
dnnl_memory_desc_create_with_strides(&md, ndims, dims, data_type, strides);
dnnl_format_kind_t md_fmt_kind;
int md_ndims;
int md_inner_nblks;
dnnl_dims_t* md_padded_dims = nullptr;
dnnl_memory_desc_query(md, dnnl_query_inner_nblks_s32, &md_inner_nblks);
dnnl_memory_desc_query(md, dnnl_query_format_kind, &md_fmt_kind);
dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &md_ndims);
dnnl_memory_desc_query(md, dnnl_query_padded_dims, &md_padded_dims);
if (strides == nullptr || md_ndims == 0 ||
md_fmt_kind != dnnl_format_kind_t::dnnl_blocked)
return true;
dnnl_dims_t blocks = {0};
int perm[DNNL_MAX_NDIMS] = {0};
for (int d = 0; d < md_ndims; ++d) {
// no strides check needed for empty tensor
if (md_padded_dims[d] == 0)
return true;
// no strides verification for runtime dims
if (strides[d] == DNNL_RUNTIME_DIM_VAL)
return true;
perm[d] = d;
blocks[d] = 1;
}
auto block_size = 1;
dnnl_dims_t md_inner_blks;
dnnl_dims_t md_blk_inner_idxs;
dnnl_memory_desc_query(md, dnnl_query_inner_idxs, &md_blk_inner_idxs);
dnnl_memory_desc_query(md, dnnl_query_inner_blks, &md_inner_blks);
for (int iblk = 0; iblk < md_inner_nblks; ++iblk) {
blocks[md_blk_inner_idxs[iblk]] *= md_inner_blks[iblk];
block_size *= md_inner_blks[iblk];
}
// A custom comparator to yield linear order on perm
auto idx_sorter = [&](const int a, const int b) -> bool {
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
return a < b;
else if (strides[a] == strides[b])
return md_padded_dims[a] < md_padded_dims[b];
else
return strides[a] < strides[b];
};
std::sort(perm, perm + md_ndims, idx_sorter);
auto min_stride = block_size;
for (int idx = 0; idx < md_ndims; ++idx) {
const int d = perm[idx];
// Make an exception for strides[d] == 0 as it has broadcast semantics
// Note: owing to being sorted, these are the initial strides
if (strides[d] == 0)
continue;
else if (strides[d] < min_stride)
return false;
// update min_stride for next iteration
const auto padded_dim = *md_padded_dims[d];
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
}
return true;
}
bool is_broadcast(const at::Tensor& t) {
for (int i = 0; i < t.dim(); i++) {
if (t.stride(i) == 0)
return true;
}
return false;
}
bool is_onednn_matmul_strides(
const at::Tensor& tensor,
bool is_dst) {
// https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html
// oneDNN matmul only support 2-dim and 3-dim
// 2D src(Mxk), wei(KxN), dst(MxN)
// 3D src(SxMxK), wei(WxKxN), dst(DxMxN)
auto sizes = tensor.sizes();
auto tensor_dim = sizes.size();
if (tensor_dim != 2 && tensor_dim != 3)
return false;
if (tensor.is_contiguous())
return true;
// the overlaped cases are not supported
dnnl::memory::dims strides = get_onednn_strides(tensor);
int64_t storage_size = 1;
for (size_t dim = 0; dim < tensor_dim; ++dim)
storage_size += (sizes[dim] - 1) * strides[dim];
if (storage_size < tensor.numel())
return false;
// the broadcast cases are not supported
if (is_broadcast(tensor)) {
return false;
}
if (is_dst) {
// The memory format of the destination tensor should always
// be plain with n axis contiguous
if (strides[-1] != 1)
return false;
} else {
// the src and weight must have at least one of the axes
// m or k and n or k contiguous (i.e., stride=1) respectively.
if (strides[tensor_dim - 1] != 1 && strides[tensor_dim - 2] != 1)
return false;
}
if (!onednn_strides_check(tensor))
return false;
return true;
}
bool is_broadcast_from_other_to_self(
const at::Tensor& self,
const at::Tensor& other) {
return (
self.sizes() != other.sizes() &&
at::is_expandable_to(other.sizes(), self.sizes()));
}
at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim) {
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"ndim must be 3, 4 or 5 when get cl tag");
if (3 == ndim) {
return at::MemoryFormat::Contiguous;
} else if (5 == ndim) {
return at::MemoryFormat::ChannelsLast3d;
} else {
return at::MemoryFormat::ChannelsLast;
}
}
bool binary_valid(
const at::Tensor& self,
const at::Tensor& other,
bool is_fusion) {
if (self.sizes() != other.sizes() &&
!is_broadcast_from_other_to_self(self, other))
return false;
/* If the following conditions are satisfied, then oneDNN path will be
selected:
* 1. self and other should be xpu tensor and be defined.
* 2. self or other should not be scalar (wrapped tensor).
* 3. dim of self and other should be equal and must be larger than 0 and
smaller than 7.
* 4. the datatype should be supported by oneDNN primitive.
* 5. self and other should be in the same datatype.
* 6. self and other should be contiguous or channel-last contiguous.*/
// 1. self and other should be xpu tensor and be defined.
if ((!self.defined()) || (!other.defined()) || (!self.is_xpu()) ||
(!other.is_xpu()))
return false;
// 2. self or other should not be scalar (wrapped tensor).
if (self.unsafeGetTensorImpl()->is_wrapped_number() || other.unsafeGetTensorImpl()->is_wrapped_number())
return false;
// 3. dim of self and other should be equal and must be larger than 0 and
// smaller than 7.
if ((self.dim() <= 0) || (other.dim() <= 0) || (self.dim() != other.dim()) ||
(self.dim() > 6) || (other.dim() > 6))
return false;
// 4. the datatype should be supported by oneDNN primitive.
switch (self.scalar_type()) {
case at::ScalarType::Char:
break;
case at::ScalarType::Byte:
break;
case at::ScalarType::Half:
break;
case at::ScalarType::Float:
break;
case at::ScalarType::BFloat16:
break;
default:
return false;
};
// 5. datatype check
if (is_fusion) {
// for fusion case, the fusion can be performed on scalar_type or Float
// datatype.
if (self.scalar_type() != other.scalar_type() &&
other.scalar_type() != at::ScalarType::Float) {
return false;
}
} else {
if (self.scalar_type() != other.scalar_type()) {
// for non-fusion case: self and other should be in the same datatype.
return false;
}
}
// 6. self and other should be contiguous or channel-last contiguous.
const auto ndim = self.ndimension();
auto cl_tag = at::MemoryFormat::ChannelsLast;
if (3 == ndim || 4 == ndim || 5 == ndim) {
cl_tag = get_cl_tag_by_ndim(ndim);
}
if ((self.is_contiguous() && other.is_contiguous()) ||
(self.is_contiguous(cl_tag) && other.is_contiguous(cl_tag)))
return true;
return false;
}
}

View File

@ -0,0 +1,56 @@
#pragma once
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/grad_mode.h>
#include <c10/core/MemoryFormat.h>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_sycl.hpp>
#include <oneapi/dnnl/dnnl_version.h>
#define ONEDNN_SUPPORT_DETERMINISTIC (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=4)
namespace at::native::onednn {
dnnl::memory::format_tag get_dnnl_default_format(
int ndims,
bool is_channels_last = false,
bool allow_undef = false);
dnnl::memory::data_type get_onednn_dtype(
const at::Tensor& tensor,
bool allow_undef = false);
dnnl::memory::data_type get_onednn_dtype_include_double(
const at::Tensor& tensor,
bool allow_undef = false);
bool is_supported_onednn_dtype(const at::Tensor& tensor);
dnnl::memory::dims get_onednn_dims(const at::Tensor& tensor);
dnnl::memory::dims get_onednn_strides(const at::Tensor& tensor);
dnnl::memory::desc get_onednn_md(const at::Tensor& tensor);
bool onednn_strides_check(const at::Tensor& src);
bool is_broadcast(const at::Tensor& t);
bool is_onednn_matmul_strides(
const at::Tensor& tensor,
bool is_dst = false);
bool is_broadcast_from_other_to_self(
const at::Tensor& self,
const at::Tensor& other);
at::MemoryFormat get_cl_tag_by_ndim(const int64_t ndim);
bool binary_valid(
const at::Tensor& self,
const at::Tensor& other,
bool is_fusion = false);
} // namespace at::native::onednn

View File

@ -0,0 +1,20 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
#include <ATen/native/mkldnn/xpu/detail/Attr.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
namespace at::native::onednn{
TORCH_API sycl::event matmul(
at::Tensor& result,
const at::Tensor& mat1,
const at::Tensor& mat2,
const at::Tensor& b_raw,
bool m2_trans,
Attr attr,
const std::vector<sycl::event>& deps = {});
} // namespace at::native::onednn

View File

@ -0,0 +1,27 @@
#include <ATen/native/mkldnn/xpu/detail/oneDNNContext.h>
#include <ATen/native/mkldnn/xpu/detail/Utils.h>
/* *
* Do NOT put any kernels or call any device binaries here!
* Only maintain oneDNN runtime states in this file.
* */
namespace at::native::onednn {
using namespace dnnl;
GpuEngineManager& GpuEngineManager::Instance() {
static GpuEngineManager myInstance;
return myInstance;
}
GpuStreamManager& GpuStreamManager::Instance() {
static thread_local GpuStreamManager myInstance;
return myInstance;
}
bool set_onednn_verbose(int level) {
dnnl::status rs = dnnl::set_verbose(level);
return rs == dnnl::status::success;
}
} // namespace at::native::onednn

View File

@ -0,0 +1,75 @@
#pragma once
#include <ATen/Config.h>
#include <c10/core/Device.h>
#include <c10/xpu/XPUFunctions.h>
#include <c10/xpu/XPUStream.h>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_sycl.hpp>
#include <vector>
namespace at::native::onednn {
TORCH_API dnnl::memory make_onednn_memory(
dnnl::memory::desc md,
dnnl::engine& engine,
void* ptr);
// Keep non-static and non-inline
bool set_onednn_verbose(int level);
// GpuEngineManager singleton
struct TORCH_API GpuEngineManager {
static GpuEngineManager& Instance(); // Singleton
dnnl::engine& get_engine(const Device& device) {
TORCH_INTERNAL_ASSERT(device.type() == kXPU);
TORCH_INTERNAL_ASSERT(device.index() < c10::xpu::device_count());
return *engine_pool[device.index()];
}
GpuEngineManager(GpuEngineManager const&) = delete;
GpuEngineManager& operator=(GpuEngineManager const&) = delete;
protected:
GpuEngineManager() {
int device_count = (int)c10::xpu::device_count();
TORCH_INTERNAL_ASSERT(device_count > 0);
for (int i = 0; i < device_count; i++) {
engine_pool.push_back(
std::make_shared<dnnl::engine>(dnnl::sycl_interop::make_engine(
c10::xpu::get_raw_device(i), c10::xpu::get_device_context()
)));
}
}
~GpuEngineManager() {}
private:
std::vector<std::shared_ptr<dnnl::engine>> engine_pool;
};
// GpuStreamManager singleton
struct TORCH_API GpuStreamManager {
static GpuStreamManager& Instance(); // Singleton
dnnl::stream get_stream() {
c10::DeviceIndex device_index = c10::xpu::current_device();
TORCH_INTERNAL_ASSERT(device_index < c10::xpu::device_count());
return dnnl::sycl_interop::make_stream(
GpuEngineManager::Instance().get_engine({c10::kXPU, device_index}),
c10::xpu::getCurrentXPUStream(device_index).queue());
}
GpuStreamManager(GpuStreamManager const&) = delete;
GpuStreamManager& operator=(GpuStreamManager const&) = delete;
protected:
GpuStreamManager() {
}
~GpuStreamManager() {}
};
} // namespace at::native::onednn