Upgrade mkldnn-bridge for dnnlowp support (#16308)

Summary:
The mkldnn-bridge is upgraded in this PR to support DNNLOWP operators.
Meanwhile, APIs have been updated in caffe2 to use latest version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16308

Differential Revision: D14697018

Pulled By: yinghai

fbshipit-source-id: ca952589098accb08295fd5aa92924c61e74d69c
This commit is contained in:
Gu, Jinghui
2019-04-03 10:29:19 -07:00
committed by Facebook Github Bot
parent 46a68c1b37
commit a7b82a44c4
12 changed files with 406 additions and 320 deletions

View File

@ -12,16 +12,41 @@ namespace caffe2 {
enum ConvAlgorithm {
CONV_ALGORITHM_AUTO = 0,
CONV_ALGORITHM_WINOGRAD = 1,
CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1
CONV_ALGORITHM_MAX
};
enum FusionType {
FUSION_UNKNOWN = 0,
FUSION_CONV_RELU = 1,
FUSION_CONV_SUM = 2,
FUSION_CONV_SUM_RELU = 3,
FUSION_MAX
};
#define USE_IDEEP_DEF_ALIASES() \
/* the hash key of cahced operator generated by iDEEP */ \
using ikey = ideep::key_t; \
/* the tensor type created/handled by iDEEP */ \
using itensor = ideep::tensor; \
/* the date layout of iDEEP tensor */ \
using iformat = ideep::format; \
/* the scales for iDEEP tensor with different data type */ \
using iscale = ideep::scale_t; \
/* the detial algorithm for iDEEP operators, e.g. winograd */ \
using ialgo = ideep::algorithm; \
/* the kind of propagation for iDEEP operators, e.g. forward, training */ \
using iprop = ideep::prop_kind; \
/* the kind of low precision operators, e.g. signed/unsigned activation */ \
using ilowp_kind = ideep::lowp_kind; \
/* the kind of padding, usually set as zero padding */ \
using ipadding = ideep::padding_kind; \
/* the data type of iDEEP tensor, e.g. f32, u8, s8 */ \
using idtype = ideep::tensor::data_type; \
/* the descriptor of iDEEP tensor */ \
using itdesc = ideep::tensor::descriptor; \
/* the attribute for operator to describe the details of inputs&fusion */ \
using iattr = ideep::descriptor_group::attr_t; \
/* the detail flags for batch normalization */ \
using ibn_flag = ideep::batch_normalization_flag;
} // namespace caffe2

View File

@ -1,219 +0,0 @@
#include <caffe2/ideep/operators/conv_pool_base_op.h>
namespace caffe2 {
class IDEEPConvFusionOp final : public IDEEPConvPoolOpBase {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
enum FusionType {
FUSION_UNKNOWN = 0,
FUSION_CONV_RELU = 1,
FUSION_CONV_SUM = 2,
FUSION_CONV_SUM_RELU = 3,
FUSION_MAX = FUSION_CONV_SUM_RELU + 1,
};
IDEEPConvFusionOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPConvPoolOpBase(operator_def, ws),
fusion_type_(static_cast<FusionType>(
OperatorBase::GetSingleArgument<int>("fusion_type", 0))),
training_mode_(
OperatorBase::GetSingleArgument<int>("training_mode", 0)),
conv_algorithm_(
OperatorBase::GetSingleArgument<int>("conv_algorithm", CONV_ALGORITHM_AUTO)) {
OPERATOR_NEEDS_FEATURE(
pad_l() == pad_r() && pad_t() == pad_b(),
"Uneven padding not supported.");
OPERATOR_NEEDS_FEATURE(group_ == 1, "Group not supported.");
OPERATOR_NEEDS_FEATURE(
fusion_type_ > FUSION_UNKNOWN && fusion_type_ < FUSION_MAX,
"Undefined Conv fusion type.",
fusion_type_);
// Check kernel only if we are doing conv. The reason is that a
// few other ops, like PadImage, are also using this base class. We really
// need to clean this up.
for (int dim = 0; dim < kernel_.size(); ++dim) {
CAFFE_ENFORCE_GE(pads_[dim], 0);
CAFFE_ENFORCE_GE(pads_[kernel_.size() + dim], 0);
CAFFE_ENFORCE(
kernel_[dim],
"If you are doing convolution, you will need to set "
"explicitly the kernel size.");
}
}
~IDEEPConvFusionOp() override {}
bool RunOnDeviceWithOrderNCHW() override {
const auto& X = Input(INPUT_X);
const auto& filter = Input(FILTER);
auto* Y = Output(OUTPUT);
auto Y_dims_conv = CalcOutputDims(X, filter.get_dim(0));
auto attr = [this]() {
return (fusion_type_ == FUSION_CONV_RELU)
? iattr::fuse_relu()
: ((fusion_type_ == FUSION_CONV_SUM)
? iattr::fuse_sum()
: ((fusion_type_ == FUSION_CONV_SUM_RELU) ? iattr::residual()
: iattr()));
};
auto last_input = [this]() {
return (fusion_type_ == FUSION_CONV_RELU) ? BIAS_OR_INPUT_S : INPUT_S;
};
CAFFE_ENFORCE(4 == X.ndims());
CAFFE_ENFORCE(4 == filter.ndims());
CAFFE_ENFORCE(filter.get_dim(2) == kernel_h());
CAFFE_ENFORCE(filter.get_dim(3) == kernel_w());
CAFFE_ENFORCE(
X.get_dim(1) == filter.get_dim(1) * group_,
"Convolution fusion op: input channels does not match: "
"# of input channels ",
X.get_dim(1),
" is not equal to kernel channels * group:",
filter.get_dim(1),
"*",
group_);
ideep::algorithm aalgorithm = ideep::algorithm::convolution_direct;
if (conv_algorithm_ == CONV_ALGORITHM_WINOGRAD) {
aalgorithm = ideep::algorithm::convolution_winograd;
}
bool weights_changed =
(cached_weights_descriptor_ != filter.get_descriptor());
if (weights_changed && !training_mode_) {
cached_weights_descriptor_ = filter.get_descriptor();
filter_ = filter;
auto expected_descriptor =
ideep::convolution_forward::expected_weights_descriptor(
filter.get_dims());
if (filter_.get_descriptor() != expected_descriptor) {
filter_.init<ideep::utils::allocator, ideep::convolution_forward>(
expected_descriptor);
ideep::reorder::compute(filter, filter_);
}
}
if (InputSize() > last_input()) {
ideep::convolution_forward::compute(
X,
training_mode_ ? filter : filter_,
Input(BIAS_OR_INPUT_S),
Y_dims_conv,
*Y,
stride_,
dilation_,
pad_tl(),
pad_br(),
group_,
attr(),
aalgorithm);
} else {
ideep::convolution_forward::compute(
X,
training_mode_ ? filter : filter_,
Y_dims_conv,
*Y,
stride_,
dilation_,
pad_tl(),
pad_br(),
group_,
attr(),
aalgorithm);
}
if (fusion_type_ != FUSION_CONV_RELU) {
CAFFE_ENFORCE(
Y == &(Input(InputSize() - 1)),
"Convolution fusion op: InPlace is enforced for sum fusion.");
}
return true;
}
private:
FusionType fusion_type_;
bool training_mode_;
int conv_algorithm_;
ideep::tensor filter_;
ideep::tensor::descriptor cached_weights_descriptor_;
INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
OUTPUT_TAGS(OUTPUT);
};
REGISTER_IDEEP_OPERATOR(ConvFusion, IDEEPConvFusionOp);
const char* kConvFusionDoc = R"DOC(
Note that other parameters, such as the stride and
kernel size, or the pads' sizes in each direction are not necessary for input
because they are provided by the ConvPoolOpBase operator. Various dimension
checks are done implicitly, and the sizes are specified in the Input docs for
this operator. As is expected, the filter is convolved with a subset of the
image and the bias is added; this is done throughout the image data and the
output is computed. As a side note on the implementation layout:
conv_op_impl.h is the templated implementation of the conv_op.h file, which is
why they are separate files.
)DOC";
std::function<void(OpSchema&)> ConvFusionDocGenerator(const char* dim) {
return [=](OpSchema& schema) {
string doc = R"DOC(
The convolution fusion operator consumes an input vector, a {dim}filter blob,
a bias blob and another input vector and computes the output. This operator
gives the chance to fuse the ReLU or element-wise Sum with a convolution
operator. {conv_fusion_doc})DOC";
c10::ReplaceAll(doc, "{dim}", dim);
c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc);
schema.SetDoc(doc);
schema.Input(
0,
"X",
"Input data blob from previous layer; has size (N x C x H x W), "
"where N is the batch size, C is the number of channels, "
"and H and W are the height and width. Note that this is for the NCHW "
"usage. On the other hand, the NHWC Op has a different set of "
"dimension constraints. ");
schema.Input(
1,
"filter",
"The filter blob that will be used in the "
"convolutions; has size (M x C x kH x kW), where C is the number of "
"channels, and kH and kW are the height and width of the kernel.");
schema.Input(
2,
"bias",
"The 1D bias blob that is added through the "
"convolution; has size (M).");
schema.Input(
3,
"S",
"Input data blob for element-wise Sum fusion from previous layer; "
"has the same size of convolution output. Its input index should "
"be 2 if no bias for this convolution, and it MUST be inplace with "
"output Y.");
schema.Output(
0,
"Y",
"Output data blob that contains the result of the "
"convolution fusion. The output dimensions are functions of the kernel "
"size, stride size, and pad lengths."
"");
};
}
OPERATOR_SCHEMA(ConvFusion)
.NumInputs(2, 4)
.NumOutputs(1)
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
.CostInferenceFunction(OpSchema::CostInferenceFunctionType(
ConvPoolOpBase<CPUContext>::CostInferenceForConv))
.Arg("fusion_type", "Which fusion type is used")
.AllowInplace({{2, 0}, {3, 0}})
.FillUsing(ConvFusionDocGenerator(""));
} // namespace caffe2

