support pre-convert filter format for mkldnn training mode and change 'OptimizeForIdeep' to 'OptimizeForMkldnn' (#15171)

Summary:
For MKL-DNN,the filter data will be reorderd to primitive format, it takes a lot of time.
So the patch provide a method to convert filter format before training.
And "OptimizeForIdeep" will be changed to "OptimizeForMkldnn" in this patch.
 This patch depends on https://github.com/pytorch/pytorch/pull/12866
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15171

Differential Revision: D14590741

Pulled By: yinghai

fbshipit-source-id: 07971c9977edac3c8eec08ca2c39cda639683492
This commit is contained in:
Cheng,Penghui
2019-03-29 18:51:50 -07:00
committed by Facebook Github Bot
parent d73c830e23
commit e13101e069
12 changed files with 292 additions and 31 deletions

View File

@ -9,6 +9,12 @@
namespace caffe2 {
enum ConvAlgorithm {
CONV_ALGORITHM_AUTO = 0,
CONV_ALGORITHM_WINOGRAD = 1,
CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1
};
#define USE_IDEEP_DEF_ALIASES() \
using itensor = ideep::tensor; \
using iformat = ideep::format; \
@ -18,7 +24,4 @@ namespace caffe2 {
using iattr = ideep::descriptor_group::attr_t; \
using ibn_flag = ideep::batch_normalization_flag;
const int CONV_ALGORITHM_AUTO = 0;
const int CONV_ALGORITHM_WINOGRAD = 1;
} // namespace caffe2

View File

@ -140,8 +140,30 @@ class ConvConverter : public Converter {
~ConvConverter() override {}
};
class ConvTransposeConverter : public Converter {
std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator(
const OperatorDef& op) override {
std::unique_ptr<repr::NeuralNetOperator> nnOp;
auto argMap = getArgumentsFromOperator(op);
auto kernelShape = getKernelShape(argMap);
nnOp = util::make_unique<repr::ConvTranspose>(kernelShape);
auto c = dyn_cast<repr::ConvTranspose>(nnOp.get());
c->setStrides(getStrides(argMap));
c->setPads(getPads(argMap));
c->setGroup(getGroup(argMap));
return nnOp;
}
// Does not override default converter to OperatorDef
virtual ~ConvTransposeConverter() {}
};
REGISTER_CONVERTER(Conv, ConvConverter);
REGISTER_CONVERTER(ConvTranspose, ConvTransposeConverter);
TRIVIAL_CONVERTER(Relu);
REGISTER_CONVERTER(Relu, ReluConverter);

View File

@ -12,7 +12,7 @@ namespace opt {
using namespace nom;
#ifndef CAFFE2_USE_MKLDNN
void OptimizeForIdeep(
void OptimizeForMkldnn(
repr::NNModule* nn,
caffe2::Workspace* ws,
bool training_mode) {
@ -37,6 +37,15 @@ T* getTensor(Blob* blob) {
return nullptr;
}
template <class T>
T* getMutableTensor(Blob* blob) {
CAFFE_ENFORCE(blob, "Blob is invalid");
if (blob->template IsType<T>()) {
return blob->template GetMutable<T>();
}
return nullptr;
}
const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) {
auto annotation = nnOp.getAnnotation();
if (annotation == nullptr) {
@ -72,7 +81,7 @@ bool shouldFuseConv(const repr::Conv& conv) {
return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
}
void removeStopGradientForInference(repr::NNModule *nn) {
void removeStopGradientForInference(repr::NNModule* nn) {
auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) {
if (!repr::nn::is<repr::NeuralNetOperator>(node)) {
return false;
@ -430,10 +439,10 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
}
}
void setPoolingInferenceMode(repr::NNModule *nn) {
void setPoolingInferenceMode(repr::NNModule* nn) {
for (auto node_pair : repr::nn::dataIterator<repr::MaxPool>(nn->dataFlow)) {
repr::NNGraph::NodeRef maxPoolNode;
repr::MaxPool *maxPool;
repr::MaxPool* maxPool;
std::tie(maxPool, maxPoolNode) = node_pair;
if (!isOnIdeepDevice(*maxPool)) {
@ -441,9 +450,9 @@ void setPoolingInferenceMode(repr::NNModule *nn) {
continue;
}
auto *op = getMutableOpDef(*maxPool);
auto* op = getMutableOpDef(*maxPool);
bool found_training_mode = false;
for (auto &arg : *op->mutable_arg()) {
for (auto& arg : *op->mutable_arg()) {
if (arg.name() == "training_mode") {
arg.set_i(0);
found_training_mode = true;
@ -452,19 +461,149 @@ void setPoolingInferenceMode(repr::NNModule *nn) {
}
if (!found_training_mode) {
auto *arg = op->add_arg();
auto* arg = op->add_arg();
arg->set_name("training_mode");
arg->set_i(0);
}
}
}
void OptimizeForIdeep(
repr::NNModule* nn,
caffe2::Workspace* ws,
bool training_mode) {
// Pre-convert filters format to expected one here
// in order to avoid boring conversions during computations
void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
for (auto& node : nn->dataFlow.getMutableNodes()) {
if (!repr::nn::is<repr::ConvTranspose>(node) &&
!repr::nn::is<repr::Conv>(node) && !repr::nn::is<repr::FC>(node)) {
continue;
}
auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
if (!isOnIdeepDevice(*nnOp)) {
LOG(INFO) << "Not a IDEEP operator";
continue;
}
auto inputs = repr::nn::getInputs(node);
if (inputs.size() < 2) {
LOG(WARNING) << "Invalid input size";
continue;
}
auto* filterBlob = getBlob(inputs[1], ws);
auto* filter = getMutableTensor<itensor>(filterBlob);
if (filter == nullptr) {
continue;
}
itensor::descriptor expectedDesc;
if (repr::nn::is<repr::ConvTranspose>(node)) {
if (filter->get_public_format() == ideep::format::iohw)
continue;
auto convTranspose = repr::nn::get<repr::ConvTranspose>(node);
auto initValue = [](vector<int>& v, vector<int> i) {
if (v.empty())
v = i;
};
auto strides = convTranspose->getStrides();
initValue(strides, {1, 1});
auto pads = convTranspose->getPads();
initValue(pads, {0, 0, 0, 0});
auto* op = getMutableOpDef(*convTranspose);
auto aalgorithm = ialgo::deconvolution_direct;
auto dataType = filter->get_data_type();
ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1),
filter->get_dim(0),
filter->get_dim(2),
filter->get_dim(3)};
expectedDesc =
ideep::convolution_transpose_forward::expected_weights_descriptor(
filter_dims_mkldnn,
dataType,
strides,
{pads[0], pads[1]},
{pads[2], pads[3]});
if (filter->get_descriptor() != expectedDesc) {
filter->set_public_format(ideep::format::iohw);
itensor&& newFilter(expectedDesc);
ideep::reorder::compute(*filter, newFilter);
newFilter.set_public_format(ideep::format::iohw);
filterBlob->Reset<itensor>(new itensor(newFilter));
}
} else if (repr::nn::is<repr::Conv>(node)) {
auto conv = repr::nn::get<repr::Conv>(node);
auto initValue = [](vector<int>& v, vector<int> i) {
if (v.empty())
v = i;
};
auto strides = conv->getStrides();
initValue(strides, {1, 1});
auto pads = conv->getPads();
initValue(pads, {0, 0, 0, 0});
auto dilations = conv->getDilations();
initValue(dilations, {1, 1});
auto* op = getMutableOpDef(*conv);
auto aalgorithm = ialgo::convolution_direct;
for (auto& arg : *op->mutable_arg()) {
if ((arg.name() == "conv_algorithm") &&
(arg.i() == CONV_ALGORITHM_WINOGRAD)) {
aalgorithm = ialgo::convolution_winograd;
}
}
auto dataType = filter->get_data_type();
filter->make_group(conv->getGroup());
expectedDesc = ideep::convolution_forward::expected_weights_descriptor(
filter->get_dims(),
dataType,
strides,
{pads[0], pads[1]},
{pads[2], pads[3]},
dilations,
conv->getGroup(),
aalgorithm);
if (filter->get_descriptor() != expectedDesc) {
itensor&& newFilter(expectedDesc);
ideep::reorder::compute(*filter, newFilter);
filterBlob->Reset<itensor>(new itensor(newFilter));
}
// convert weights for FC
} else if (repr::nn::is<repr::FC>(node)) {
auto fc = repr::nn::get<repr::FC>(node);
auto axis_w = fc->getAxisW();
if (axis_w != 1) {
auto f_dims = filter->get_dims();
auto f_dim0 = std::accumulate(
f_dims.begin(),
f_dims.begin() + axis_w,
1,
std::multiplies<itensor::dim_t>());
auto f_dim1 = std::accumulate(
f_dims.begin() + axis_w,
f_dims.end(),
1,
std::multiplies<itensor::dim_t>());
filter->reshape({f_dim0, f_dim1});
}
expectedDesc = ideep::inner_product_forward::expected_weights_descriptor(
filter->get_dims());
if (filter->get_descriptor() != expectedDesc) {
itensor&& newFilter(expectedDesc);
ideep::reorder::compute(filter->as_weights(), newFilter);
filterBlob->Reset<itensor>(new itensor(newFilter));
}
}
}
}
void OptimizeForMkldnn(repr::NNModule *nn, caffe2::Workspace *ws,
bool training_mode) {
if (training_mode) {
// Only support inference so far
preConvertFiltersFormat(nn, ws);
return;
}

View File

@ -8,7 +8,7 @@
namespace caffe2 {
namespace opt {
CAFFE2_API void OptimizeForIdeep(
CAFFE2_API void OptimizeForMkldnn(
nom::repr::NNModule* nn,
caffe2::Workspace* ws,
bool training_mode = false);

View File

@ -9,7 +9,7 @@ from hypothesis import given, settings
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
from caffe2.python.transformations import optimizeForIDEEP
from caffe2.python.transformations import optimizeForMKLDNN
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.ideep_test_util as mu
@ -133,7 +133,7 @@ class ConvTest(hu.HypothesisTestCase):
old_net = caffe2_pb2.NetDef()
old_net.op.extend([op1])
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
optimizeForMKLDNN(net)
workspace.RunOperatorOnce(net.Proto().op[0])
Y1 = workspace.FetchBlob('Y')

View File

@ -10,7 +10,7 @@ import copy
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
from caffe2.python.transformations import optimizeForIDEEP
from caffe2.python.transformations import optimizeForMKLDNN
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.ideep_test_util as mu
@ -103,7 +103,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.FeedBlob('b0', b, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
optimizeForMKLDNN(net)
self.assertTrue(len(net.Proto().op) == 1)
self.assertTrue(net.Proto().op[0].type == "ConvFusion")
workspace.RunOperatorOnce(net.Proto().op[0])
@ -243,7 +243,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.FeedBlob('b0', b, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
optimizeForMKLDNN(net)
self.assertTrue(len(net.Proto().op) == 2)
self.assertTrue(net.Proto().op[1].type == "ConvFusion")
workspace.RunNetOnce(net.Proto())
@ -393,7 +393,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.FeedBlob('b0', b, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
optimizeForMKLDNN(net)
self.assertTrue(len(net.Proto().op) == 2)
self.assertTrue(net.Proto().op[1].type == "ConvFusion")
workspace.RunNetOnce(net.Proto())
@ -481,7 +481,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.FeedBlob('var', var, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
optimizeForMKLDNN(net)
self.assertTrue(len(net.Proto().op) == 1)
self.assertTrue(net.Proto().op[0].type == "Conv")
workspace.RunOperatorOnce(net.Proto().op[0])
@ -562,7 +562,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
workspace.FeedBlob('bias', bias, dc[1])
net = core.Net("net")
net.Proto().CopyFrom(old_net)
optimizeForIDEEP(net)
optimizeForMKLDNN(net)
self.assertTrue(len(net.Proto().op) == 1)
self.assertTrue(net.Proto().op[0].type == "Conv")
workspace.RunOperatorOnce(net.Proto().op[0])

View File

@ -0,0 +1,97 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import hypothesis.strategies as st
from hypothesis import given
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import (
brew,
core,
model_helper,
workspace,
)
from caffe2.python.transformations import optimizeForMKLDNN
import caffe2.python.hypothesis_test_util as hu
@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.")
class PreConvertTest(hu.HypothesisTestCase):
@given(input_channels=st.integers(15, 16),
batch_size=st.integers(1, 3))
def test_preConvert(self, input_channels, batch_size):
def AddModel(model, data):
conv1 = brew.conv(model, data, 'conv1', dim_in=input_channels,
dim_out=10, kernel=3, stride=1, pad=1, training_mode=1)
deconv1 = brew.conv_transpose(model, conv1, 'deconv1', dim_in=10, dim_out=10,
kernel=2, stride=2, pad=0, training_mode=1)
fc1 = brew.fc(model, deconv1, 'fc1', dim_in=10 * 56 * 56, dim_out=3)
softmax = brew.softmax(model, fc1, 'softmax')
return softmax
def AddTrainingOperators(model, softmax, label):
"""Adds training operators to the model."""
# Compute cross entropy between softmax scores and labels
xent = model.LabelCrossEntropy([softmax, label], 'xent')
# Compute the expected loss
loss = model.AveragedLoss(xent, "loss")
# Use the average loss we just computed to add gradient operators to the model
model.AddGradientOperators([loss])
arg_scope = {"order": "NCHW", 'no_bias': False}
# Create the model helper for the train model
device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
with core.DeviceScope(device_opt):
train_model = model_helper.ModelHelper(name="test_train", arg_scope=arg_scope)
# Add the model definition (fc layers, conv layers, softmax, etc.)
softmax = AddModel(train_model, "X")
AddTrainingOperators(train_model, softmax, "label")
X = np.random.rand(
batch_size, input_channels, 28, 28).astype(np.float32) - 0.5
label = np.random.randint(3, size=batch_size).astype(np.int32)
blob_dict = {}
output_dict = {}
output_dict_cosim = {}
old_ws_name = workspace.CurrentWorkspace()
workspace.FeedBlob('X', X)
workspace.FeedBlob('label', label)
workspace.RunNetOnce(train_model.param_init_net)
for op in train_model.net.Proto().op:
if op.type == "Softmax":
break
for j in range(1, len(op.input)):
blob_dict[op.input[j]] = workspace.FetchBlob(op.input[j])
workspace.CreateNet(train_model.net, overwrite=True)
optimizeForMKLDNN(train_model.net, training_mode=True)
workspace.RunNet(train_model.net)
for op in train_model.net.Proto().op:
for blob in op.output:
output_dict[blob] = workspace.FetchBlob(blob)
workspace.SwitchWorkspace("_device_check_", True)
workspace.FeedBlob('X', X)
workspace.FeedBlob('label', label)
for blob in blob_dict.keys():
workspace.FeedBlob(blob, blob_dict[blob])
workspace.CreateNet(train_model.net, overwrite=True)
workspace.RunNet(train_model.net)
for blob in output_dict.keys():
output_dict_cosim[blob] = workspace.FetchBlob(blob)
for blob in output_dict.keys():
if not np.allclose(output_dict[blob], output_dict_cosim[blob], atol=0.001, rtol=0.0001):
print("blob {} error".format(blob))
print(np.max(np.abs(output_dict[blob] - output_dict_cosim[blob])))
self.assertTrue(False)
workspace.ResetWorkspace()
workspace.SwitchWorkspace(old_ws_name)
if __name__ == "__main__":
unittest.main()

View File

@ -9,7 +9,6 @@ from hypothesis import given, settings
import numpy as np
from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace
from caffe2.python.transformations import optimizeForIDEEP
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.ideep_test_util as mu

View File

@ -277,7 +277,7 @@ def fuse_conv_relu(net):
op.device_option.CopyFrom(device_option)
new_net = caffe2_pb2.NetDef()
new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString()))
new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString()))
return new_net

View File

@ -1691,12 +1691,12 @@ void addGlobalMethods(py::module& m) {
// into a python interface in transformations.py
// Prefix the transformation with transform_ to avoid clobbering the
// function namespace.
m.def("transform_optimizeForIDEEP", [](py::bytes def, bool training_mode) {
m.def("transform_optimizeForMKLDNN", [](py::bytes def, bool training_mode) {
caffe2::NetDef proto;
CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast<std::string>(), &proto));
auto nn = caffe2::convertToNNModule(proto);
opt::OptimizeForIdeep(&nn, gWorkspace, training_mode);
opt::OptimizeForMkldnn(&nn, gWorkspace, training_mode);
auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
std::string out;

View File

@ -64,7 +64,8 @@ public:
numpy_type != -1,
"Unsupported ideep memory data type? This usually should not happen "
"since ideep memory usually only do float and double.");
itensor::dims dims = atensor.get_dims();
itensor::dims dims = atensor.get_public_format_dims();
std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
result.copied = force_copy || atensor.need_reorder();

View File

@ -46,9 +46,9 @@ def fuseNNPACKConvRelu(net):
)
def optimizeForIDEEP(net, training_mode = False):
def optimizeForMKLDNN(net, training_mode = False):
net.Proto().ParseFromString(
C.transform_optimizeForIDEEP(net.Proto().SerializeToString(), training_mode)
C.transform_optimizeForMKLDNN(net.Proto().SerializeToString(), training_mode)
)