mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Upgrade MKL-DNN to DNNL v1.2 (#32422)
Summary: ## Motivation This PR upgrades MKL-DNN from v0.20 to DNNL v1.2 and resolves https://github.com/pytorch/pytorch/issues/30300. DNNL (Deep Neural Network Library) is the new brand of MKL-DNN, which improves performance, quality, and usability over the old version. This PR focuses on the migration of all existing functionalities, including minor fixes, performance improvement and code clean up. It serves as the cornerstone of our future efforts to accommodate new features like OpenCL support, BF16 training, INT8 inference, etc. and to let the Pytorch community derive more benefits from the Intel Architecture. <br> ## What's included? Even DNNL has many breaking changes to the API, we managed to absorb most of them in ideep. This PR contains minimalist changes to the integration code in pytorch. Below is a summary of the changes: <br> **General:** 1. Replace op-level allocator with global-registered allocator ``` // before ideep::sum::compute<AllocForMKLDNN>(scales, {x, y}, z); // after ideep::sum::compute(scales, {x, y}, z); ``` The allocator is now being registeted at `aten/src/ATen/native/mkldnn/IDeepRegistration.cpp`. Thereafter all tensors derived from the `cpu_engine` (by default) will use the c10 allocator. ``` RegisterEngineAllocator cpu_alloc( ideep::engine::cpu_engine(), [](size_t size) { return c10::GetAllocator(c10::DeviceType::CPU)->raw_allocate(size); }, [](void* p) { c10::GetAllocator(c10::DeviceType::CPU)->raw_deallocate(p); } ); ``` ------ 2. Simplify group convolution We had such a scenario in convolution where ideep tensor shape mismatched aten tensor: when `groups > 1`, DNNL expects weights tensors to be 5-d with an extra group dimension, e.g. `goihw` instead of `oihw` in 2d conv case. As shown below, a lot of extra checks came with this difference in shape before. Now we've completely hidden this difference in ideep and all tensors are going to align with pytorch's definition. So we could safely remove these checks from both aten and c2 integration code. ``` // aten/src/ATen/native/mkldnn/Conv.cpp if (w.ndims() == x.ndims() + 1) { AT_ASSERTM( groups > 1, "Only group _mkldnn_conv2d weights could have been reordered to 5d"); kernel_size[0] = w.get_dim(0) * w.get_dim(1); std::copy_n( w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1); } else { std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin()); } ``` ------ 3. Enable DNNL built-in cache Previously, we stored DNNL jitted kernels along with intermediate buffers inside ideep using an LRU cache. Now we are switching to the newly added DNNL built-in cache, and **no longer** caching buffers in order to reduce memory footprint. This change will be mainly reflected in lower memory usage from memory profiling results. On the code side, we removed couple of lines of `op_key_` that depended on the ideep cache before. ------ 4. Use 64-bit integer to denote dimensions We changed the type of `ideep::dims` from `vector<int32_t>` to `vector<int64_t>`. This renders ideep dims no longer compatible with 32-bit dims used by caffe2. So we use something like `{stride_.begin(), stride_.end()}` to cast parameter `stride_` into a int64 vector. <br> **Misc changes in each commit:** **Commit:** change build options Some build options were slightly changed, mainly to avoid name collisions with other projects that include DNNL as a subproject. In addition, DNNL built-in cache is enabled by option `DNNL_ENABLE_PRIMITIVE_CACHE`. Old | New -- | -- WITH_EXAMPLE | MKLDNN_BUILD_EXAMPLES WITH_TEST | MKLDNN_BUILD_TESTS MKLDNN_THREADING | MKLDNN_CPU_RUNTIME MKLDNN_USE_MKL | N/A (not use MKL anymore) ------ **Commit:** aten reintegration - aten/src/ATen/native/mkldnn/BinaryOps.cpp Implement binary ops using new operation `binary` provided by DNNL - aten/src/ATen/native/mkldnn/Conv.cpp Clean up group convolution checks Simplify conv backward integration - aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp Simplify prepacking convolution weights - test/test_mkldnn.py Fixed an issue in conv2d unit test: it didn't check conv results between mkldnn and aten implementation before. Instead, it compared the mkldnn with mkldnn as the default cpu path will also go into mkldnn. Now we use `torch.backends.mkldnn.flags` to fix this issue - torch/utils/mkldnn.py Prepack weight tensor on module `__init__` to achieve better performance significantly ------ **Commit:** caffe2 reintegration - caffe2/ideep/ideep_utils.h Clean up unused type definitions - caffe2/ideep/operators/adam_op.cc & caffe2/ideep/operators/momentum_sgd_op.cc Unify tensor initialization with `ideep::tensor::init`. Obsolete `ideep::tensor::reinit` - caffe2/ideep/operators/conv_op.cc & caffe2/ideep/operators/quantization/int8_conv_op.cc Clean up group convolution checks Revamp convolution API - caffe2/ideep/operators/conv_transpose_op.cc Clean up group convolution checks Clean up deconv workaround code ------ **Commit:** custom allocator - Register c10 allocator as mentioned above <br><br> ## Performance We tested inference on some common models based on user scenarios, and most performance numbers are either better than or on par with DNNL 0.20. ratio: new / old | Latency (batch=1 4T) | Throughput (batch=64 56T) -- | -- | -- pytorch resnet18 | 121.4% | 99.7% pytorch resnet50 | 123.1% | 106.9% pytorch resnext101_32x8d | 116.3% | 100.1% pytorch resnext50_32x4d | 141.9% | 104.4% pytorch mobilenet_v2 | 163.0% | 105.8% caffe2 alexnet | 303.0% | 99.2% caffe2 googlenet-v3 | 101.1% | 99.2% caffe2 inception-v1 | 102.2% | 101.7% caffe2 mobilenet-v1 | 356.1% | 253.7% caffe2 resnet101 | 100.4% | 99.8% caffe2 resnet152 | 99.8% | 99.8% caffe2 shufflenet | 141.1% | 69.0% † caffe2 squeezenet | 98.5% | 99.2% caffe2 vgg16 | 136.8% | 100.6% caffe2 googlenet-v3 int8 | 100.0% | 100.7% caffe2 mobilenet-v1 int8 | 779.2% | 943.0% caffe2 resnet50 int8 | 99.5% | 95.5% _Configuration: Platform: Skylake 8180 Latency Test: 4 threads, warmup 30, iteration 500, batch size 1 Throughput Test: 56 threads, warmup 30, iteration 200, batch size 64_ † Shufflenet is one of the few models that require temp buffers during inference. The performance degradation is an expected issue since we no longer cache any buffer in the ideep. As for the solution, we suggest users opt for caching allocator like **jemalloc** as a drop-in replacement for system allocator in such heavy workloads. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32422 Test Plan: Perf results: https://our.intern.facebook.com/intern/fblearner/details/177790608?tab=Experiment%20Results 10% improvement for ResNext with avx512, neutral on avx2 More results: https://fb.quip.com/ob10AL0bCDXW#NNNACAUoHJP Reviewed By: yinghai Differential Revision: D20381325 Pulled By: dzhulgakov fbshipit-source-id: 803b906fd89ed8b723c5fcab55039efe3e4bcb77
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8240db11e1
commit
bd604cb5b7
@ -1,5 +0,0 @@
|
||||
#include <ATen/mkldnn/Runtime.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
}} // namespace at::native
|
@ -1,49 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <mkldnn.hpp>
|
||||
|
||||
using namespace mkldnn;
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// CpuEngine singleton
|
||||
struct CpuEngine {
|
||||
static CpuEngine& Instance() {
|
||||
static CpuEngine myInstance;
|
||||
return myInstance;
|
||||
}
|
||||
engine& get_engine() {
|
||||
return _cpu_engine;
|
||||
}
|
||||
CpuEngine(CpuEngine const&) = delete;
|
||||
CpuEngine& operator=(CpuEngine const&) = delete;
|
||||
|
||||
protected:
|
||||
CpuEngine():_cpu_engine(mkldnn::engine::cpu, 0) {}
|
||||
~CpuEngine() {}
|
||||
|
||||
private:
|
||||
engine _cpu_engine;
|
||||
};
|
||||
|
||||
// Stream singleton
|
||||
struct Stream {
|
||||
static Stream& Instance() {
|
||||
static thread_local Stream myInstance;
|
||||
return myInstance;
|
||||
};
|
||||
stream& get_stream() {
|
||||
return _cpu_stream;
|
||||
}
|
||||
Stream(Stream const&) = delete;
|
||||
Stream& operator=(Stream const&) = delete;
|
||||
|
||||
protected:
|
||||
Stream():_cpu_stream(mkldnn::stream::kind::eager) {}
|
||||
~Stream() {}
|
||||
|
||||
private:
|
||||
stream _cpu_stream;
|
||||
};
|
||||
|
||||
}} // namespace at::native
|
@ -390,7 +390,7 @@ auto ConvParams::use_cudnn_depthwise(
|
||||
|
||||
static void check_shape_forward(const at::Tensor& input,
|
||||
const c10::IntArrayRef& weight_sizes, const at::Tensor& bias,
|
||||
const ConvParams& params, bool input_is_mkldnn) {
|
||||
const ConvParams& params) {
|
||||
int64_t k = input.ndimension();
|
||||
int64_t weight_dim = weight_sizes.size();
|
||||
int64_t groups = params.groups;
|
||||
@ -564,18 +564,7 @@ at::Tensor _convolution(
|
||||
auto weight = weight_r;
|
||||
auto bias = bias_r;
|
||||
auto k = weight.ndimension();
|
||||
// mkldnn conv2d weights could have been re-ordered to 5d by
|
||||
// mkldnn_reorder_conv2d_weight
|
||||
std::vector<int64_t> weight_sizes_mkl;
|
||||
c10::IntArrayRef weight_sizes = weight.sizes();
|
||||
if (input_is_mkldnn && (k == input.ndimension() + 1)) {
|
||||
k = input.ndimension();
|
||||
weight_sizes_mkl.resize(k);
|
||||
weight_sizes_mkl[0] = weight.size(0) * weight.size(1);
|
||||
std::copy_n(
|
||||
weight.sizes().cbegin() + 2, k - 1, weight_sizes_mkl.begin() + 1);
|
||||
weight_sizes = c10::IntArrayRef(weight_sizes_mkl);
|
||||
}
|
||||
int64_t dim = k - 2;
|
||||
|
||||
|
||||
@ -592,7 +581,7 @@ at::Tensor _convolution(
|
||||
params.deterministic = deterministic;
|
||||
params.cudnn_enabled = cudnn_enabled;
|
||||
|
||||
check_shape_forward(input, weight_sizes, bias, params, input_is_mkldnn);
|
||||
check_shape_forward(input, weight_sizes, bias, params);
|
||||
|
||||
if (input.size(0) == 0) {
|
||||
// don't send empty inputs through backends
|
||||
|
@ -55,7 +55,7 @@ Tensor& mkldnn_add_out(
|
||||
|
||||
ideep::tensor& z = itensor_from_mkldnn(result);
|
||||
const std::vector<float> scales{1.0, alpha.to<float>()};
|
||||
ideep::sum::compute<AllocForMKLDNN>(scales, {x, y}, z);
|
||||
ideep::sum::compute(scales, {x, y}, z);
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -66,7 +66,7 @@ Tensor mkldnn_add(const Tensor& self, const Tensor& other, Scalar alpha) {
|
||||
|
||||
ideep::tensor z;
|
||||
const std::vector<float> scales{1.0, alpha.to<float>()};
|
||||
ideep::sum::compute<AllocForMKLDNN>(scales, {x, y}, z);
|
||||
ideep::sum::compute(scales, {x, y}, z);
|
||||
|
||||
return new_with_itensor_mkldnn(std::move(z), self.options());
|
||||
}
|
||||
@ -83,7 +83,7 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other)
|
||||
|
||||
// for zero_dim tensor
|
||||
if (other.ndimension() == 0) {
|
||||
ideep::eltwise_forward::compute<AllocForMKLDNN>(
|
||||
ideep::eltwise_forward::compute(
|
||||
x, z, ideep::algorithm::eltwise_linear,
|
||||
ideep::prop_kind::forward_inference, /*alpha*/ other.item().to<float>());
|
||||
|
||||
@ -92,8 +92,7 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other)
|
||||
AT_ASSERTM(self.sizes() == other.sizes(),
|
||||
"mkldnn_mul_out: currently mkldnn not support broadcasting");
|
||||
ideep::tensor y = itensor_from_mkldnn(other);
|
||||
auto op = ideep::eltwise_binary::eltwise_binary_op::ELTWISE_MUL;
|
||||
ideep::eltwise_binary::compute<AllocForMKLDNN>(op, x, y, z);
|
||||
ideep::binary::compute(x, y, z, dnnl::algorithm::binary_mul);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -34,13 +34,10 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
|
||||
|
||||
#else // AT_MKLDNN_EBABLED
|
||||
|
||||
#include <ATen/mkldnn/Runtime.h>
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#include <ATen/native/mkldnn/Utils.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
|
||||
using namespace mkldnn;
|
||||
|
||||
namespace {
|
||||
// Helper function for getting an ideep tensor out of an aten Tensor.
|
||||
// Note in case the aten Tensor is a dense tensor, the returned ideep
|
||||
@ -66,28 +63,16 @@ ideep::tensor _mkldnn_conv2d(
|
||||
at::IntArrayRef stride,
|
||||
at::IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
std::vector<int64_t> kernel_size(x.ndims());
|
||||
// mkldnn conv2d weights could have been re-ordered to 5d by
|
||||
// mkldnn_reorder_conv2d_weight
|
||||
if (w.ndims() == x.ndims() + 1) {
|
||||
AT_ASSERTM(
|
||||
groups > 1,
|
||||
"Only group _mkldnn_conv2d weights could have been reordered to 5d");
|
||||
kernel_size[0] = w.get_dim(0) * w.get_dim(1);
|
||||
std::copy_n(
|
||||
w.get_dims().cbegin() + 2, x.ndims() - 1, kernel_size.begin() + 1);
|
||||
} else {
|
||||
std::copy_n(w.get_dims().cbegin(), x.ndims(), kernel_size.begin());
|
||||
}
|
||||
|
||||
const ideep::param::dims x_dims = x.get_dims();
|
||||
std::vector<int64_t> input_size{x_dims.cbegin(), x_dims.cend()};
|
||||
auto kernel_size = w.get_dims();
|
||||
|
||||
std::vector<int64_t> input_size = x.get_dims();
|
||||
std::vector<int64_t> output_sizes =
|
||||
conv_output_size(input_size, kernel_size, padding, stride, dilation);
|
||||
|
||||
ideep::tensor y;
|
||||
if (b.has_value()) {
|
||||
ideep::convolution_forward::compute<AllocForMKLDNN>(
|
||||
ideep::convolution_forward::compute(
|
||||
x,
|
||||
w,
|
||||
b.value(),
|
||||
@ -97,24 +82,18 @@ ideep::tensor _mkldnn_conv2d(
|
||||
{dilation.begin(), dilation.end()},
|
||||
{padding.begin(), padding.end()},
|
||||
{padding.begin(), padding.end()},
|
||||
groups,
|
||||
ideep::descriptor_group::attr_t{},
|
||||
ideep::algorithm::convolution_direct,
|
||||
ideep::prop_kind::forward);
|
||||
groups);
|
||||
} else {
|
||||
ideep::convolution_forward::compute<AllocForMKLDNN>(
|
||||
x,
|
||||
w,
|
||||
{output_sizes.cbegin(), output_sizes.cend()},
|
||||
y,
|
||||
{stride.begin(), stride.end()},
|
||||
{dilation.begin(), dilation.end()},
|
||||
{padding.begin(), padding.end()},
|
||||
{padding.begin(), padding.end()},
|
||||
groups,
|
||||
ideep::descriptor_group::attr_t{},
|
||||
ideep::algorithm::convolution_direct,
|
||||
ideep::prop_kind::forward);
|
||||
ideep::convolution_forward::compute(
|
||||
x,
|
||||
w,
|
||||
{output_sizes.cbegin(), output_sizes.cend()},
|
||||
y,
|
||||
{stride.begin(), stride.end()},
|
||||
{dilation.begin(), dilation.end()},
|
||||
{padding.begin(), padding.end()},
|
||||
{padding.begin(), padding.end()},
|
||||
groups);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
@ -155,245 +134,63 @@ Tensor mkldnn_convolution_backward_input(
|
||||
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
|
||||
{
|
||||
auto grad_input = at::empty(input_size, grad_output.options());
|
||||
auto mkldnn_grad_output = get_mkldnn_tensor(grad_output);
|
||||
auto mkldnn_weight = get_mkldnn_tensor(weight);
|
||||
|
||||
auto cpu_engine = CpuEngine::Instance().get_engine();
|
||||
ideep::tensor mkldnn_grad_input;
|
||||
ideep::convolution_backward_data::compute(
|
||||
mkldnn_grad_output,
|
||||
mkldnn_weight,
|
||||
input_size.vec(),
|
||||
mkldnn_grad_input,
|
||||
stride.vec(),
|
||||
dilation.vec(),
|
||||
padding.vec(),
|
||||
padding.vec(),
|
||||
groups);
|
||||
|
||||
int32_t g = groups;
|
||||
|
||||
int32_t n = grad_input.size(0);
|
||||
int32_t ic = grad_input.size(1);
|
||||
int32_t ih = grad_input.size(2);
|
||||
int32_t iw = grad_input.size(3);
|
||||
|
||||
int32_t oc = grad_output.size(1);
|
||||
int32_t oh = grad_output.size(2);
|
||||
int32_t ow = grad_output.size(3);
|
||||
|
||||
int32_t kh = weight.size(2);
|
||||
int32_t kw = weight.size(3);
|
||||
|
||||
int32_t sh = stride[0];
|
||||
int32_t sw = stride[1];
|
||||
int32_t ph = padding[0];
|
||||
int32_t pw = padding[1];
|
||||
|
||||
auto data_t = memory::data_type::f32;
|
||||
auto format_any = memory::format::any;
|
||||
auto format_nchw = memory::format::nchw;
|
||||
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
|
||||
|
||||
memory::dims input_tz = {n, ic, ih, iw};
|
||||
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
|
||||
memory::dims bias_tz = {oc};
|
||||
memory::dims output_tz = {n, oc, oh, ow};
|
||||
memory::dims _stride = {sh, sw};
|
||||
memory::dims _padding = {ph, pw};
|
||||
|
||||
auto input_md = memory::desc({input_tz}, data_t, format_any);
|
||||
auto weight_md = memory::desc({weight_tz}, data_t, format_any);
|
||||
auto bias_md = memory::desc({bias_tz}, data_t, format_any);
|
||||
auto output_md = memory::desc({output_tz}, data_t, format_any);
|
||||
|
||||
// need to re-create conv_forward_pd to feed conv_backward_data_pd
|
||||
std::shared_ptr<convolution_forward::desc> conv_forward_desc;
|
||||
if (bias_defined) {
|
||||
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, input_md, weight_md, bias_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
} else {
|
||||
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, input_md, weight_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
}
|
||||
|
||||
std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
|
||||
conv_forward_pd.reset(new convolution_forward::primitive_desc(
|
||||
*conv_forward_desc, cpu_engine));
|
||||
|
||||
std::shared_ptr<convolution_backward_data::desc> conv_backward_data_desc;
|
||||
conv_backward_data_desc.reset(new convolution_backward_data::desc(
|
||||
convolution_direct, input_md, weight_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
|
||||
std::shared_ptr<convolution_backward_data::primitive_desc> conv_backward_data_pd;
|
||||
conv_backward_data_pd.reset(new convolution_backward_data::primitive_desc(
|
||||
*conv_backward_data_desc, cpu_engine, *conv_forward_pd));
|
||||
|
||||
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
|
||||
grad_output.data_ptr());
|
||||
auto weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
|
||||
weight.data_ptr());
|
||||
auto grad_input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
|
||||
grad_input.data_ptr());
|
||||
|
||||
std::vector<primitive> net;
|
||||
|
||||
auto grad_output_pd = conv_backward_data_pd->diff_dst_primitive_desc();
|
||||
auto grad_output_memory = grad_output_usr_memory;
|
||||
if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
|
||||
grad_output_memory = memory(grad_output_pd);
|
||||
net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
|
||||
}
|
||||
|
||||
auto weight_pd = conv_backward_data_pd->weights_primitive_desc();
|
||||
auto weight_memory = weight_usr_memory;
|
||||
if (weight_usr_memory.get_primitive_desc() != memory::primitive_desc(weight_pd)) {
|
||||
weight_memory = memory(weight_pd);
|
||||
net.push_back(reorder(weight_usr_memory, weight_memory));
|
||||
}
|
||||
|
||||
auto grad_input_pd = conv_backward_data_pd->diff_src_primitive_desc();
|
||||
auto grad_input_memory = grad_input_usr_memory;
|
||||
if (grad_input_memory.get_primitive_desc() != memory::primitive_desc(grad_input_pd)) {
|
||||
grad_input_memory = memory(grad_input_pd);
|
||||
}
|
||||
|
||||
std::shared_ptr<convolution_backward_data> conv_backward_data;
|
||||
conv_backward_data.reset(new convolution_backward_data(*conv_backward_data_pd,
|
||||
grad_output_memory, weight_memory, grad_input_memory));
|
||||
net.push_back(*conv_backward_data);
|
||||
|
||||
if (grad_input_memory != grad_input_usr_memory) {
|
||||
net.push_back(reorder(grad_input_memory, grad_input_usr_memory));
|
||||
}
|
||||
|
||||
Stream::Instance().get_stream().submit(net);
|
||||
|
||||
return grad_input;
|
||||
return mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_input),
|
||||
grad_output.options()));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> mkldnn_convolution_backward_weights(
|
||||
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
|
||||
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined)
|
||||
{
|
||||
auto grad_weight = at::empty(weight_size, grad_output.options());
|
||||
const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor(grad_output);
|
||||
const ideep::tensor mkldnn_input = get_mkldnn_tensor(input);
|
||||
|
||||
Tensor grad_bias;
|
||||
ideep::tensor mkldnn_grad_weight, mkldnn_grad_bias;
|
||||
if (bias_defined) {
|
||||
grad_bias = at::empty({grad_output.size(1)}, grad_output.options());
|
||||
}
|
||||
|
||||
auto cpu_engine = CpuEngine::Instance().get_engine();
|
||||
|
||||
int32_t g = groups;
|
||||
|
||||
int32_t n = input.size(0);
|
||||
int32_t ic = input.size(1);
|
||||
int32_t ih = input.size(2);
|
||||
int32_t iw = input.size(3);
|
||||
|
||||
int32_t oc = grad_output.size(1);
|
||||
int32_t oh = grad_output.size(2);
|
||||
int32_t ow = grad_output.size(3);
|
||||
|
||||
int32_t kh = grad_weight.size(2);
|
||||
int32_t kw = grad_weight.size(3);
|
||||
|
||||
int32_t sh = stride[0];
|
||||
int32_t sw = stride[1];
|
||||
int32_t ph = padding[0];
|
||||
int32_t pw = padding[1];
|
||||
|
||||
auto data_t = memory::data_type::f32;
|
||||
auto format_any = memory::format::any;
|
||||
auto format_nchw = memory::format::nchw;
|
||||
auto format_weight = (g!= 1) ? memory::format::goihw : memory::format::oihw;
|
||||
auto format_x = memory::format::x;
|
||||
|
||||
memory::dims input_tz = {n, ic, ih, iw};
|
||||
memory::dims weight_tz = (g!= 1) ? memory::dims{g, oc/g, ic/g, kh, kw} : memory::dims{oc, ic, kh, kw};
|
||||
memory::dims bias_tz = {oc};
|
||||
memory::dims output_tz = {n, oc, oh, ow};
|
||||
memory::dims _stride = {sh, sw};
|
||||
memory::dims _padding = {ph, pw};
|
||||
|
||||
memory::desc input_md({input_tz}, data_t, format_any);
|
||||
memory::desc weight_md({weight_tz}, data_t, format_any);
|
||||
memory::desc bias_md({bias_tz}, data_t, format_any);
|
||||
memory::desc output_md({output_tz}, data_t, format_any);
|
||||
|
||||
// need to re-create conv_forward_pd to feed conv_backward_weight_pd
|
||||
std::shared_ptr<convolution_forward::desc> conv_forward_desc;
|
||||
if (bias_defined) {
|
||||
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, input_md, weight_md, bias_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
ideep::convolution_backward_weights::compute(
|
||||
mkldnn_input,
|
||||
mkldnn_grad_output,
|
||||
weight_size.vec(),
|
||||
mkldnn_grad_weight,
|
||||
mkldnn_grad_bias,
|
||||
stride.vec(),
|
||||
dilation.vec(),
|
||||
padding.vec(),
|
||||
padding.vec(),
|
||||
groups);
|
||||
} else {
|
||||
conv_forward_desc.reset(new convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, input_md, weight_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
ideep::convolution_backward_weights::compute(
|
||||
mkldnn_input,
|
||||
mkldnn_grad_output,
|
||||
weight_size.vec(),
|
||||
mkldnn_grad_weight,
|
||||
stride.vec(),
|
||||
dilation.vec(),
|
||||
padding.vec(),
|
||||
padding.vec(),
|
||||
groups);
|
||||
}
|
||||
|
||||
std::shared_ptr<convolution_forward::primitive_desc> conv_forward_pd;
|
||||
conv_forward_pd.reset(new convolution_forward::primitive_desc(
|
||||
*conv_forward_desc, cpu_engine));
|
||||
|
||||
std::shared_ptr<convolution_backward_weights::desc> conv_backward_weight_desc;
|
||||
if (bias_defined) {
|
||||
conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
|
||||
convolution_direct, input_md, weight_md, bias_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
} else {
|
||||
conv_backward_weight_desc.reset(new convolution_backward_weights::desc(
|
||||
convolution_direct, input_md, weight_md, output_md,
|
||||
_stride, _padding, _padding, padding_kind::zero));
|
||||
}
|
||||
|
||||
std::shared_ptr<convolution_backward_weights::primitive_desc> conv_backward_weight_pd;
|
||||
conv_backward_weight_pd.reset(new convolution_backward_weights::primitive_desc(
|
||||
*conv_backward_weight_desc, cpu_engine, *conv_forward_pd));
|
||||
|
||||
auto input_usr_memory = memory({{{input_tz}, data_t, format_nchw}, cpu_engine},
|
||||
input.data_ptr());
|
||||
auto grad_output_usr_memory = memory({{{output_tz}, data_t, format_nchw}, cpu_engine},
|
||||
grad_output.data_ptr());
|
||||
auto grad_weight_usr_memory = memory({{{weight_tz}, data_t, format_weight}, cpu_engine},
|
||||
grad_weight.data_ptr());
|
||||
std::shared_ptr<memory> grad_bias_memory;
|
||||
|
||||
std::vector<primitive> net;
|
||||
|
||||
auto input_pd = conv_backward_weight_pd->src_primitive_desc();
|
||||
auto input_memory = input_usr_memory;
|
||||
if (input_usr_memory.get_primitive_desc() != memory::primitive_desc(input_pd)) {
|
||||
input_memory = memory(input_pd);
|
||||
net.push_back(reorder(input_usr_memory, input_memory));
|
||||
}
|
||||
|
||||
auto grad_output_pd = conv_backward_weight_pd->diff_dst_primitive_desc();
|
||||
auto grad_output_memory = grad_output_usr_memory;
|
||||
if (grad_output_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_output_pd)) {
|
||||
grad_output_memory = memory(grad_output_pd);
|
||||
net.push_back(reorder(grad_output_usr_memory, grad_output_memory));
|
||||
}
|
||||
|
||||
auto grad_weight_pd = conv_backward_weight_pd->diff_weights_primitive_desc();
|
||||
auto grad_weight_memory = grad_weight_usr_memory;
|
||||
if (grad_weight_usr_memory.get_primitive_desc() != memory::primitive_desc(grad_weight_pd)) {
|
||||
grad_weight_memory = memory(grad_weight_pd);
|
||||
}
|
||||
|
||||
std::shared_ptr<convolution_backward_weights> conv_backward_weight;
|
||||
if (bias_defined) {
|
||||
grad_bias_memory.reset(new memory({{{bias_tz}, data_t, format_x}, cpu_engine},
|
||||
grad_bias.data_ptr()));
|
||||
conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
|
||||
input_memory, grad_output_memory, grad_weight_memory, *grad_bias_memory));
|
||||
} else {
|
||||
conv_backward_weight.reset(new convolution_backward_weights(*conv_backward_weight_pd,
|
||||
input_memory, grad_output_memory, grad_weight_memory));
|
||||
}
|
||||
|
||||
net.push_back(*conv_backward_weight);
|
||||
|
||||
if (grad_weight_memory != grad_weight_usr_memory) {
|
||||
net.push_back(reorder(grad_weight_memory, grad_weight_usr_memory));
|
||||
}
|
||||
|
||||
Stream::Instance().get_stream().submit(net);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor>{grad_weight, grad_bias};
|
||||
return std::make_tuple(
|
||||
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_weight),
|
||||
grad_output.options())),
|
||||
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_bias),
|
||||
grad_output.options())));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
|
||||
@ -412,7 +209,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> mkldnn_convolution_backward(
|
||||
weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
|
||||
}
|
||||
|
||||
return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
|
||||
return std::make_tuple(grad_input, grad_weight, grad_bias);
|
||||
}
|
||||
|
||||
}} // namespace at::native
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
@ -5,4 +6,16 @@
|
||||
// needs to be included only once in library.
|
||||
#include <ideep_pin_singletons.hpp>
|
||||
|
||||
using namespace ideep;
|
||||
|
||||
RegisterEngineAllocator cpu_alloc(
|
||||
engine::cpu_engine(),
|
||||
[](size_t size) {
|
||||
return c10::GetAllocator(c10::DeviceType::CPU)->raw_allocate(size);
|
||||
},
|
||||
[](void* p) {
|
||||
c10::GetAllocator(c10::DeviceType::CPU)->raw_deallocate(p);
|
||||
}
|
||||
);
|
||||
|
||||
#endif // AT_MKLDNN_ENALBED()
|
||||
|
@ -8,19 +8,6 @@
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// Custom allocator using c10 CPU allocator for `ideep::tensor`
|
||||
struct AllocForMKLDNN {
|
||||
static char* malloc(size_t size) {
|
||||
auto allocator = c10::GetAllocator(c10::DeviceType::CPU);
|
||||
return (char*)allocator->raw_allocate(size);
|
||||
}
|
||||
|
||||
static void free(void* p) {
|
||||
auto allocator = c10::GetAllocator(c10::DeviceType::CPU);
|
||||
allocator->raw_deallocate(p);
|
||||
}
|
||||
};
|
||||
|
||||
// Construct aten MKL-DNN tensor given an ideep tensor
|
||||
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options);
|
||||
|
||||
|
@ -15,7 +15,9 @@ Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor) {
|
||||
Tensor cpu_tensor = at::empty(
|
||||
std::vector<int64_t>(dims.begin(), dims.end()),
|
||||
mkldnn_tensor.options().layout(c10::kStrided));
|
||||
stensor.to_public(cpu_tensor.template data_ptr<float>());
|
||||
if (stensor.is_empty()) return cpu_tensor;
|
||||
auto pub_tensor = stensor.to_public(cpu_tensor.template data_ptr<float>());
|
||||
cpu_tensor.as_strided_(dims, pub_tensor.get_strides());
|
||||
return cpu_tensor;
|
||||
}
|
||||
|
||||
@ -55,7 +57,7 @@ Tensor mkldnn_reorder_conv2d_weight(
|
||||
auto padding_vec = expand_param_if_needed(padding, "padding", 2);
|
||||
auto dilation_vec = expand_param_if_needed(dilation, "dilation", 2);
|
||||
|
||||
ideep::tensor w = itensor_from_mkldnn(self).as_weights();
|
||||
auto w = itensor_from_mkldnn(self);
|
||||
|
||||
// Legacy mkldnn conv2d jitted module may contain a 5-d weight with an extra
|
||||
// dimension when groups > 1, having dimension [g, o/g, i, h, w] instead of
|
||||
@ -67,9 +69,8 @@ Tensor mkldnn_reorder_conv2d_weight(
|
||||
w.reshape({wdims[0] * wdims[1], wdims[2], wdims[3], wdims[4]});
|
||||
}
|
||||
|
||||
w.make_group(groups);
|
||||
ideep::tensor::descriptor desc =
|
||||
ideep::convolution_forward::expected_weights_descriptor(
|
||||
auto desc =
|
||||
ideep::convolution_forward::expected_weights_desc(
|
||||
w.get_dims(),
|
||||
w.get_data_type(),
|
||||
{stride_vec.cbegin(), stride_vec.cend()},
|
||||
@ -79,7 +80,7 @@ Tensor mkldnn_reorder_conv2d_weight(
|
||||
groups,
|
||||
ideep::algorithm::convolution_direct);
|
||||
ideep::tensor result;
|
||||
result.init<AllocForMKLDNN>(desc);
|
||||
result.init(desc);
|
||||
result.feed_from(w);
|
||||
|
||||
return new_with_itensor_mkldnn(std::move(result), self.options());
|
||||
|
@ -62,7 +62,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
|
||||
} else {
|
||||
AT_ASSERTM(input.dim() == 4 || input.dim() == 5,
|
||||
"mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm");
|
||||
ideep::batch_normalization_forward_inference::compute<AllocForMKLDNN>(
|
||||
ideep::batch_normalization_forward_inference::compute(
|
||||
x, m, v, w, b, y, eps);
|
||||
return std::make_tuple(
|
||||
new_with_itensor_mkldnn(std::move(y), input.options()),
|
||||
|
@ -130,7 +130,7 @@ static Tensor _mkldnn_pool2d(
|
||||
}
|
||||
|
||||
ideep::tensor y;
|
||||
ideep::pooling_forward::compute<AllocForMKLDNN>(
|
||||
ideep::pooling_forward::compute(
|
||||
x,
|
||||
{output_sizes.cbegin(), output_sizes.cend()},
|
||||
y,
|
||||
|
@ -26,14 +26,14 @@ namespace at { namespace native {
|
||||
Tensor mkldnn_relu(const Tensor& input) {
|
||||
const ideep::tensor& x = itensor_from_mkldnn(input);
|
||||
ideep::tensor y;
|
||||
ideep::eltwise_forward::compute<AllocForMKLDNN>(
|
||||
ideep::eltwise_forward::compute(
|
||||
x, y, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
||||
return new_with_itensor_mkldnn(std::move(y), input.options());
|
||||
}
|
||||
|
||||
Tensor& mkldnn_relu_(Tensor& input) {
|
||||
ideep::tensor& x = itensor_from_mkldnn(input);
|
||||
ideep::eltwise_forward::compute<AllocForMKLDNN>(
|
||||
ideep::eltwise_forward::compute(
|
||||
x, x, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
||||
return input;
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ Tensor mkldnn_softmax(
|
||||
const int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim());
|
||||
ideep::tensor& x = itensor_from_mkldnn(self);
|
||||
ideep::tensor y;
|
||||
ideep::softmax_forward::compute<AllocForMKLDNN>(x, y, wrapped_dim);
|
||||
ideep::softmax_forward::compute(x, y, wrapped_dim);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
}
|
||||
|
||||
|
@ -14,8 +14,7 @@ Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::option
|
||||
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
|
||||
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
|
||||
ideep::tensor::dims dst_dims (sizes.begin(), sizes.end());
|
||||
ideep::tensor it;
|
||||
it.resize<AllocForMKLDNN>(dst_dims, ideep::tensor::data_type::f32);
|
||||
ideep::tensor it {dst_dims, ideep::tensor::data_type::f32};
|
||||
return new_with_itensor_mkldnn(std::move(it), options);
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
|
||||
}
|
||||
const ideep::tensor& x = itensor_from_mkldnn(self);
|
||||
ideep::tensor y{x};
|
||||
y.reshape<AllocForMKLDNN>({inferred_size.cbegin(), inferred_size.cend()});
|
||||
y.reshape(inferred_size);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optiona
|
||||
optional_memory_format.value());
|
||||
ideep::tensor& src = itensor_from_mkldnn(self);
|
||||
ideep::tensor dst;
|
||||
ideep::direct_copy::compute<AllocForMKLDNN>(src, dst);
|
||||
ideep::direct_copy::compute(src, dst);
|
||||
return new_with_itensor_mkldnn(std::move(dst), self.options());
|
||||
}
|
||||
|
||||
@ -71,7 +71,7 @@ Tensor mkldnn_transpose(const Tensor & self, int64_t dim0, int64_t dim1) {
|
||||
std::vector<int> axes(x.ndims());
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[dim0], axes[dim1]);
|
||||
y.transpose_from<AllocForMKLDNN>(x, axes);
|
||||
y.transpose_from(x, axes);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
}
|
||||
|
||||
|
@ -24,12 +24,10 @@ enum FusionType {
|
||||
};
|
||||
|
||||
#define USE_IDEEP_DEF_ALIASES() \
|
||||
/* the hash key of cahced operator generated by iDEEP */ \
|
||||
using ikey = ideep::key_t; \
|
||||
/* the tensor type created/handled by iDEEP */ \
|
||||
using itensor = ideep::tensor; \
|
||||
/* the date layout of iDEEP tensor */ \
|
||||
using iformat = ideep::format; \
|
||||
using iformat = ideep::format_tag; \
|
||||
/* the scales for iDEEP tensor with different data type */ \
|
||||
using iscale = ideep::scale_t; \
|
||||
/* the detial algorithm for iDEEP operators, e.g. winograd */ \
|
||||
@ -38,14 +36,12 @@ enum FusionType {
|
||||
using iprop = ideep::prop_kind; \
|
||||
/* the kind of low precision operators, e.g. signed/unsigned activation */ \
|
||||
using ilowp_kind = ideep::lowp_kind; \
|
||||
/* the kind of padding, usually set as zero padding */ \
|
||||
using ipadding = ideep::padding_kind; \
|
||||
/* the data type of iDEEP tensor, e.g. f32, u8, s8 */ \
|
||||
using idtype = ideep::tensor::data_type; \
|
||||
/* the descriptor of iDEEP tensor */ \
|
||||
using itdesc = ideep::tensor::descriptor; \
|
||||
/* the attribute for operator to describe the details of inputs&fusion */ \
|
||||
using iattr = ideep::descriptor_group::attr_t; \
|
||||
using iattr = ideep::attr_t; \
|
||||
/* the detail flags for batch normalization */ \
|
||||
using ibn_flag = ideep::batch_normalization_flag;
|
||||
|
||||
|
@ -110,11 +110,11 @@ class IDEEPAdamOp final : public IDEEPOperator {
|
||||
CAFFE_ENFORCE(grad.get_nelems() == moment_1.get_nelems());
|
||||
CAFFE_ENFORCE(grad.get_nelems() == moment_2.get_nelems());
|
||||
if (params != *out_params)
|
||||
out_params->reinit(params.get_descriptor());
|
||||
out_params->init(params.get_descriptor());
|
||||
if (moment_1 != *out_moment1)
|
||||
out_moment1->reinit(moment_1.get_descriptor());
|
||||
out_moment1->init(moment_1.get_descriptor());
|
||||
if (moment_2 != *out_moment2)
|
||||
out_moment2->reinit(moment_2.get_descriptor());
|
||||
out_moment2->init(moment_2.get_descriptor());
|
||||
const auto w = static_cast<float *>(params.get_data_handle());
|
||||
const auto g = static_cast<float *>(grad.get_data_handle());
|
||||
const auto m = static_cast<float *>(moment_1.get_data_handle());
|
||||
@ -146,7 +146,7 @@ class IDEEPAdamOp final : public IDEEPOperator {
|
||||
} else {
|
||||
auto* out_grad = Output(OUTPUT_GRAD);
|
||||
if (grad != *out_grad)
|
||||
out_grad->reinit(grad.get_descriptor());
|
||||
out_grad->init(grad.get_descriptor());
|
||||
auto ng = static_cast<float *>(out_grad->get_data_handle());
|
||||
adam_ideep_compute_output_grad(
|
||||
grad.get_nelems(),
|
||||
|
@ -36,95 +36,98 @@ class IDEEPConvOp : public IDEEPConvPoolOpBase {
|
||||
const auto& X = Input(INPUT_X);
|
||||
const auto& filter = Input(FILTER);
|
||||
auto* Y = Output(OUTPUT);
|
||||
auto grouped = filter.is_grouped() ? 1 : 0;
|
||||
auto Y_dims_conv = CalcOutputDims(
|
||||
X,
|
||||
grouped ? (filter.get_dim(0) * filter.get_dim(1)) : filter.get_dim(0));
|
||||
|
||||
CAFFE_ENFORCE(4 == X.ndims());
|
||||
CAFFE_ENFORCE(4 == filter.ndims() || (grouped && (group_ > 1)));
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(2 + grouped), kernel_h());
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(3 + grouped), kernel_w());
|
||||
CAFFE_ENFORCE(4 == filter.ndims());
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(2), kernel_h());
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(3), kernel_w());
|
||||
CAFFE_ENFORCE(
|
||||
X.get_dim(1) == filter.get_dim(1 + grouped) * group_,
|
||||
X.get_dim(1) == filter.get_dim(1) * group_,
|
||||
"Convolution op: input channels does not match: # of input channels ",
|
||||
X.get_dim(1),
|
||||
" is not equal to kernel channels * group:",
|
||||
filter.get_dim(1 + grouped),
|
||||
filter.get_dim(1),
|
||||
"*",
|
||||
group_);
|
||||
|
||||
bool input_changed = (cached_X_descriptor_ != X.get_descriptor());
|
||||
if (input_changed) {
|
||||
op_key_.clear();
|
||||
cached_X_descriptor_ = X.dup_descriptor();
|
||||
}
|
||||
|
||||
bool weights_changed = (cached_weights_descriptor_ != filter.get_descriptor());
|
||||
if (!training_mode_ && weights_changed) {
|
||||
op_key_.clear();
|
||||
cached_weights_descriptor_ = filter.dup_descriptor();
|
||||
auto filter_in = filter.as_weights();
|
||||
filter_in.make_group(group_);
|
||||
|
||||
auto expected_descriptor =
|
||||
ideep::convolution_forward::expected_weights_descriptor(
|
||||
filter_in.get_dims(),
|
||||
ideep::convolution_forward::expected_weights_desc(
|
||||
filter.get_dims(),
|
||||
idtype::f32,
|
||||
stride_,
|
||||
{stride_.begin(), stride_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
dilation_,
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
group_,
|
||||
algo_,
|
||||
pk_,
|
||||
idtype::f32,
|
||||
X.get_dims());
|
||||
if (filter_in.get_descriptor() != expected_descriptor) {
|
||||
if (filter.get_descriptor() != expected_descriptor) {
|
||||
filter_.init(expected_descriptor);
|
||||
filter_.feed_from(filter_in);
|
||||
filter_.feed_from(filter);
|
||||
} else {
|
||||
filter_ = filter_in;
|
||||
filter_ = filter;
|
||||
}
|
||||
}
|
||||
|
||||
if (InputSize() > last_input_) {
|
||||
ideep::convolution_forward::compute(
|
||||
op_key_,
|
||||
X,
|
||||
training_mode_ ? filter : filter_,
|
||||
Input(BIAS_OR_INPUT_S),
|
||||
Y_dims_conv,
|
||||
*Y,
|
||||
stride_,
|
||||
dilation_,
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
attr_,
|
||||
algo_,
|
||||
pk_);
|
||||
bool with_bias = InputSize() > last_input_;
|
||||
auto filter_in = training_mode_ ? filter : filter_;
|
||||
if (training_mode_ || input_changed || weights_changed) {
|
||||
auto Y_dims_conv = CalcOutputDims(X, filter.get_dim(0));
|
||||
if (with_bias) {
|
||||
ideep::convolution_forward::prepare(
|
||||
conv_param,
|
||||
X,
|
||||
filter_in,
|
||||
Input(BIAS_OR_INPUT_S),
|
||||
Y_dims_conv,
|
||||
*Y,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
attr_,
|
||||
algo_,
|
||||
pk_);
|
||||
} else {
|
||||
ideep::convolution_forward::prepare(
|
||||
conv_param,
|
||||
X,
|
||||
filter_in,
|
||||
Y_dims_conv,
|
||||
*Y,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
attr_,
|
||||
algo_,
|
||||
pk_);
|
||||
}
|
||||
}
|
||||
|
||||
if (with_bias) {
|
||||
ideep::convolution_forward::compute(conv_param, X, filter_in,
|
||||
Input(BIAS_OR_INPUT_S), *Y);
|
||||
} else {
|
||||
ideep::convolution_forward::compute(
|
||||
op_key_,
|
||||
X,
|
||||
training_mode_ ? filter : filter_,
|
||||
Y_dims_conv,
|
||||
*Y,
|
||||
stride_,
|
||||
dilation_,
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
dummy_scale_,
|
||||
attr_,
|
||||
algo_,
|
||||
pk_);
|
||||
ideep::convolution_forward::compute(conv_param, X, filter_in, *Y);
|
||||
}
|
||||
|
||||
if (fusion_type_ == FUSION_CONV_SUM
|
||||
@ -140,13 +143,13 @@ class IDEEPConvOp : public IDEEPConvPoolOpBase {
|
||||
iprop pk_;
|
||||
ialgo algo_;
|
||||
iattr attr_;
|
||||
ikey op_key_;
|
||||
int last_input_;
|
||||
bool training_mode_;
|
||||
FusionType fusion_type_;
|
||||
itensor filter_;
|
||||
iscale dummy_scale_;
|
||||
itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
|
||||
ideep::convolution_forward_params conv_param;
|
||||
|
||||
INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
|
||||
OUTPUT_TAGS(OUTPUT);
|
||||
@ -289,8 +292,8 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
|
||||
dY,
|
||||
filter.get_dims(),
|
||||
*dfilter,
|
||||
stride_,
|
||||
dilation_,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_);
|
||||
@ -302,8 +305,8 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
|
||||
filter.get_dims(),
|
||||
*dfilter,
|
||||
*dbias,
|
||||
stride_,
|
||||
dilation_,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_);
|
||||
@ -316,8 +319,8 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
|
||||
filter,
|
||||
X.get_dims(),
|
||||
*dX,
|
||||
stride_,
|
||||
dilation_,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_);
|
||||
|
@ -33,7 +33,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
|
||||
const ideep::tensor& input,
|
||||
int output_channel) {
|
||||
CAFFE_ENFORCE_GT(input.get_size(), 0);
|
||||
ideep::tensor::dims output_dims;
|
||||
std::vector<int> output_dims;
|
||||
const auto input_dims = input.get_dims();
|
||||
std::vector<std::int64_t> input_Tdims(
|
||||
input_dims.cbegin(), input_dims.cend());
|
||||
@ -48,7 +48,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
|
||||
&kernel_,
|
||||
&pads_,
|
||||
&output_dims);
|
||||
return output_dims;
|
||||
return {output_dims.begin(), output_dims.end()};
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
|
@ -28,60 +28,48 @@ class IDEEPConvTransposeOp final : public IDEEPConvTransposeUnpoolBase {
|
||||
CAFFE_ENFORCE_EQ(filter.ndims(), 4);
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(2), kernel_h());
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(3), kernel_w());
|
||||
CAFFE_ENFORCE_EQ(filter.get_dim(0), X.get_dim(1),
|
||||
"filter number must be equal to input channel number");
|
||||
|
||||
ideep::tensor::dims Y_dims;
|
||||
const bool pre_converted = filter.get_public_format() == ideep::format::iohw;
|
||||
if (!pre_converted) {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
filter.get_dim(0), X.get_dim(1),
|
||||
"filter number must be equal to input channel number");
|
||||
auto Y_dims = CalcOutputDims(X, filter.get_dim(1));
|
||||
|
||||
Y_dims = CalcOutputDims(X, filter.get_dim(1));
|
||||
bool weights_changed = (cached_weights_descriptor_ != filter.get_descriptor());
|
||||
if (!training_mode_ && weights_changed) {
|
||||
cached_weights_descriptor_ = filter.dup_descriptor();
|
||||
auto filter_in = filter;
|
||||
|
||||
ideep::tensor::dims filter_dims_mkldnn {filter.get_dim(1), filter.get_dim(0),
|
||||
filter.get_dim(2), filter.get_dim(3)};
|
||||
auto expected_descriptor =
|
||||
ideep::convolution_transpose_forward::expected_weights_descriptor(
|
||||
filter_dims_mkldnn,
|
||||
ideep::convolution_transpose_forward::expected_weights_desc(
|
||||
filter.get_dims(),
|
||||
filter.get_data_type(),
|
||||
stride_,
|
||||
{stride_.begin(), stride_.end()},
|
||||
pad_tl(),
|
||||
pad_br());
|
||||
const bool weights_changed =
|
||||
(cached_weights_descriptor_ != filter.get_descriptor());
|
||||
if (weights_changed) {
|
||||
cached_weights_descriptor_ = filter.dup_descriptor();
|
||||
}
|
||||
|
||||
if (training_mode_ || weights_changed) {
|
||||
auto filter_in = filter;
|
||||
// Framework has filters in IOHW while MKL-DNN requires OIHW,
|
||||
// we have to do explicit conversion here.
|
||||
filter_in.set_public_format(ideep::format::iohw);
|
||||
if (filter_in.get_descriptor() != expected_descriptor) {
|
||||
filter_.init(expected_descriptor);
|
||||
filter_.feed_from(filter_in);
|
||||
filter_.feed_from(filter_in, /*is_deconv_weights=*/true);
|
||||
} else {
|
||||
filter_ = filter_in;
|
||||
}
|
||||
|
||||
} else {
|
||||
CAFFE_ENFORCE_EQ(
|
||||
filter.get_dim(1), X.get_dim(1),
|
||||
"filter number must be equal to input channel number");
|
||||
|
||||
Y_dims = CalcOutputDims(X, filter.get_dim(0));
|
||||
}
|
||||
|
||||
auto transposed_filter = training_mode_ ? filter : filter_;
|
||||
transposed_filter.transpose_(0, 1);
|
||||
|
||||
if (InputSize() > BIAS) {
|
||||
const auto& bias = Input(BIAS);
|
||||
CAFFE_ENFORCE_EQ(bias.ndims(), 1, "bias must be 1D tensor");
|
||||
CAFFE_ENFORCE_EQ(
|
||||
bias.get_dim(0), pre_converted ? filter.get_dim(0) : filter.get_dim(1),
|
||||
bias.get_dim(0), filter.get_dim(1),
|
||||
"bias dimension must be equal to output channel number");
|
||||
|
||||
ideep::convolution_transpose_forward::compute(
|
||||
X, pre_converted ? filter : filter_, bias, Y_dims, *Y, stride_, pad_tl(), pad_br());
|
||||
X, transposed_filter, bias, Y_dims, *Y,
|
||||
{stride_.begin(), stride_.end()} , pad_tl(), pad_br());
|
||||
} else {
|
||||
ideep::convolution_transpose_forward::compute(
|
||||
X, pre_converted ? filter : filter_, Y_dims, *Y, stride_, pad_tl(), pad_br());
|
||||
X, transposed_filter, Y_dims, *Y,
|
||||
{stride_.begin(), stride_.end()}, pad_tl(), pad_br());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -121,60 +109,36 @@ class IDEEPConvTransposeGradientOp final : public IDEEPConvTransposeUnpoolBase {
|
||||
const auto& filter = Input(FILTER);
|
||||
const auto& dY = Input(OUTPUT_GRAD);
|
||||
auto* dfilter = Output(FILTER_GRAD);
|
||||
|
||||
itensor dfilter_;
|
||||
itensor filter_;
|
||||
auto filter_in = filter;
|
||||
|
||||
itensor::dims oihw_dims {filter.get_dim(1), filter.get_dim(0),
|
||||
filter.get_dim(2), filter.get_dim(3)};
|
||||
const bool pre_converted = (filter.get_public_format() == ideep::format::iohw);
|
||||
if (!pre_converted) {
|
||||
auto expected_descriptor =
|
||||
ideep::convolution_transpose_forward::expected_weights_descriptor(
|
||||
oihw_dims,
|
||||
filter.get_data_type(),
|
||||
stride_,
|
||||
pad_tl(),
|
||||
pad_br());
|
||||
// Framework has filters in IOHW while MKL-DNN requires OIHW,
|
||||
// we have to do explicit conversion here.
|
||||
filter_in.set_public_format(ideep::format::iohw);
|
||||
filter_.init(expected_descriptor);
|
||||
filter_.feed_from(filter_in);
|
||||
}
|
||||
auto transposed_filter = filter;
|
||||
transposed_filter.transpose_(0, 1);
|
||||
|
||||
if (no_bias_) {
|
||||
ideep::convolution_transpose_backward_weights::compute(
|
||||
X, dY, pre_converted ? filter.get_dims() : oihw_dims,
|
||||
pre_converted ? *dfilter : dfilter_, stride_, pad_tl(), pad_br());
|
||||
X,
|
||||
dY,
|
||||
filter.get_dims(),
|
||||
*dfilter,
|
||||
{stride_.begin(), stride_.end()},
|
||||
pad_tl(),
|
||||
pad_br());
|
||||
} else {
|
||||
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
|
||||
ideep::convolution_transpose_backward_weights::compute(
|
||||
X,
|
||||
dY,
|
||||
pre_converted ? filter.get_dims() : oihw_dims,
|
||||
pre_converted ? *dfilter : dfilter_,
|
||||
filter.get_dims(),
|
||||
*dfilter,
|
||||
*dbias,
|
||||
stride_,
|
||||
{stride_.begin(), stride_.end()},
|
||||
pad_tl(),
|
||||
pad_br());
|
||||
}
|
||||
|
||||
if (!pre_converted) {
|
||||
// Framework has filters in IOHW while MKL-DNN requires OIHW,
|
||||
// we have to do explicit conversion here.
|
||||
dfilter_.set_public_format(ideep::format::iohw);
|
||||
dfilter->reinit(filter.get_descriptor());
|
||||
dfilter_.to_public(dfilter->get_data_handle());
|
||||
} else {
|
||||
dfilter->set_public_format(ideep::format::iohw);
|
||||
}
|
||||
|
||||
if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
|
||||
auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
|
||||
ideep::convolution_transpose_backward_data::compute(
|
||||
dY, pre_converted ? filter : filter_, X.get_dims(), *dX, stride_, pad_tl(), pad_br());
|
||||
dY, transposed_filter, X.get_dims(), *dX,
|
||||
{stride_.begin(), stride_.end()}, pad_tl(), pad_br());
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -28,7 +28,6 @@ class IDEEPFullyConnectedOp final : public IDEEPOperator {
|
||||
}
|
||||
|
||||
if (training_mode_) {
|
||||
op_key_.clear();
|
||||
filter_ = filter;
|
||||
auto filter_dims = CanonicalDims(filter_.get_dims(), axis_w_);
|
||||
if (filter_.get_dims() != filter_dims) {
|
||||
@ -40,12 +39,10 @@ class IDEEPFullyConnectedOp final : public IDEEPOperator {
|
||||
}
|
||||
} else {
|
||||
if (cached_X_descriptor_ != X.get_descriptor()) {
|
||||
op_key_.clear();
|
||||
cached_X_descriptor_ = X.dup_descriptor();
|
||||
}
|
||||
|
||||
if (cached_weights_descriptor_ != filter.get_descriptor()) {
|
||||
op_key_.clear();
|
||||
cached_weights_descriptor_ = filter.dup_descriptor();
|
||||
|
||||
filter_ = filter.has_scale() ? filter.to_public() : filter;
|
||||
@ -63,9 +60,9 @@ class IDEEPFullyConnectedOp final : public IDEEPOperator {
|
||||
|
||||
if (InputSize() > BIAS) {
|
||||
ideep::inner_product_forward::compute(
|
||||
op_key_, X_in, filter_, bias_, *Y);
|
||||
X_in, filter_, bias_, *Y);
|
||||
} else {
|
||||
ideep::inner_product_forward::compute(op_key_, X_in, filter_, *Y);
|
||||
ideep::inner_product_forward::compute(X_in, filter_, *Y);
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -76,7 +73,6 @@ class IDEEPFullyConnectedOp final : public IDEEPOperator {
|
||||
size_t axis_w_{1};
|
||||
bool training_mode_;
|
||||
|
||||
ikey op_key_;
|
||||
itensor filter_, bias_;
|
||||
itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
|
||||
|
||||
@ -115,6 +111,7 @@ class IDEEPFullyConnectedGradientOp final : public IDEEPOperator {
|
||||
}
|
||||
|
||||
ideep::inner_product_backward_weights::compute(X_in, dY, *dfilter, *dbias);
|
||||
dfilter->to_default_format();
|
||||
|
||||
/**
|
||||
* In mkl-dnn,weight gradient shape is determined by X_in,
|
||||
|
@ -49,10 +49,10 @@ class IDEEPMomentumSGDOp final : public IDEEPOperator {
|
||||
bool RunOnDevice() override {
|
||||
CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
|
||||
if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
|
||||
Output(OUTPUT_GRAD)->reinit(Input(GRAD).get_descriptor());
|
||||
Output(OUTPUT_GRAD)->init(Input(GRAD).get_descriptor());
|
||||
}
|
||||
if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
|
||||
Output(OUTPUT_MOMENTUM)->reinit(Input(MOMENTUM).get_descriptor());
|
||||
Output(OUTPUT_MOMENTUM)->init(Input(MOMENTUM).get_descriptor());
|
||||
}
|
||||
|
||||
// TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
|
||||
@ -91,10 +91,10 @@ class IDEEPMomentumSGDUpdateOp final : public IDEEPOperator {
|
||||
bool RunOnDevice() override {
|
||||
CAFFE_ENFORCE(Input(GRAD).get_nelems() == Input(MOMENTUM).get_nelems());
|
||||
if (Input(GRAD) != *Output(OUTPUT_GRAD)) {
|
||||
Output(OUTPUT_GRAD)->reinit(Input(GRAD).get_descriptor());
|
||||
Output(OUTPUT_GRAD)->init(Input(GRAD).get_descriptor());
|
||||
}
|
||||
if (Input(MOMENTUM) != *Output(OUTPUT_MOMENTUM)) {
|
||||
Output(OUTPUT_MOMENTUM)->reinit(Input(MOMENTUM).get_descriptor());
|
||||
Output(OUTPUT_MOMENTUM)->init(Input(MOMENTUM).get_descriptor());
|
||||
}
|
||||
|
||||
// TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
|
||||
|
@ -94,7 +94,7 @@ class IDEEPFallbackOp final : public IDEEPOperator {
|
||||
dtensor->Resize(input.get_dims());
|
||||
// If fallback from INT8, the public format of original input is nhwc.
|
||||
// While the required format is nchw, need to reorder to nchw.
|
||||
if (input.get_public_format() == iformat::nhwc) {
|
||||
if (input.get_desc().is_nhwc()) {
|
||||
itensor temp_ten ({input.get_dims(), idtype::f32, iformat::nchw},
|
||||
dtensor->template mutable_data<float>());
|
||||
temp_ten.feed_from(input);
|
||||
|
@ -13,7 +13,7 @@ class IDEEPNHWC2NCHWOp final : public IDEEPOperator {
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
CAFFE_ENFORCE_EQ(X.ndims(), 4);
|
||||
CAFFE_ENFORCE(X.get_internal_format() == iformat::nhwc);
|
||||
CAFFE_ENFORCE(X.get_desc().is_nhwc());
|
||||
|
||||
auto *Y = Output(OUTPUT);
|
||||
CAFFE_ENFORCE(Y != &X);
|
||||
@ -39,7 +39,7 @@ class IDEEPNCHW2NHWCOp final : public IDEEPOperator {
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
CAFFE_ENFORCE_EQ(X.ndims(), 4);
|
||||
CAFFE_ENFORCE(X.get_internal_format() == iformat::nchw);
|
||||
CAFFE_ENFORCE(X.get_desc().is_nchw());
|
||||
|
||||
auto *Y = Output(OUTPUT);
|
||||
CAFFE_ENFORCE(Y != &X);
|
||||
|
@ -41,12 +41,13 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
|
||||
auto Y_dims = CalcOutputDims(X, X.get_dim(1));
|
||||
|
||||
if (cached_X_descriptor_ != X.get_descriptor()) {
|
||||
op_key_.clear();
|
||||
cached_X_descriptor_ = X.dup_descriptor();
|
||||
}
|
||||
|
||||
ideep::pooling_forward::compute(op_key_, X, Y_dims, *Y,
|
||||
stride_, kernel_, pad_tl(), pad_br(), algo_, pk_);
|
||||
ideep::pooling_forward::compute(X, Y_dims, *Y,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{kernel_.begin(), kernel_.end()},
|
||||
pad_tl(), pad_br(), algo_, pk_);
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -54,7 +55,6 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
|
||||
private:
|
||||
iprop pk_;
|
||||
ialgo algo_;
|
||||
ikey op_key_;
|
||||
itensor::descriptor cached_X_descriptor_;
|
||||
|
||||
INPUT_TAGS(INPUT);
|
||||
@ -95,7 +95,9 @@ class IDEEPPoolGradientOp final : public IDEEPConvPoolOpBase {
|
||||
auto* dX = Output(INPUT_GRAD);
|
||||
|
||||
ideep::pooling_backward::compute(dY, Y, X, *dX,
|
||||
stride_, kernel_, pad_tl(), pad_br(), algo_);
|
||||
{stride_.begin(), stride_.end()},
|
||||
{kernel_.begin(), kernel_.end()},
|
||||
pad_tl(), pad_br(), algo_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -39,7 +39,8 @@ class IDEEPInt8SumReluOp final : public IDEEPOperator {
|
||||
if (input_dims.empty())
|
||||
input_dims = Xi.get_dims();
|
||||
CAFFE_ENFORCE(input_dims == Xi.get_dims());
|
||||
inputs_itensor.emplace_back(Xi);
|
||||
inputs_itensor.emplace_back(
|
||||
Xi.get_data_type() != idtype::f32 ? Xi.dequantize() : Xi);
|
||||
}
|
||||
|
||||
temp_ten.init({input_dims, idtype::f32});
|
||||
|
@ -33,7 +33,6 @@ class IDEEPInt8ConvOp : public IDEEPConvPoolOpBase {
|
||||
const auto &X = Input(INPUT_X);
|
||||
const auto &filter = Input(FILTER);
|
||||
auto *Y = Output(OUTPUT);
|
||||
auto Y_dims = CalcOutputDims(X, filter.get_dim(0));
|
||||
|
||||
CAFFE_ENFORCE(X.has_scale());
|
||||
CAFFE_ENFORCE(4 == X.ndims() && 4 == filter.ndims());
|
||||
@ -49,38 +48,38 @@ class IDEEPInt8ConvOp : public IDEEPConvPoolOpBase {
|
||||
|
||||
bool input_changed = (cached_X_descriptor_ != X.get_descriptor());
|
||||
if (input_changed) {
|
||||
op_key_.clear();
|
||||
cached_X_descriptor_ = X.dup_descriptor();
|
||||
}
|
||||
|
||||
bool weights_changed = (cached_weights_descriptor_ != filter.get_descriptor());
|
||||
if (weights_changed) {
|
||||
op_key_.clear();
|
||||
cached_weights_descriptor_ = filter.dup_descriptor();
|
||||
CAFFE_ENFORCE(filter.get_data_type() == idtype::s8 && filter.has_scale());
|
||||
|
||||
itensor filter_in;
|
||||
auto X_dt = X.get_data_type();
|
||||
lowp_kind_ = ilowp_kind::LOWP_U8S8;
|
||||
auto filter_scale = filter.get_scale();
|
||||
if (X_dt == idtype::s8) {
|
||||
lowp_kind_ = ilowp_kind::LOWP_S8S8;
|
||||
filter_in = filter.as_weights().to_public();
|
||||
} else {
|
||||
filter_in = filter.as_weights();
|
||||
}
|
||||
filter_in.make_group(group_);
|
||||
|
||||
auto expected_descriptor =
|
||||
ideep::convolution_forward::expected_weights_descriptor(
|
||||
filter_in.get_dims(), idtype::s8, stride_, pad_tl(), pad_br(),
|
||||
dilation_, group_, algo_, iprop::forward_inference, X_dt, X.get_dims());
|
||||
if (filter_in.get_descriptor() != expected_descriptor) {
|
||||
ideep::convolution_forward::expected_weights_desc(
|
||||
filter.get_dims(),
|
||||
idtype::s8,
|
||||
{stride_.begin(), stride_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
group_,
|
||||
algo_,
|
||||
iprop::forward_inference,
|
||||
X_dt, X.get_dims());
|
||||
if (filter.get_desc() != expected_descriptor) {
|
||||
filter_.init(expected_descriptor);
|
||||
filter_.set_scale(filter_scale);
|
||||
filter_.feed_from(filter_in);
|
||||
filter_.set_scale(filter.get_scale());
|
||||
filter_.feed_from(filter);
|
||||
} else {
|
||||
filter_ = filter_in;
|
||||
filter_ = filter;
|
||||
}
|
||||
|
||||
if (InputSize() > last_input_) {
|
||||
@ -96,18 +95,55 @@ class IDEEPInt8ConvOp : public IDEEPConvPoolOpBase {
|
||||
}
|
||||
}
|
||||
|
||||
if (InputSize() > last_input_) {
|
||||
ideep::convolution_forward::compute(
|
||||
op_key_, X, filter_, bias_, Y_dims, *Y,
|
||||
stride_, dilation_, pad_tl(), pad_br(), group_,
|
||||
iscale(), iscale(), Y_scales_, attr_, algo_,
|
||||
iprop::forward_inference, ipadding::zero, lowp_kind_);
|
||||
bool with_bias = InputSize() > last_input_;
|
||||
if (input_changed || weights_changed) {
|
||||
auto Y_dims = CalcOutputDims(X, filter.get_dim(0));
|
||||
if (with_bias) {
|
||||
ideep::convolution_forward::prepare(
|
||||
conv_param,
|
||||
X,
|
||||
filter_,
|
||||
bias_,
|
||||
Y_dims,
|
||||
*Y,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_,
|
||||
iscale(),
|
||||
iscale(),
|
||||
Y_scales_,
|
||||
attr_,
|
||||
algo_,
|
||||
iprop::forward_inference,
|
||||
lowp_kind_);
|
||||
} else {
|
||||
ideep::convolution_forward::prepare(
|
||||
conv_param,
|
||||
X,
|
||||
filter_,
|
||||
Y_dims,
|
||||
*Y,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{dilation_.begin(), dilation_.end()},
|
||||
pad_tl(),
|
||||
pad_br(),
|
||||
group_,
|
||||
iscale(),
|
||||
iscale(),
|
||||
Y_scales_,
|
||||
attr_,
|
||||
algo_,
|
||||
iprop::forward_inference,
|
||||
lowp_kind_);
|
||||
}
|
||||
}
|
||||
|
||||
if (with_bias) {
|
||||
ideep::convolution_forward::compute(conv_param, X, filter_, bias_, *Y);
|
||||
} else {
|
||||
ideep::convolution_forward::compute(
|
||||
op_key_, X, filter_, Y_dims, *Y,
|
||||
stride_, dilation_, pad_tl(), pad_br(), group_,
|
||||
iscale(), iscale(), Y_scales_, attr_, algo_,
|
||||
iprop::forward_inference, ipadding::zero, lowp_kind_);
|
||||
ideep::convolution_forward::compute(conv_param, X, filter_, *Y);
|
||||
}
|
||||
|
||||
if (fusion_type_ != FUSION_CONV_RELU && fusion_type_ != FUSION_UNKNOWN) {
|
||||
@ -122,7 +158,6 @@ class IDEEPInt8ConvOp : public IDEEPConvPoolOpBase {
|
||||
protected:
|
||||
iattr attr_;
|
||||
ialgo algo_;
|
||||
ikey op_key_;
|
||||
float scale_;
|
||||
int last_input_;
|
||||
int32_t zero_point_;
|
||||
@ -132,6 +167,7 @@ class IDEEPInt8ConvOp : public IDEEPConvPoolOpBase {
|
||||
itensor filter_, bias_;
|
||||
iscale Y_scales_;
|
||||
itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
|
||||
ideep::convolution_forward_params conv_param;
|
||||
|
||||
INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
|
||||
OUTPUT_TAGS(OUTPUT);
|
||||
|
@ -14,7 +14,8 @@ class IDEEPInt8DequantizeOp final : public IDEEPOperator {
|
||||
|
||||
if (HasArgument("output_order")) {
|
||||
Y_fmt_ = static_cast<iformat>(
|
||||
this->template GetSingleArgument<int>("output_order", iformat::nchw));
|
||||
this->template GetSingleArgument<int>("output_order",
|
||||
static_cast<int>(iformat::nchw)));
|
||||
}
|
||||
}
|
||||
virtual ~IDEEPInt8DequantizeOp() {}
|
||||
@ -22,17 +23,18 @@ class IDEEPInt8DequantizeOp final : public IDEEPOperator {
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
|
||||
Y->init({X.get_dims(), idtype::f32,
|
||||
Y_fmt_ != iformat::format_undef
|
||||
? Y_fmt_ : X.get_public_format()});
|
||||
if (Y_fmt_ != iformat::undef) {
|
||||
Y->init(X.get_desc().to_type(idtype::f32).to_format(Y_fmt_));
|
||||
} else {
|
||||
Y->init(X.get_desc().to_type(idtype::f32));
|
||||
}
|
||||
Y->feed_from(X);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
iformat Y_fmt_ {iformat::format_undef};
|
||||
iformat Y_fmt_ {iformat::undef};
|
||||
};
|
||||
|
||||
REGISTER_IDEEP_OPERATOR_WITH_ENGINE(Int8Dequantize, DNNLOWP, IDEEPInt8DequantizeOp);
|
||||
|
@ -40,13 +40,11 @@ public:
|
||||
}
|
||||
|
||||
if (cached_X_descriptor_ != X.get_descriptor()) {
|
||||
op_key_.clear();
|
||||
cached_X_descriptor_ = X.dup_descriptor();
|
||||
Y_.init({{X.get_dim(0), filter.get_dim(0)}, idtype::f32});
|
||||
}
|
||||
|
||||
if (cached_weights_descriptor_ != filter.get_descriptor()) {
|
||||
op_key_.clear();
|
||||
cached_weights_descriptor_ = filter.dup_descriptor();
|
||||
CAFFE_ENFORCE(filter.get_data_type() == idtype::s8 && filter.has_scale());
|
||||
|
||||
@ -66,11 +64,10 @@ public:
|
||||
|
||||
if (InputSize() > BIAS) {
|
||||
ideep::inner_product_forward::compute(
|
||||
op_key_, X_in, filter_, bias_, Y_);
|
||||
X_in, filter_, bias_, Y_);
|
||||
} else {
|
||||
ideep::inner_product_forward::compute(op_key_, X_in, filter_, Y_);
|
||||
ideep::inner_product_forward::compute(X_in, filter_, Y_);
|
||||
}
|
||||
|
||||
Y->init({Y_.get_dims(), Y_data_type_});
|
||||
Y->set_scale(Y_scales_);
|
||||
Y->feed_from(Y_);
|
||||
@ -83,7 +80,6 @@ private:
|
||||
float scale_;
|
||||
int32_t zero_point_;
|
||||
|
||||
ikey op_key_;
|
||||
idtype Y_data_type_;
|
||||
itensor filter_, bias_, Y_;
|
||||
iscale Y_scales_;
|
||||
|
@ -13,7 +13,7 @@ class IDEEPInt8GivenTensorFillOp final : public IDEEPOperator {
|
||||
: IDEEPOperator(operator_def, ws),
|
||||
zero_point_(
|
||||
this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
|
||||
shape_(this->template GetRepeatedArgument<int>("shape")) {
|
||||
shape_(this->template GetRepeatedArgument<itensor::dim>("shape")) {
|
||||
CAFFE_ENFORCE(shape_.size() == 4 || shape_.size() == 2 || shape_.size() == 1);
|
||||
CAFFE_ENFORCE(zero_point_ == 0 || zero_point_ == 128,
|
||||
"Not support zero point");
|
||||
@ -37,9 +37,10 @@ class IDEEPInt8GivenTensorFillOp final : public IDEEPOperator {
|
||||
}
|
||||
|
||||
auto source_values = this->template GetSingleArgument<string>("values", "");
|
||||
values_.Resize(source_values.size());
|
||||
auto src_size = source_values.size();
|
||||
values_.Resize(src_size);
|
||||
uint8_t* values_data = values_.template mutable_data<uint8_t>();
|
||||
for (int i = 0; i < source_values.size(); i++) {
|
||||
for (int i = 0; i < src_size; i++) {
|
||||
values_data[i] = static_cast<uint8_t>(source_values[i]);
|
||||
}
|
||||
}
|
||||
@ -64,7 +65,8 @@ class IDEEPInt8GivenTensorFillOp final : public IDEEPOperator {
|
||||
// Shift quantized data to s8 per zero point
|
||||
if (zero_point_ == 128) {
|
||||
auto* data_s8 = static_cast<int8_t*>(temp_ten.get_data_handle());
|
||||
for (int i = 0; i < temp_ten.get_nelems(); i++) {
|
||||
auto nelems = temp_ten.get_nelems();
|
||||
for (int i = 0; i < nelems; i++) {
|
||||
data_s8[i] = data_s8[i] - zero_point_;
|
||||
}
|
||||
}
|
||||
@ -95,7 +97,7 @@ class IDEEPInt8GivenIntTensorFillOp final : public IDEEPOperator {
|
||||
: IDEEPOperator(operator_def, ws),
|
||||
zero_point_(
|
||||
this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
|
||||
shape_(this->template GetRepeatedArgument<int>("shape")) {
|
||||
shape_(this->template GetRepeatedArgument<itensor::dim>("shape")) {
|
||||
CAFFE_ENFORCE(zero_point_ == 0, "Not support zero point");
|
||||
if (HasArgument("Y_scales")) {
|
||||
scales_ = this->template GetRepeatedArgument<float>("Y_scales");
|
||||
@ -105,9 +107,10 @@ class IDEEPInt8GivenIntTensorFillOp final : public IDEEPOperator {
|
||||
}
|
||||
|
||||
auto source_values = this->template GetRepeatedArgument<int32_t>("values");
|
||||
values_.Resize(source_values.size());
|
||||
auto src_size = source_values.size();
|
||||
values_.Resize(src_size);
|
||||
auto* values_data = values_.template mutable_data<int32_t>();
|
||||
for (int i = 0; i < source_values.size(); i++) {
|
||||
for (int i = 0; i < src_size; i++) {
|
||||
values_data[i] = static_cast<int32_t>(source_values[i]);
|
||||
}
|
||||
}
|
||||
|
@ -38,19 +38,20 @@ class IDEEPInt8PoolOp final : public IDEEPConvPoolOpBase {
|
||||
auto Y_dims = CalcOutputDims(X, X.get_dim(1));
|
||||
|
||||
if (cached_X_descriptor_ != X.get_descriptor()) {
|
||||
op_key_.clear();
|
||||
cached_X_descriptor_ = X.dup_descriptor();
|
||||
}
|
||||
|
||||
ideep::pooling_forward::compute(op_key_, X, Y_dims, *Y,
|
||||
stride_, kernel_, pad_tl(), pad_br(), algo_, iprop::forward_inference);
|
||||
ideep::pooling_forward::compute(X, Y_dims, *Y,
|
||||
{stride_.begin(), stride_.end()},
|
||||
{kernel_.begin(), kernel_.end()},
|
||||
pad_tl(), pad_br(), algo_,
|
||||
iprop::forward_inference);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
ialgo algo_;
|
||||
ikey op_key_;
|
||||
itensor::descriptor cached_X_descriptor_;
|
||||
|
||||
INPUT_TAGS(INPUT);
|
||||
|
@ -17,7 +17,8 @@ class IDEEPInt8QuantizeOp final : public IDEEPOperator {
|
||||
|
||||
if (HasArgument("output_order")) {
|
||||
Y_fmt_ = static_cast<iformat>(
|
||||
this->template GetSingleArgument<int>("output_order", iformat::nchw));
|
||||
this->template GetSingleArgument<int>("output_order",
|
||||
static_cast<int>(iformat::nchw)));
|
||||
}
|
||||
|
||||
CAFFE_ENFORCE(zero_point_ == 0 || zero_point_ == 128,
|
||||
@ -32,9 +33,11 @@ class IDEEPInt8QuantizeOp final : public IDEEPOperator {
|
||||
CAFFE_ENFORCE(X.get_data_type() == idtype::f32, "Not support data type");
|
||||
|
||||
auto* Y = Output(0);
|
||||
Y->init({X.get_dims(), Y_data_type_,
|
||||
Y_fmt_ != iformat::format_undef
|
||||
? Y_fmt_ : X.get_public_format()});
|
||||
if (Y_fmt_ != iformat::undef) {
|
||||
Y->init(X.get_desc().to_type(Y_data_type_).to_format(Y_fmt_));
|
||||
} else {
|
||||
Y->init(X.get_desc().to_type(Y_data_type_));
|
||||
}
|
||||
Y->set_scale(Y_scales_);
|
||||
Y->feed_from(X);
|
||||
|
||||
@ -46,7 +49,7 @@ class IDEEPInt8QuantizeOp final : public IDEEPOperator {
|
||||
int32_t zero_point_;
|
||||
iscale Y_scales_;
|
||||
idtype Y_data_type_;
|
||||
iformat Y_fmt_ {iformat::format_undef};
|
||||
iformat Y_fmt_ {iformat::undef};
|
||||
|
||||
INPUT_TAGS(INPUT0);
|
||||
OUTPUT_TAGS(OUTPUT);
|
||||
|
@ -31,7 +31,6 @@ class IDEEPInt8ReluOp final : public IDEEPOperator {
|
||||
}
|
||||
|
||||
private:
|
||||
ikey op_key_;
|
||||
float alpha_;
|
||||
|
||||
INPUT_TAGS(INPUT);
|
||||
|
@ -12,7 +12,7 @@ class IDEEPReshapeOp final : public IDEEPOperator {
|
||||
|
||||
IDEEPReshapeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: IDEEPOperator(operator_def, ws),
|
||||
new_shape_(OperatorBase::GetRepeatedArgument<int>("shape")) {}
|
||||
new_shape_(OperatorBase::GetRepeatedArgument<itensor::dim>("shape")) {}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
ideep::tensor::dims actual_new_shape = new_shape_;
|
||||
|
@ -37,8 +37,9 @@ class IDEEPSpatialBNOp final : public IDEEPOperator {
|
||||
if (is_test_) {
|
||||
const auto& est_mean = Input(EST_MEAN);
|
||||
const auto& est_var = Input(EST_VAR);
|
||||
auto X_ = X.get_data_type() != idtype::f32 ? X.dequantize() : X;
|
||||
ideep::batch_normalization_forward_inference::compute(
|
||||
X, est_mean, est_var, scale, bias, *Y, epsilon_);
|
||||
X_, est_mean, est_var, scale, bias, *Y, epsilon_);
|
||||
} else {
|
||||
auto* saved_mean = Output(SAVED_MEAN);
|
||||
auto* saved_var = Output(SAVED_VAR);
|
||||
|
@ -65,7 +65,8 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator {
|
||||
OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(CPU));
|
||||
X.to_public(Y->template mutable_data<float>());
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported ideep type: ", X.get_data_type());
|
||||
CAFFE_THROW("Unsupported ideep type: ",
|
||||
static_cast<int>(X.get_data_type()));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -671,7 +671,7 @@ bool fuseOrderSwitchToQuantizeOp(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
auto* seqOp = getMutableOpDef(*seq);
|
||||
auto* arg = seqOp->add_arg();
|
||||
arg->set_name("output_order");
|
||||
arg->set_i(iformat::nhwc);
|
||||
arg->set_i(static_cast<int64_t>(iformat::nhwc));
|
||||
|
||||
auto input = repr::nn::getInputs(osNode).front();
|
||||
nn->dataFlow.replaceNode(output, input);
|
||||
@ -696,7 +696,7 @@ bool fuseOrderSwitchToQuantizeOp(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
auto* preOp = getMutableOpDef(*pre);
|
||||
auto* arg = preOp->add_arg();
|
||||
arg->set_name("output_order");
|
||||
arg->set_i(iformat::nchw);
|
||||
arg->set_i(static_cast<int64_t>(iformat::nchw));
|
||||
|
||||
auto output = repr::nn::getOutputs(osNode).front();
|
||||
nn->dataFlow.replaceNode(input, output);
|
||||
@ -864,7 +864,7 @@ void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
|
||||
itensor::descriptor expectedDesc;
|
||||
if (repr::nn::is<repr::ConvTranspose>(node)) {
|
||||
if (filter->get_public_format() == ideep::format::iohw)
|
||||
if (filter->get_desc().is_iohw())
|
||||
continue;
|
||||
auto convTranspose = repr::nn::get<repr::ConvTranspose>(node);
|
||||
auto initValue = [](vector<int>& v, vector<int> i) {
|
||||
@ -883,19 +883,17 @@ void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
filter->get_dim(2),
|
||||
filter->get_dim(3)};
|
||||
expectedDesc =
|
||||
ideep::convolution_transpose_forward::expected_weights_descriptor(
|
||||
ideep::convolution_transpose_forward::expected_weights_desc(
|
||||
filter_dims_mkldnn,
|
||||
dataType,
|
||||
strides,
|
||||
{strides.begin(), strides.end()},
|
||||
{pads[0], pads[1]},
|
||||
{pads[2], pads[3]});
|
||||
|
||||
if (filter->get_descriptor() != expectedDesc) {
|
||||
filter->set_public_format(ideep::format::iohw);
|
||||
itensor newFilter;
|
||||
newFilter.init(expectedDesc);
|
||||
newFilter.feed_from(*filter);
|
||||
newFilter.set_public_format(ideep::format::iohw);
|
||||
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
|
||||
}
|
||||
} else if (repr::nn::is<repr::Conv>(node)) {
|
||||
@ -919,16 +917,14 @@ void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
aalgorithm = ialgo::convolution_winograd;
|
||||
}
|
||||
}
|
||||
auto dataType = filter->get_data_type();
|
||||
|
||||
filter->make_group(conv->getGroup());
|
||||
expectedDesc = ideep::convolution_forward::expected_weights_descriptor(
|
||||
expectedDesc = ideep::convolution_forward::expected_weights_desc(
|
||||
filter->get_dims(),
|
||||
dataType,
|
||||
strides,
|
||||
filter->get_data_type(),
|
||||
{strides.begin(), strides.end()},
|
||||
{pads[0], pads[1]},
|
||||
{pads[2], pads[3]},
|
||||
dilations,
|
||||
{dilations.begin(), dilations.end()},
|
||||
conv->getGroup(),
|
||||
aalgorithm);
|
||||
|
||||
@ -957,13 +953,13 @@ void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
filter->reshape({f_dim0, f_dim1});
|
||||
}
|
||||
|
||||
expectedDesc = ideep::inner_product_forward::expected_weights_descriptor(
|
||||
expectedDesc = ideep::inner_product_forward::expected_weights_desc(
|
||||
filter->get_dims());
|
||||
|
||||
if (filter->get_descriptor() != expectedDesc) {
|
||||
itensor newFilter;
|
||||
newFilter.init(expectedDesc);
|
||||
newFilter.feed_from(filter->as_weights());
|
||||
newFilter.feed_from(*filter);
|
||||
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ class IDeepFeeder : public BlobFeederBase {
|
||||
else if (meta == TypeMeta::Make<uint8_t>())
|
||||
return itensor::data_type::u8;
|
||||
else
|
||||
return itensor::data_type::data_undef;
|
||||
return itensor::data_type::undef;
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -43,9 +43,9 @@ ELSE(MKL_FOUND)
|
||||
ENDIF(MKL_FOUND)
|
||||
|
||||
SET(MKL_cmake_included TRUE)
|
||||
IF (NOT MKLDNN_THREADING)
|
||||
SET(MKLDNN_THREADING "OMP:COMP" CACHE STRING "")
|
||||
ELSEIF (MKLDNN_THREADING STREQUAL "TBB")
|
||||
IF (NOT MKLDNN_CPU_RUNTIME)
|
||||
SET(MKLDNN_CPU_RUNTIME "OMP" CACHE STRING "")
|
||||
ELSEIF (MKLDNN_CPU_RUNTIME STREQUAL "TBB")
|
||||
IF (USE_TBB)
|
||||
MESSAGE(STATUS "MKL-DNN is using TBB")
|
||||
|
||||
@ -59,14 +59,15 @@ ELSEIF (MKLDNN_THREADING STREQUAL "TBB")
|
||||
INCLUDE_DIRECTORIES(${TBB_INCLUDE_DIRS})
|
||||
LIST(APPEND EXTRA_SHARED_LIBS tbb)
|
||||
ELSE()
|
||||
MESSAGE(FATAL_ERROR "MKLDNN_THREADING is set to TBB but TBB is not used")
|
||||
MESSAGE(FATAL_ERROR "MKLDNN_CPU_RUNTIME is set to TBB but TBB is not used")
|
||||
ENDIF()
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "MKLDNN_THREADING = ${MKLDNN_THREADING}")
|
||||
MESSAGE(STATUS "MKLDNN_CPU_RUNTIME = ${MKLDNN_CPU_RUNTIME}")
|
||||
|
||||
SET(WITH_TEST FALSE CACHE BOOL "" FORCE)
|
||||
SET(WITH_EXAMPLE FALSE CACHE BOOL "" FORCE)
|
||||
SET(MKLDNN_BUILD_TESTS FALSE CACHE BOOL "" FORCE)
|
||||
SET(MKLDNN_BUILD_EXAMPLES FALSE CACHE BOOL "" FORCE)
|
||||
SET(MKLDNN_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)
|
||||
SET(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
|
||||
IF(MKLDNN_USE_NATIVE_ARCH) # Disable HostOpts in MKLDNN unless MKLDNN_USE_NATIVE_ARCH is set.
|
||||
SET(ARCH_OPT_FLAGS "HostOpts" CACHE STRING "" FORCE)
|
||||
ELSE()
|
||||
@ -78,23 +79,17 @@ ELSE()
|
||||
ENDIF()
|
||||
|
||||
ADD_SUBDIRECTORY(${MKLDNN_ROOT})
|
||||
IF(NOT TARGET mkldnn)
|
||||
IF(NOT TARGET dnnl)
|
||||
MESSAGE("Failed to include MKL-DNN target")
|
||||
RETURN()
|
||||
ENDIF(NOT TARGET mkldnn)
|
||||
IF(MKL_FOUND)
|
||||
SET(USE_MKL_CBLAS -DUSE_MKL)
|
||||
IF(USE_MKLDNN_CBLAS)
|
||||
LIST(APPEND USE_MKL_CBLAS -DUSE_CBLAS)
|
||||
ENDIF(USE_MKLDNN_CBLAS)
|
||||
TARGET_COMPILE_DEFINITIONS(mkldnn PRIVATE USE_MKL_CBLAS)
|
||||
ENDIF(MKL_FOUND)
|
||||
ENDIF(NOT TARGET dnnl)
|
||||
|
||||
IF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
|
||||
TARGET_COMPILE_OPTIONS(mkldnn PRIVATE -Wno-maybe-uninitialized)
|
||||
TARGET_COMPILE_OPTIONS(mkldnn PRIVATE -Wno-strict-overflow)
|
||||
TARGET_COMPILE_OPTIONS(mkldnn PRIVATE -Wno-error=strict-overflow)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-maybe-uninitialized)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-strict-overflow)
|
||||
TARGET_COMPILE_OPTIONS(dnnl PRIVATE -Wno-error=strict-overflow)
|
||||
ENDIF(NOT APPLE AND CMAKE_COMPILER_IS_GNUCC)
|
||||
LIST(APPEND MKLDNN_LIBRARIES mkldnn)
|
||||
LIST(APPEND MKLDNN_LIBRARIES dnnl)
|
||||
|
||||
SET(MKLDNN_FOUND TRUE)
|
||||
MESSAGE(STATUS "Found MKL-DNN: TRUE")
|
||||
|
@ -70,15 +70,15 @@ library a better choice in their application.
|
||||
PyTorch allows selecting of the parallelization backend used by ATen and other
|
||||
libraries at the build time with the following build options:
|
||||
|
||||
+------------+-----------------------+-----------------------------+----------------------------------------+
|
||||
| Library | Build Option | Values | Notes |
|
||||
+============+=======================+=============================+========================================+
|
||||
| ATen | ``ATEN_THREADING`` | ``OMP`` (default), ``TBB`` | |
|
||||
+------------+-----------------------+-----------------------------+----------------------------------------+
|
||||
| MKL | ``MKL_THREADING`` | (same) | To enable MKL use ``BLAS=MKL`` |
|
||||
+------------+-----------------------+-----------------------------+----------------------------------------+
|
||||
| MKL-DNN | ``MKLDNN_THREADING`` | (same) | To enable MKL-DNN use ``USE_MKLDNN=1`` |
|
||||
+------------+-----------------------+-----------------------------+----------------------------------------+
|
||||
+------------+------------------------+-----------------------------+----------------------------------------+
|
||||
| Library | Build Option | Values | Notes |
|
||||
+============+========================+=============================+========================================+
|
||||
| ATen | ``ATEN_THREADING`` | ``OMP`` (default), ``TBB`` | |
|
||||
+------------+------------------------+-----------------------------+----------------------------------------+
|
||||
| MKL | ``MKL_THREADING`` | (same) | To enable MKL use ``BLAS=MKL`` |
|
||||
+------------+------------------------+-----------------------------+----------------------------------------+
|
||||
| MKL-DNN | ``MKLDNN_CPU_RUNTIME`` | (same) | To enable MKL-DNN use ``USE_MKLDNN=1`` |
|
||||
+------------+------------------------+-----------------------------+----------------------------------------+
|
||||
|
||||
It is recommended not to mix OpenMP and TBB within one build.
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -42,7 +42,7 @@
|
||||
# USE_MKLDNN=0
|
||||
# disables use of MKLDNN
|
||||
#
|
||||
# MKLDNN_THREADING
|
||||
# MKLDNN_CPU_RUNTIME
|
||||
# MKL-DNN threading mode: TBB or OMP (default)
|
||||
#
|
||||
# USE_NNPACK=0
|
||||
|
@ -12,6 +12,7 @@ skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
import torch.backends.mkldnn
|
||||
from torch.utils import mkldnn as mkldnn_utils
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
|
||||
|
||||
@ -105,7 +106,7 @@ class TestMkldnn(TestCase):
|
||||
N = torch.randint(3, 10, (1,)).item()
|
||||
C = torch.randint(1, 3, (1,)).item() * groups
|
||||
M = torch.randint(1, 3, (1,)).item() * groups
|
||||
x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100
|
||||
x = torch.randn(N, C, 224, 224, dtype=torch.float32)
|
||||
for bias in [True, False]:
|
||||
conv2d = torch.nn.Conv2d(in_channels=C,
|
||||
out_channels=M,
|
||||
@ -115,9 +116,10 @@ class TestMkldnn(TestCase):
|
||||
bias=bias,
|
||||
groups=groups).float()
|
||||
mkldnn_conv2d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv2d))
|
||||
self.assertEqual(
|
||||
conv2d(x),
|
||||
mkldnn_conv2d(x.to_mkldnn()).to_dense())
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
y_aten = conv2d(x)
|
||||
y_mkldnn = mkldnn_conv2d(x.to_mkldnn()).to_dense()
|
||||
self.assertEqual(y_aten, y_mkldnn)
|
||||
|
||||
self._test_serialization(mkldnn_conv2d, (x.to_mkldnn(),))
|
||||
self._test_tracing(mkldnn_conv2d, (x.to_mkldnn(),))
|
||||
@ -143,9 +145,7 @@ class TestMkldnn(TestCase):
|
||||
conv2d_loaded = torch.jit.load(fname)
|
||||
|
||||
self.assertEqual(conv2d_mkldnn.weight.ndimension(), 5)
|
||||
# with DNNL upgrade we should switch to no-reordering,
|
||||
# but for now we keep the 5d tensor
|
||||
# self.assertEqual(conv2d_loaded.weight.ndimension(), 4)
|
||||
self.assertEqual(conv2d_loaded.weight.ndimension(), 4)
|
||||
self.assertEqual(
|
||||
conv2d(x),
|
||||
conv2d_loaded(x.to_mkldnn()).to_dense())
|
||||
@ -323,6 +323,22 @@ class TestMkldnn(TestCase):
|
||||
z.to_dense(),
|
||||
)
|
||||
|
||||
def test_reshape_blocked_format(self):
|
||||
# construct an mkldnn blocked tensor with mkldnn conv2d
|
||||
C = 7
|
||||
m = mkldnn_utils.to_mkldnn(torch.nn.Conv2d(C, C, 3))
|
||||
x = torch.randn(1, C, 8, 8).to_mkldnn()
|
||||
|
||||
# mkldnn tensor w/ blocked format
|
||||
y_block = m(x)
|
||||
# aten tensor w/ plain format
|
||||
y_plain = y_block.to_dense()
|
||||
|
||||
y_block_reshape = y_block.reshape(C, -1)
|
||||
y_plain_reshape = y_plain.reshape(C, -1)
|
||||
|
||||
self.assertEqual(y_plain_reshape, y_block_reshape.to_dense())
|
||||
|
||||
def test_clone(self):
|
||||
x = torch.randn(4, 5, dtype=torch.float32) * 10
|
||||
self.assertEqual(
|
||||
|
2
third_party/ideep
vendored
2
third_party/ideep
vendored
Submodule third_party/ideep updated: 3fc96899dc...c36993484d
@ -237,7 +237,7 @@ class CMake:
|
||||
'INTEL_MKL_DIR',
|
||||
'INTEL_OMP_DIR',
|
||||
'MKL_THREADING',
|
||||
'MKLDNN_THREADING',
|
||||
'MKLDNN_CPU_RUNTIME',
|
||||
'MSVC_Z7_OVERRIDE',
|
||||
'Numa_INCLUDE_DIR',
|
||||
'Numa_LIBRARIES',
|
||||
|
Reference in New Issue
Block a user