View File

@ -2,116 +2,258 @@
namespace caffe2 {
class IDEEPConvOp final : public IDEEPConvPoolOpBase {
class IDEEPConvOp : public IDEEPConvPoolOpBase {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
IDEEPConvOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPConvPoolOpBase(operator_def, ws),
training_mode_(
OperatorBase::GetSingleArgument<int>("training_mode", 0)),
conv_algorithm_(
OperatorBase::GetSingleArgument<int>("conv_algorithm", CONV_ALGORITHM_AUTO)) {
: IDEEPConvPoolOpBase(operator_def, ws) {
OPERATOR_NEEDS_FEATURE(
order_ == StorageOrder::NCHW, "Unsupported storage order.");
OPERATOR_NEEDS_FEATURE(
pad_l() == pad_r() && pad_t() == pad_b(),
"Uneven padding not supported.");
fusion_type_ = FUSION_UNKNOWN;
last_input_ = BIAS_OR_INPUT_S;
training_mode_ = OperatorBase::GetSingleArgument<int>("training_mode", 0);
pk_ = training_mode_ ? iprop::forward_training : iprop::forward_inference;
algo_ = ialgo::convolution_direct;
auto conv_algorithm = OperatorBase::GetSingleArgument<int>(
"conv_algorithm", CONV_ALGORITHM_AUTO);
if (conv_algorithm == CONV_ALGORITHM_WINOGRAD) {
algo_ = ialgo::convolution_winograd;
}
}
~IDEEPConvOp() override {}
virtual ~IDEEPConvOp() {}
bool RunOnDeviceWithOrderNCHW() override {
const auto& X = Input(INPUT);
const auto& X = Input(INPUT_X);
const auto& filter = Input(FILTER);
auto* Y = Output(OUTPUT);
auto Y_dims = CalcOutputDims(X, filter.get_dim(0));
auto grouped = filter.is_grouped() ? 1 : 0;
auto Y_dims_conv = CalcOutputDims(
X,
grouped ? (filter.get_dim(0) * filter.get_dim(1)) : filter.get_dim(0));
CAFFE_ENFORCE(4 == X.ndims());
CAFFE_ENFORCE(4 == filter.ndims());
CAFFE_ENFORCE(filter.get_dim(2) == kernel_h());
CAFFE_ENFORCE(filter.get_dim(3) == kernel_w());
CAFFE_ENFORCE(4 == filter.ndims() || (grouped && (group_ > 1)));
CAFFE_ENFORCE_EQ(filter.get_dim(2 + grouped), kernel_h());
CAFFE_ENFORCE_EQ(filter.get_dim(3 + grouped), kernel_w());
CAFFE_ENFORCE(
X.get_dim(1) == filter.get_dim(1) * group_,
X.get_dim(1) == filter.get_dim(1 + grouped) * group_,
"Convolution op: input channels does not match: # of input channels ",
X.get_dim(1),
" is not equal to kernel channels * group:",
filter.get_dim(1),
filter.get_dim(1 + grouped),
"*",
group_);
ideep::algorithm aalgorithm = ideep::algorithm::convolution_direct;
if (conv_algorithm_ == CONV_ALGORITHM_WINOGRAD) {
aalgorithm = ideep::algorithm::convolution_winograd;
}
bool weights_changed =
(cached_weights_descriptor_ != filter.get_descriptor());
if (weights_changed && !training_mode_) {
cached_weights_descriptor_ = filter.get_descriptor();
auto filter_in = filter;
op_key_.clear();
cached_weights_descriptor_ = filter.dup_descriptor();
auto filter_in = filter.as_weights();
filter_in.make_group(group_);
auto expected_descriptor =
ideep::convolution_forward::expected_weights_descriptor(
filter_in.get_dims(),
filter_in.get_data_type(),
idtype::f32,
stride_,
pad_tl(),
pad_br(),
dilation_,
group_,
aalgorithm);
filter_.init<ideep::utils::allocator, ideep::convolution_forward>(
expected_descriptor);
ideep::reorder::compute(filter_in, filter_);
algo_,
pk_,
idtype::f32,
X.get_dims());
if (filter_in.get_descriptor() != expected_descriptor) {
filter_.init(expected_descriptor);
filter_.feed_from(filter_in);
} else {
filter_ = filter_in;
}
}
// NB: actually, in the case when `group_ > 1`, IDEEP will create
// an itermediate tensor for each run below. However, this tensor is merely
// a view of of the weights and there is no actual data copy, so I'll let it
// go now. If we encounter performance surprise when convoluting with group
// > 1, this is the first place to check and we need to do the same cache
// trick as above
if (InputSize() > BIAS) {
if (cached_X_descriptor_ != X.get_descriptor()) {
op_key_.clear();
cached_X_descriptor_ = X.dup_descriptor();
}
if (InputSize() > last_input_) {
ideep::convolution_forward::compute(
op_key_,
X,
training_mode_ ? filter : filter_,
Input(BIAS),
Y_dims,
Input(BIAS_OR_INPUT_S),
Y_dims_conv,
*Y,
stride_,
dilation_,
pad_tl(),
pad_br(),
group_,
ideep::descriptor_group::attr_t(),
aalgorithm);
dummy_scale_,
dummy_scale_,
dummy_scale_,
attr_,
algo_,
pk_);
} else {
ideep::convolution_forward::compute(
op_key_,
X,
training_mode_ ? filter : filter_,
Y_dims,
Y_dims_conv,
*Y,
stride_,
dilation_,
pad_tl(),
pad_br(),
group_,
ideep::descriptor_group::attr_t(),
aalgorithm);
dummy_scale_,
dummy_scale_,
dummy_scale_,
attr_,
algo_,
pk_);
}
if (fusion_type_ == FUSION_CONV_SUM
&& fusion_type_ == FUSION_CONV_SUM_RELU) {
CAFFE_ENFORCE_EQ(Y, &(Input(InputSize() - 1)),
"Convolution fusion op: InPlace is enforced for sum fusion.");
}
return true;
}
private:
INPUT_TAGS(INPUT, FILTER, BIAS);
OUTPUT_TAGS(OUTPUT);
protected:
iprop pk_;
ialgo algo_;
iattr attr_;
ikey op_key_;
int last_input_;
bool training_mode_;
int conv_algorithm_;
ideep::tensor filter_;
ideep::tensor::descriptor cached_weights_descriptor_;
FusionType fusion_type_;
itensor filter_;
iscale dummy_scale_;
itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
OUTPUT_TAGS(OUTPUT);
};
class IDEEPConvFusionOp final : public IDEEPConvOp {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
IDEEPConvFusionOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPConvOp(operator_def, ws) {
CAFFE_ENFORCE(OperatorBase::HasArgument("fusion_type"),
"You should specify the fusion type");
fusion_type_ = static_cast<FusionType>(
OperatorBase::GetSingleArgument<int>("fusion_type", FUSION_UNKNOWN));
OPERATOR_NEEDS_FEATURE(
fusion_type_ > FUSION_UNKNOWN && fusion_type_ < FUSION_MAX,
"Undefined Conv fusion type.",
fusion_type_);
switch (fusion_type_) {
case FUSION_CONV_RELU:
attr_ = iattr::fuse_relu();
last_input_ = BIAS_OR_INPUT_S;
break;
case FUSION_CONV_SUM:
attr_ = iattr::fuse_sum();
last_input_ = INPUT_S;
break;
case FUSION_CONV_SUM_RELU:
attr_ = iattr::residual();
last_input_ = INPUT_S;
break;
default:
CAFFE_THROW("Unsupported conv fusion type!");
}
}
virtual ~IDEEPConvFusionOp() {}
};
const char* kConvFusionDoc = R"DOC(
Note that other parameters, such as the stride and
kernel size, or the pads' sizes in each direction are not necessary for input
because they are provided by the ConvPoolOpBase operator. Various dimension
checks are done implicitly, and the sizes are specified in the Input docs for
this operator. As is expected, the filter is convolved with a subset of the
image and the bias is added; this is done throughout the image data and the
output is computed. As a side note on the implementation layout:
conv_op_impl.h is the templated implementation of the conv_op.h file, which is
why they are separate files.
)DOC";
std::function<void(OpSchema&)> ConvFusionDocGenerator(const char* dim) {
return [=](OpSchema& schema) {
string doc = R"DOC(
The convolution fusion operator consumes an input vector, a {dim}filter blob,
a bias blob and another input vector and computes the output. This operator
gives the chance to fuse the ReLU or element-wise Sum with a convolution
operator. {conv_fusion_doc})DOC";
c10::ReplaceAll(doc, "{dim}", dim);
c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc);
schema.SetDoc(doc);
schema.Input(
0,
"X",
"Input data blob from previous layer; has size (N x C x H x W), "
"where N is the batch size, C is the number of channels, "
"and H and W are the height and width. Note that this is for the NCHW "
"usage. On the other hand, the NHWC Op has a different set of "
"dimension constraints. ");
schema.Input(
1,
"filter",
"The filter blob that will be used in the "
"convolutions; has size (M x C x kH x kW), where C is the number of "
"channels, and kH and kW are the height and width of the kernel.");
schema.Input(
2,
"bias",
"The 1D bias blob that is added through the "
"convolution; has size (M).");
schema.Input(
3,
"S",
"Input data blob for element-wise Sum fusion from previous layer; "
"has the same size of convolution output. Its input index should "
"be 2 if no bias for this convolution, and it MUST be inplace with "
"output Y.");
schema.Output(
0,
"Y",
"Output data blob that contains the result of the "
"convolution fusion. The output dimensions are functions of the kernel "
"size, stride size, and pad lengths."
"");
};
}
OPERATOR_SCHEMA(ConvFusion)
.NumInputs(2, 4)
.NumOutputs(1)
.TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
.CostInferenceFunction(OpSchema::CostInferenceFunctionType(
ConvPoolOpBase<CPUContext>::CostInferenceForConv))
.Arg("fusion_type", "Which fusion type is used")
.AllowInplace({{2, 0}, {3, 0}})
.FillUsing(ConvFusionDocGenerator(""));
class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
public:
USE_IDEEP_DEF_ALIASES();
@ -131,7 +273,7 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
"In order to backward propagate weights correctly, "
"please set training_mode=1");
}
~IDEEPConvGradientOp() override {}
virtual ~IDEEPConvGradientOp() {}
bool RunOnDeviceWithOrderNCHW() override {
const auto& X = Input(INPUT);
@ -190,6 +332,7 @@ class IDEEPConvGradientOp final : public IDEEPConvPoolOpBase {
};
REGISTER_IDEEP_OPERATOR(Conv, IDEEPConvOp);
REGISTER_IDEEP_OPERATOR(ConvFusion, IDEEPConvFusionOp);
REGISTER_IDEEP_OPERATOR(ConvGradient, IDEEPConvGradientOp);
} // namespace caffe2

View File

@ -11,10 +11,7 @@ namespace caffe2 {
class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
public:
IDEEPConvPoolOpBase(const OperatorDef& operator_def, Workspace* ws)
: ConvPoolOpBase<IDEEPContext>(operator_def, ws) {
OPERATOR_NEEDS_FEATURE(
order_ == StorageOrder::NCHW, "Unsupported storage order.");
}
: ConvPoolOpBase<IDEEPContext>(operator_def, ws) {}
virtual ~IDEEPConvPoolOpBase() {}
inline const ideep::tensor& Input(int index) {
@ -35,7 +32,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
ideep::tensor::dims CalcOutputDims(
const ideep::tensor& input,
int output_channel) {
CAFFE_ENFORCE(input.get_descriptor().get_size() > 0);
CAFFE_ENFORCE_GT(input.get_size(), 0);
ideep::tensor::dims output_dims;
const auto input_dims = input.get_dims();
std::vector<std::int64_t> input_Tdims(
@ -43,7 +40,7 @@ class IDEEPConvPoolOpBase : public ConvPoolOpBase<IDEEPContext> {
InferOutputSize(
input_Tdims,
output_channel,
order_,
StorageOrder::NCHW, //order_,
global_pooling_,
legacy_pad_,
dilation_,

View File

@ -77,7 +77,7 @@ class IDEEPConvTransposeOp final : public IDEEPConvTransposeUnpoolBase {
// we have to do explicit conversion here.
filter_in.set_public_format(ideep::format::iohw);
filter_.init(expected_descriptor);
ideep::reorder::compute(filter_in, filter_);
filter_.feed_from(filter_in);
}
// TODO: The code below works around correctness issues with particular input shapes
@ -178,7 +178,7 @@ class IDEEPConvTransposeGradientOp final : public IDEEPConvTransposeUnpoolBase {
// we have to do explicit conversion here.
filter_in.set_public_format(ideep::format::iohw);
filter_.init(expected_descriptor);
ideep::reorder::compute(filter_in, filter_);
filter_.feed_from(filter_in);
// TODO: The code below works around correctness issues with particular input shapes
// in MKL-DNN v0.17, will be removed with the fixes in MKL-DNN 0.18.

View File

@ -82,20 +82,29 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
bool RunOnDevice() override {
for (int i = 0; i < InputSize(); ++i) {
if (InputIsType<itensor>(i) &&
Input(i).get_data_type() == itensor::data_type::f32) {
if (InputIsType<itensor>(i)
&& (Input(i).has_scale()
|| Input(i).get_data_type() == idtype::f32)) {
auto& input = Input(i);
if (input_share_[i]) {
local_input_blobs_[i]->Reset();
input_share_[i] = false;
}
input_share_[i] = false;
auto dtensor = BlobGetMutableTensor(local_input_blobs_[i], CPU);
dtensor->Resize(input.get_dims());
if (input.is_public_format()) {
// If fallback from INT8, the public format of original input is nhwc.
// While the required format is nchw, need to reorder to nchw.
if (input.get_public_format() == iformat::nhwc) {
itensor temp_ten ({input.get_dims(), idtype::f32, iformat::nchw},
dtensor->template mutable_data<float>());
temp_ten.feed_from(input);
} else if (!input.need_reorder()) {
CAFFE_ENFORCE(!input.has_scale(),
"Incorrect invocation of get_data_handle");
dtensor->ShareExternalPointer(
static_cast<float*>(input.get_data_handle()));
} else {
input.reorder_to(dtensor->template mutable_data<float>());
input.to_public(dtensor->template mutable_data<float>());
}
} else {
VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
@ -143,12 +152,14 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
itensor::dims dst_dims (src_dims.begin(), src_dims.end());
auto dtensor = dst->template GetMutable<itensor>();
if (dtensor->get_dims() != dst_dims) {
dtensor->resize(dst_dims, itensor::data_type::f32);
dtensor->resize(dst_dims, idtype::f32);
}
if (output_inplace_[i]) {
dtensor->reorder_from(dst_dims, itensor::data_type::f32,
const_cast<void*>(src.raw_data()));
dtensor->feed_from(dst_dims, idtype::f32,
const_cast<void*>(src.raw_data()));
} else {
CAFFE_ENFORCE(!dtensor->has_scale(),
"Incorrect invocation of set_data_handle");
dtensor->set_data_handle(const_cast<void *>(src.raw_data()));
}
} else {

View File

@ -8,9 +8,7 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
IDEEPPoolOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPConvPoolOpBase(operator_def, ws),
training_mode_(
OperatorBase::GetSingleArgument<int>("training_mode", 1)) {
: IDEEPConvPoolOpBase(operator_def, ws) {
CAFFE_ENFORCE(
(dilation_h() == 1) && (dilation_w() == 1),
"Pooling op does not support dilation right now.");
@ -20,6 +18,10 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
pad_l() < kernel_w() && pad_r() < kernel_w(),
"Pad should be smaller than kernel.");
}
bool training_mode = OperatorBase::GetSingleArgument<int>("training_mode", 1);
pk_ = training_mode ? iprop::forward_training : iprop::forward_inference;
// Figure out the pooling descriptor.
if (operator_def.type().substr(0, 7) == "MaxPool") {
algo_ = ialgo::pooling_max;
@ -35,18 +37,23 @@ class IDEEPPoolOp final : public IDEEPConvPoolOpBase {
auto& X = Input(INPUT);
auto* Y = Output(OUTPUT);
auto Y_dims = CalcOutputDims(X, X.get_dim(1));
mkldnn::prop_kind pk = training_mode_ ?
mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_inference;
ideep::pooling_forward::compute(X, Y_dims, *Y,
stride_, kernel_, pad_tl(), pad_br(), algo_, pk);
if (cached_X_descriptor_ != X.get_descriptor()) {
op_key_.clear();
cached_X_descriptor_ = X.dup_descriptor();
}
ideep::pooling_forward::compute(op_key_, X, Y_dims, *Y,
stride_, kernel_, pad_tl(), pad_br(), algo_, pk_);
return true;
}
private:
iprop pk_;
ialgo algo_;
bool training_mode_;
ikey op_key_;
itensor::descriptor cached_X_descriptor_;
INPUT_TAGS(INPUT);
OUTPUT_TAGS(OUTPUT);

View File

@ -19,7 +19,7 @@ class CopyCPUToIDEEPOp final : public IDEEPOperator {
Y->Reset(new itensor());
Y->GetMutable<itensor>()->resize(src_dims, itensor::data_type::f32);
}
Y->GetMutable<itensor>()->reorder_from(
Y->GetMutable<itensor>()->feed_from(
src_dims, itensor::data_type::f32, X.raw_data());
return true;
}
@ -61,7 +61,7 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator {
}
auto* Y =
OperatorBase::OutputTensor(0, dims, at::dtype<float>().device(CPU));
X.reorder_to(Y->template mutable_data<float>());
X.to_public(Y->template mutable_data<float>());
} else {
CAFFE_THROW("Unsupported ideep type: ", X.get_data_type());
}

View File

@ -16,6 +16,8 @@ C10_DECLARE_REGISTRY(
C10_REGISTER_CREATOR(IDEEPOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_IDEEP_OPERATOR(name, ...) \
C10_REGISTER_CLASS(IDEEPOperatorRegistry, name, __VA_ARGS__)
#define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \
C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
#define REGISTER_IDEEP_OPERATOR_STR(str_name, ...) \
C10_REGISTER_TYPED_CLASS(IDEEPOperatorRegistry, str_name, __VA_ARGS__)
#define REGISTER_IDEEP_COMPARE_OPERATOR(Op) \
@ -27,8 +29,6 @@ C10_DECLARE_REGISTRY(
Op##Functor<CPUContext>, \
FixedType<bool>>>)
#define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \
C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
// IDEEPOperator is the base scaffolding of the operators that uses IDEEP. It
// provides a few operators that are useful to IDEEP specific implementations.
@ -39,8 +39,6 @@ class IDEEPOperator : public OperatorBase {
context_(operator_def.device_option()),
order_(StringToStorageOrder(
OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
OPERATOR_NEEDS_FEATURE(
order_ == StorageOrder::NCHW, "Unsupported storage order.");
}
virtual ~IDEEPOperator() {}
@ -119,4 +117,19 @@ class IDEEPOperator : public OperatorBase {
: IDEEPOperator(operator_def, ws) {} \
virtual ~name() {}
// Convert zero_point scales to min_max scales
// NOTE:
// The scales in operator is saved in FBGEMM format,
// while FBGEMM scales are the reciprocals of MKL-DNN scales.
// This function is provided to convert scales from FBGEMM to MKL-DNN
inline ideep::scale_t ConvertScales(
const std::vector<float> scales_z) {
ideep::scale_t scales (scales_z);
for (auto it = scales.begin(); it != scales.end(); it++) {
*it = 1.0f / *it;
}
return scales;
}
} // namespace caffe2

View File

@ -3,6 +3,7 @@
#include "caffe2/opt/fusion.h"
#ifdef CAFFE2_USE_MKLDNN
#include <cpuinfo.h>
#include "caffe2/ideep/ideep_utils.h"
#endif
@ -78,7 +79,7 @@ bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) {
}
bool shouldFuseConv(const repr::Conv& conv) {
return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
return isOnIdeepDevice(conv);
}
void removeStopGradientForInference(repr::NNModule* nn) {
@ -110,10 +111,6 @@ void removeStopGradientForInference(repr::NNModule* nn) {
}
void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
// Fusion types:
// FUSION_CONV_RELU = 1
// FUSION_CONV_SUM = 2
// FUSION_CONV_SUM_RELU = 3
auto conv = repr::nn::get<repr::Conv>(convNode);
auto annotation = conv->getMutableAnnotation();
if (!annotation || !isa<Caffe2Annotation>(annotation)) {
@ -126,19 +123,18 @@ void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
}
if (op->type() == "ConvFusion") {
CAFFE_ENFORCE(fusion_type == 1, "Invalid nest fusion");
CAFFE_ENFORCE(fusion_type == FUSION_CONV_RELU, "Invalid nest fusion");
for (auto& arg : *op->mutable_arg()) {
if (arg.name() == "fusion_type") {
// Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU
CAFFE_ENFORCE(arg.i() == 2, "Invalid nest fusion");
arg.set_i(3);
CAFFE_ENFORCE(arg.i() == FUSION_CONV_SUM, "Invalid nest fusion");
arg.set_i(FUSION_CONV_SUM_RELU);
return;
}
}
return;
}
CAFFE_ENFORCE(fusion_type < 3, "Invalid fusion type");
CAFFE_ENFORCE_LT(fusion_type, FUSION_CONV_SUM_RELU, "Invalid fusion type");
op->set_type("ConvFusion");
auto* arg = op->add_arg();
arg->set_name("fusion_type");
@ -224,7 +220,7 @@ bool fuseConvBNAndAffChHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws)
continue; \
} \
name##Tensor.resize(name->get_dims(), name->get_data_type()); \
name##Tensor.reorder_from(*name); \
name##Tensor.feed_from(*name); \
CAFFE_ENFORCE( \
name##Tensor.is_public_format(), #name " not with public format"); \
name##Data = static_cast<float*>(name##Tensor.get_data_handle()); \
@ -263,8 +259,8 @@ bool fuseConvBNAndAffChHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws)
}
}
filter->reorder_from(filterTensor);
biasConv->reorder_from(biasConvTensor);
filter->feed_from(filterTensor);
biasConv->feed_from(biasConvTensor);
nn->dataFlow.replaceNode(convOutput, bnOrAffChOutput);
nn->dataFlow.deleteNode(bnOrAffChNode);
@ -282,6 +278,7 @@ void fuseConvBNAndAffChForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
}
void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo");
// Assume the order of nodes from getMutableNodes conforms to
// the original topo order of operators
auto allNodes = nn->dataFlow.getMutableNodes();
@ -342,11 +339,16 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
}
auto conv = repr::nn::get<repr::Conv>(convNode);
if (!shouldFuseConv(*conv)) {
if (!isOnIdeepDevice(*conv)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
if (conv->getGroup() > 1 && !cpuinfo_has_x86_avx512f()) {
LOG(WARNING) << "Not support conv sum fusion with grouped filter";
continue;
}
auto convOutput = repr::nn::getOutputs(convNode).front();
repr::NNGraph::NodeRef sumInputX =
(sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
@ -366,8 +368,7 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
auto sumOutput = repr::nn::getOutputs(sumNode).front();
nn->dataFlow.replaceNode(sumOutput, newOutput);
// 2 means FUSION_CONV_SUM
resetConvForFusion(convNode, 2);
resetConvForFusion(convNode, FUSION_CONV_SUM);
nn->dataFlow.createEdge(sumInputX, convNode);
nn->dataFlow.createEdge(convNode, newOutput);
@ -405,8 +406,8 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
bool enforce_inplace = false;
for (const auto& arg : op.arg()) {
// Only check FUSION_SUM & FUSION_SUM_RELU
if (arg.name() == "fusion_type" && (arg.i() == 2 || arg.i() == 3)) {
if (arg.name() == "fusion_type"
&& (arg.i() == FUSION_CONV_SUM || arg.i() == FUSION_CONV_SUM_RELU)) {
enforce_inplace = true;
break;
}

View File

@ -256,6 +256,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.SwitchWorkspace(old_ws_name)
@given(stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(3, 5),
@ -406,6 +407,113 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.SwitchWorkspace(old_ws_name)
@given(stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(3, 5),
size=st.integers(8, 20),
input_channels=st.integers(7, 17),
output_channels=st.integers(5, 15),
batch_size=st.integers(1, 3),
use_bias=st.booleans(),
group=st.integers(2, 5),
**mu.gcs)
def test_convolution_grouped_sum_relu_fusion(self, stride, pad, kernel, size,
input_channels, output_channels,
batch_size, use_bias, group, gc, dc):
conv_S0 = core.CreateOperator(
"Conv",
["SX0", "Sw0", "Sb0"] if use_bias else ["SX0", "Sw0"],
["S0"],
stride=stride,
pad=pad,
kernel=kernel,
group=group,
device_option=dc[0]
)
conv = core.CreateOperator(
"Conv",
["X0", "w0", "b0"] if use_bias else ["X0", "w0"],
["Y0"],
stride=stride,
pad=pad,
kernel=kernel,
group=group,
device_option=dc[0]
)
sum = core.CreateOperator(
"Sum",
["S0", "Y0"],
["S0"],
device_option=dc[0]
)
relu = core.CreateOperator(
"Relu",
["S0"],
["S0"],
device_option=dc[0]
)
SX = np.random.rand(
batch_size, input_channels * group, size, size).astype(np.float32) - 0.5
Sw = np.random.rand(
output_channels * group, input_channels, kernel, kernel) \
.astype(np.float32) - 0.5
Sb = np.random.rand(output_channels * group).astype(np.float32) - 0.5
X = np.random.rand(
batch_size, input_channels * group, size, size).astype(np.float32) - 0.5
w = np.random.rand(
output_channels * group, input_channels, kernel, kernel) \
.astype(np.float32) - 0.5
b = np.random.rand(output_channels * group).astype(np.float32) - 0.5
old_ws_name = workspace.CurrentWorkspace()
workspace.SwitchWorkspace("_device_check_", True)
workspace.FeedBlob('SX0', SX, dc[0])
workspace.FeedBlob('Sw0', Sw, dc[0])
workspace.FeedBlob('Sb0', Sb, dc[0])
workspace.FeedBlob('X0', X, dc[0])
workspace.FeedBlob('w0', w, dc[0])
workspace.FeedBlob('b0', b, dc[0])
workspace.RunOperatorOnce(conv_S0)
workspace.RunOperatorOnce(conv)
workspace.RunOperatorOnce(sum)
workspace.RunOperatorOnce(relu)
S0 = workspace.FetchBlob('S0')
workspace.ResetWorkspace()
old_net = caffe2_pb2.NetDef()
conv_S0_old = caffe2_pb2.OperatorDef()
conv_S0_old.CopyFrom(conv_S0)
conv_S0_old.device_option.CopyFrom(dc[1])
conv_old = caffe2_pb2.OperatorDef()
conv_old.CopyFrom(conv)
conv_old.device_option.CopyFrom(dc[1])
sum_old = caffe2_pb2.OperatorDef()
sum_old.CopyFrom(sum)
sum_old.device_option.CopyFrom(dc[1])
relu_old = caffe2_pb2.OperatorDef()
relu_old.CopyFrom(relu)
relu_old.device_option.CopyFrom(dc[1])
old_net.op.extend([conv_S0_old, conv_old, sum_old, relu_old])
workspace.FeedBlob('SX0', SX, dc[1])
workspace.FeedBlob('Sw0', Sw, dc[1])
workspace.FeedBlob('Sb0', Sb, dc[1])
workspace.FeedBlob('X0', X, dc[1])
workspace.FeedBlob('w0', w, dc[1])
workspace.FeedBlob('b0', b, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
workspace.RunNetOnce(net.Proto())
S2 = workspace.FetchBlob('S0')
if not np.allclose(S0, S2, atol=0.01, rtol=0.01):
print(S2.flatten())
print(S0.flatten())
print(np.max(np.abs(S2 - S0)))
self.assertTrue(False)
workspace.SwitchWorkspace(old_ws_name)
@given(stride=st.integers(1, 3),
pad=st.integers(0, 3),
kernel=st.integers(3, 5),

View File

@ -59,13 +59,13 @@ public:
(atensor.get_nelems() == 0 ||
atensor.get_data_handle() != nullptr),
"Trying to fetch uninitialized tensor");
const int numpy_type = CaffeToNumpyType(type_transform(atensor));
// NOTE: Only support float so far.
const int numpy_type = NPY_FLOAT;
CAFFE_ENFORCE(
numpy_type != -1,
"Unsupported ideep memory data type? This usually should not happen "
"since ideep memory usually only do float and double.");
itensor::dims dims = atensor.get_public_format_dims();
std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
result.copied = force_copy || atensor.need_reorder();
@ -86,7 +86,7 @@ public:
}
if (result.copied) {
atensor.reorder_to(outPtr);
atensor.to_public(outPtr);
}
return result;
@ -144,7 +144,7 @@ public:
if (tensor->get_dims() != adims || type != tensor->get_data_type()) {
tensor->resize(adims, type);
}
tensor->reorder_from(adims, type,
tensor->feed_from(adims, type,
static_cast<void *>(PyArray_DATA(array)));
}
#else