mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
support pre-convert filter format for mkldnn training mode and change 'OptimizeForIdeep' to 'OptimizeForMkldnn' (#15171)
Summary: For MKL-DNN,the filter data will be reorderd to primitive format, it takes a lot of time. So the patch provide a method to convert filter format before training. And "OptimizeForIdeep" will be changed to "OptimizeForMkldnn" in this patch. This patch depends on https://github.com/pytorch/pytorch/pull/12866 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15171 Differential Revision: D14590741 Pulled By: yinghai fbshipit-source-id: 07971c9977edac3c8eec08ca2c39cda639683492
This commit is contained in:
committed by
Facebook Github Bot
parent
d73c830e23
commit
e13101e069
@ -9,6 +9,12 @@
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
enum ConvAlgorithm {
|
||||
CONV_ALGORITHM_AUTO = 0,
|
||||
CONV_ALGORITHM_WINOGRAD = 1,
|
||||
CONV_ALGORITHM_MAX = CONV_ALGORITHM_WINOGRAD + 1
|
||||
};
|
||||
|
||||
#define USE_IDEEP_DEF_ALIASES() \
|
||||
using itensor = ideep::tensor; \
|
||||
using iformat = ideep::format; \
|
||||
@ -18,7 +24,4 @@ namespace caffe2 {
|
||||
using iattr = ideep::descriptor_group::attr_t; \
|
||||
using ibn_flag = ideep::batch_normalization_flag;
|
||||
|
||||
const int CONV_ALGORITHM_AUTO = 0;
|
||||
const int CONV_ALGORITHM_WINOGRAD = 1;
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -140,8 +140,30 @@ class ConvConverter : public Converter {
|
||||
~ConvConverter() override {}
|
||||
};
|
||||
|
||||
class ConvTransposeConverter : public Converter {
|
||||
std::unique_ptr<nom::repr::NeuralNetOperator> convertToNeuralNetOperator(
|
||||
const OperatorDef& op) override {
|
||||
std::unique_ptr<repr::NeuralNetOperator> nnOp;
|
||||
auto argMap = getArgumentsFromOperator(op);
|
||||
auto kernelShape = getKernelShape(argMap);
|
||||
nnOp = util::make_unique<repr::ConvTranspose>(kernelShape);
|
||||
auto c = dyn_cast<repr::ConvTranspose>(nnOp.get());
|
||||
|
||||
c->setStrides(getStrides(argMap));
|
||||
c->setPads(getPads(argMap));
|
||||
c->setGroup(getGroup(argMap));
|
||||
|
||||
return nnOp;
|
||||
}
|
||||
// Does not override default converter to OperatorDef
|
||||
|
||||
virtual ~ConvTransposeConverter() {}
|
||||
};
|
||||
|
||||
REGISTER_CONVERTER(Conv, ConvConverter);
|
||||
|
||||
REGISTER_CONVERTER(ConvTranspose, ConvTransposeConverter);
|
||||
|
||||
TRIVIAL_CONVERTER(Relu);
|
||||
REGISTER_CONVERTER(Relu, ReluConverter);
|
||||
|
||||
|
@ -12,7 +12,7 @@ namespace opt {
|
||||
using namespace nom;
|
||||
|
||||
#ifndef CAFFE2_USE_MKLDNN
|
||||
void OptimizeForIdeep(
|
||||
void OptimizeForMkldnn(
|
||||
repr::NNModule* nn,
|
||||
caffe2::Workspace* ws,
|
||||
bool training_mode) {
|
||||
@ -37,6 +37,15 @@ T* getTensor(Blob* blob) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T* getMutableTensor(Blob* blob) {
|
||||
CAFFE_ENFORCE(blob, "Blob is invalid");
|
||||
if (blob->template IsType<T>()) {
|
||||
return blob->template GetMutable<T>();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) {
|
||||
auto annotation = nnOp.getAnnotation();
|
||||
if (annotation == nullptr) {
|
||||
@ -72,7 +81,7 @@ bool shouldFuseConv(const repr::Conv& conv) {
|
||||
return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
|
||||
}
|
||||
|
||||
void removeStopGradientForInference(repr::NNModule *nn) {
|
||||
void removeStopGradientForInference(repr::NNModule* nn) {
|
||||
auto isStopGradientNode = [](const repr::NNGraph::NodeRef& node) {
|
||||
if (!repr::nn::is<repr::NeuralNetOperator>(node)) {
|
||||
return false;
|
||||
@ -430,10 +439,10 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
|
||||
}
|
||||
}
|
||||
|
||||
void setPoolingInferenceMode(repr::NNModule *nn) {
|
||||
void setPoolingInferenceMode(repr::NNModule* nn) {
|
||||
for (auto node_pair : repr::nn::dataIterator<repr::MaxPool>(nn->dataFlow)) {
|
||||
repr::NNGraph::NodeRef maxPoolNode;
|
||||
repr::MaxPool *maxPool;
|
||||
repr::MaxPool* maxPool;
|
||||
std::tie(maxPool, maxPoolNode) = node_pair;
|
||||
|
||||
if (!isOnIdeepDevice(*maxPool)) {
|
||||
@ -441,9 +450,9 @@ void setPoolingInferenceMode(repr::NNModule *nn) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *op = getMutableOpDef(*maxPool);
|
||||
auto* op = getMutableOpDef(*maxPool);
|
||||
bool found_training_mode = false;
|
||||
for (auto &arg : *op->mutable_arg()) {
|
||||
for (auto& arg : *op->mutable_arg()) {
|
||||
if (arg.name() == "training_mode") {
|
||||
arg.set_i(0);
|
||||
found_training_mode = true;
|
||||
@ -452,19 +461,149 @@ void setPoolingInferenceMode(repr::NNModule *nn) {
|
||||
}
|
||||
|
||||
if (!found_training_mode) {
|
||||
auto *arg = op->add_arg();
|
||||
auto* arg = op->add_arg();
|
||||
arg->set_name("training_mode");
|
||||
arg->set_i(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeForIdeep(
|
||||
repr::NNModule* nn,
|
||||
caffe2::Workspace* ws,
|
||||
bool training_mode) {
|
||||
// Pre-convert filters format to expected one here
|
||||
// in order to avoid boring conversions during computations
|
||||
void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
|
||||
for (auto& node : nn->dataFlow.getMutableNodes()) {
|
||||
if (!repr::nn::is<repr::ConvTranspose>(node) &&
|
||||
!repr::nn::is<repr::Conv>(node) && !repr::nn::is<repr::FC>(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
|
||||
if (!isOnIdeepDevice(*nnOp)) {
|
||||
LOG(INFO) << "Not a IDEEP operator";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto inputs = repr::nn::getInputs(node);
|
||||
if (inputs.size() < 2) {
|
||||
LOG(WARNING) << "Invalid input size";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* filterBlob = getBlob(inputs[1], ws);
|
||||
auto* filter = getMutableTensor<itensor>(filterBlob);
|
||||
if (filter == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
itensor::descriptor expectedDesc;
|
||||
if (repr::nn::is<repr::ConvTranspose>(node)) {
|
||||
if (filter->get_public_format() == ideep::format::iohw)
|
||||
continue;
|
||||
auto convTranspose = repr::nn::get<repr::ConvTranspose>(node);
|
||||
auto initValue = [](vector<int>& v, vector<int> i) {
|
||||
if (v.empty())
|
||||
v = i;
|
||||
};
|
||||
auto strides = convTranspose->getStrides();
|
||||
initValue(strides, {1, 1});
|
||||
auto pads = convTranspose->getPads();
|
||||
initValue(pads, {0, 0, 0, 0});
|
||||
auto* op = getMutableOpDef(*convTranspose);
|
||||
auto aalgorithm = ialgo::deconvolution_direct;
|
||||
auto dataType = filter->get_data_type();
|
||||
ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1),
|
||||
filter->get_dim(0),
|
||||
filter->get_dim(2),
|
||||
filter->get_dim(3)};
|
||||
expectedDesc =
|
||||
ideep::convolution_transpose_forward::expected_weights_descriptor(
|
||||
filter_dims_mkldnn,
|
||||
dataType,
|
||||
strides,
|
||||
{pads[0], pads[1]},
|
||||
{pads[2], pads[3]});
|
||||
|
||||
if (filter->get_descriptor() != expectedDesc) {
|
||||
filter->set_public_format(ideep::format::iohw);
|
||||
itensor&& newFilter(expectedDesc);
|
||||
ideep::reorder::compute(*filter, newFilter);
|
||||
newFilter.set_public_format(ideep::format::iohw);
|
||||
filterBlob->Reset<itensor>(new itensor(newFilter));
|
||||
}
|
||||
} else if (repr::nn::is<repr::Conv>(node)) {
|
||||
auto conv = repr::nn::get<repr::Conv>(node);
|
||||
auto initValue = [](vector<int>& v, vector<int> i) {
|
||||
if (v.empty())
|
||||
v = i;
|
||||
};
|
||||
auto strides = conv->getStrides();
|
||||
initValue(strides, {1, 1});
|
||||
auto pads = conv->getPads();
|
||||
initValue(pads, {0, 0, 0, 0});
|
||||
auto dilations = conv->getDilations();
|
||||
initValue(dilations, {1, 1});
|
||||
|
||||
auto* op = getMutableOpDef(*conv);
|
||||
auto aalgorithm = ialgo::convolution_direct;
|
||||
for (auto& arg : *op->mutable_arg()) {
|
||||
if ((arg.name() == "conv_algorithm") &&
|
||||
(arg.i() == CONV_ALGORITHM_WINOGRAD)) {
|
||||
aalgorithm = ialgo::convolution_winograd;
|
||||
}
|
||||
}
|
||||
auto dataType = filter->get_data_type();
|
||||
|
||||
filter->make_group(conv->getGroup());
|
||||
expectedDesc = ideep::convolution_forward::expected_weights_descriptor(
|
||||
filter->get_dims(),
|
||||
dataType,
|
||||
strides,
|
||||
{pads[0], pads[1]},
|
||||
{pads[2], pads[3]},
|
||||
dilations,
|
||||
conv->getGroup(),
|
||||
aalgorithm);
|
||||
|
||||
if (filter->get_descriptor() != expectedDesc) {
|
||||
itensor&& newFilter(expectedDesc);
|
||||
ideep::reorder::compute(*filter, newFilter);
|
||||
filterBlob->Reset<itensor>(new itensor(newFilter));
|
||||
}
|
||||
// convert weights for FC
|
||||
} else if (repr::nn::is<repr::FC>(node)) {
|
||||
auto fc = repr::nn::get<repr::FC>(node);
|
||||
auto axis_w = fc->getAxisW();
|
||||
if (axis_w != 1) {
|
||||
auto f_dims = filter->get_dims();
|
||||
auto f_dim0 = std::accumulate(
|
||||
f_dims.begin(),
|
||||
f_dims.begin() + axis_w,
|
||||
1,
|
||||
std::multiplies<itensor::dim_t>());
|
||||
auto f_dim1 = std::accumulate(
|
||||
f_dims.begin() + axis_w,
|
||||
f_dims.end(),
|
||||
1,
|
||||
std::multiplies<itensor::dim_t>());
|
||||
filter->reshape({f_dim0, f_dim1});
|
||||
}
|
||||
|
||||
expectedDesc = ideep::inner_product_forward::expected_weights_descriptor(
|
||||
filter->get_dims());
|
||||
|
||||
if (filter->get_descriptor() != expectedDesc) {
|
||||
itensor&& newFilter(expectedDesc);
|
||||
ideep::reorder::compute(filter->as_weights(), newFilter);
|
||||
filterBlob->Reset<itensor>(new itensor(newFilter));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeForMkldnn(repr::NNModule *nn, caffe2::Workspace *ws,
|
||||
bool training_mode) {
|
||||
if (training_mode) {
|
||||
// Only support inference so far
|
||||
preConvertFiltersFormat(nn, ws);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace caffe2 {
|
||||
namespace opt {
|
||||
|
||||
CAFFE2_API void OptimizeForIdeep(
|
||||
CAFFE2_API void OptimizeForMkldnn(
|
||||
nom::repr::NNModule* nn,
|
||||
caffe2::Workspace* ws,
|
||||
bool training_mode = false);
|
||||
|
@ -9,7 +9,7 @@ from hypothesis import given, settings
|
||||
import numpy as np
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.transformations import optimizeForIDEEP
|
||||
from caffe2.python.transformations import optimizeForMKLDNN
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.ideep_test_util as mu
|
||||
|
||||
@ -133,7 +133,7 @@ class ConvTest(hu.HypothesisTestCase):
|
||||
old_net = caffe2_pb2.NetDef()
|
||||
old_net.op.extend([op1])
|
||||
net.Proto().CopyFrom(old_net)
|
||||
optimizeForIDEEP(net)
|
||||
optimizeForMKLDNN(net)
|
||||
workspace.RunOperatorOnce(net.Proto().op[0])
|
||||
Y1 = workspace.FetchBlob('Y')
|
||||
|
||||
|
@ -10,7 +10,7 @@ import copy
|
||||
import numpy as np
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.transformations import optimizeForIDEEP
|
||||
from caffe2.python.transformations import optimizeForMKLDNN
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.ideep_test_util as mu
|
||||
|
||||
@ -103,7 +103,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
|
||||
workspace.FeedBlob('b0', b, dc[1])
|
||||
net = core.Net("net")
|
||||
net.Proto().CopyFrom(old_net)
|
||||
optimizeForIDEEP(net)
|
||||
optimizeForMKLDNN(net)
|
||||
self.assertTrue(len(net.Proto().op) == 1)
|
||||
self.assertTrue(net.Proto().op[0].type == "ConvFusion")
|
||||
workspace.RunOperatorOnce(net.Proto().op[0])
|
||||
@ -243,7 +243,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
|
||||
workspace.FeedBlob('b0', b, dc[1])
|
||||
net = core.Net("net")
|
||||
net.Proto().CopyFrom(old_net)
|
||||
optimizeForIDEEP(net)
|
||||
optimizeForMKLDNN(net)
|
||||
self.assertTrue(len(net.Proto().op) == 2)
|
||||
self.assertTrue(net.Proto().op[1].type == "ConvFusion")
|
||||
workspace.RunNetOnce(net.Proto())
|
||||
@ -393,7 +393,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
|
||||
workspace.FeedBlob('b0', b, dc[1])
|
||||
net = core.Net("net")
|
||||
net.Proto().CopyFrom(old_net)
|
||||
optimizeForIDEEP(net)
|
||||
optimizeForMKLDNN(net)
|
||||
self.assertTrue(len(net.Proto().op) == 2)
|
||||
self.assertTrue(net.Proto().op[1].type == "ConvFusion")
|
||||
workspace.RunNetOnce(net.Proto())
|
||||
@ -481,7 +481,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
|
||||
workspace.FeedBlob('var', var, dc[1])
|
||||
net = core.Net("net")
|
||||
net.Proto().CopyFrom(old_net)
|
||||
optimizeForIDEEP(net)
|
||||
optimizeForMKLDNN(net)
|
||||
self.assertTrue(len(net.Proto().op) == 1)
|
||||
self.assertTrue(net.Proto().op[0].type == "Conv")
|
||||
workspace.RunOperatorOnce(net.Proto().op[0])
|
||||
@ -562,7 +562,7 @@ class ConvFusionTest(hu.HypothesisTestCase):
|
||||
workspace.FeedBlob('bias', bias, dc[1])
|
||||
net = core.Net("net")
|
||||
net.Proto().CopyFrom(old_net)
|
||||
optimizeForIDEEP(net)
|
||||
optimizeForMKLDNN(net)
|
||||
self.assertTrue(len(net.Proto().op) == 1)
|
||||
self.assertTrue(net.Proto().op[0].type == "Conv")
|
||||
workspace.RunOperatorOnce(net.Proto().op[0])
|
||||
|
97
caffe2/python/ideep/pre_convert_test.py
Normal file
97
caffe2/python/ideep/pre_convert_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given
|
||||
import numpy as np
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import (
|
||||
brew,
|
||||
core,
|
||||
model_helper,
|
||||
workspace,
|
||||
)
|
||||
from caffe2.python.transformations import optimizeForMKLDNN
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
||||
|
||||
@unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.")
|
||||
class PreConvertTest(hu.HypothesisTestCase):
|
||||
@given(input_channels=st.integers(15, 16),
|
||||
batch_size=st.integers(1, 3))
|
||||
def test_preConvert(self, input_channels, batch_size):
|
||||
def AddModel(model, data):
|
||||
conv1 = brew.conv(model, data, 'conv1', dim_in=input_channels,
|
||||
dim_out=10, kernel=3, stride=1, pad=1, training_mode=1)
|
||||
deconv1 = brew.conv_transpose(model, conv1, 'deconv1', dim_in=10, dim_out=10,
|
||||
kernel=2, stride=2, pad=0, training_mode=1)
|
||||
fc1 = brew.fc(model, deconv1, 'fc1', dim_in=10 * 56 * 56, dim_out=3)
|
||||
softmax = brew.softmax(model, fc1, 'softmax')
|
||||
|
||||
return softmax
|
||||
|
||||
def AddTrainingOperators(model, softmax, label):
|
||||
"""Adds training operators to the model."""
|
||||
# Compute cross entropy between softmax scores and labels
|
||||
xent = model.LabelCrossEntropy([softmax, label], 'xent')
|
||||
# Compute the expected loss
|
||||
loss = model.AveragedLoss(xent, "loss")
|
||||
# Use the average loss we just computed to add gradient operators to the model
|
||||
model.AddGradientOperators([loss])
|
||||
|
||||
arg_scope = {"order": "NCHW", 'no_bias': False}
|
||||
# Create the model helper for the train model
|
||||
device_opt = core.DeviceOption(caffe2_pb2.IDEEP, 0)
|
||||
with core.DeviceScope(device_opt):
|
||||
train_model = model_helper.ModelHelper(name="test_train", arg_scope=arg_scope)
|
||||
# Add the model definition (fc layers, conv layers, softmax, etc.)
|
||||
softmax = AddModel(train_model, "X")
|
||||
AddTrainingOperators(train_model, softmax, "label")
|
||||
|
||||
X = np.random.rand(
|
||||
batch_size, input_channels, 28, 28).astype(np.float32) - 0.5
|
||||
label = np.random.randint(3, size=batch_size).astype(np.int32)
|
||||
blob_dict = {}
|
||||
output_dict = {}
|
||||
output_dict_cosim = {}
|
||||
old_ws_name = workspace.CurrentWorkspace()
|
||||
workspace.FeedBlob('X', X)
|
||||
workspace.FeedBlob('label', label)
|
||||
workspace.RunNetOnce(train_model.param_init_net)
|
||||
for op in train_model.net.Proto().op:
|
||||
if op.type == "Softmax":
|
||||
break
|
||||
for j in range(1, len(op.input)):
|
||||
blob_dict[op.input[j]] = workspace.FetchBlob(op.input[j])
|
||||
|
||||
workspace.CreateNet(train_model.net, overwrite=True)
|
||||
optimizeForMKLDNN(train_model.net, training_mode=True)
|
||||
workspace.RunNet(train_model.net)
|
||||
for op in train_model.net.Proto().op:
|
||||
for blob in op.output:
|
||||
output_dict[blob] = workspace.FetchBlob(blob)
|
||||
|
||||
workspace.SwitchWorkspace("_device_check_", True)
|
||||
workspace.FeedBlob('X', X)
|
||||
workspace.FeedBlob('label', label)
|
||||
for blob in blob_dict.keys():
|
||||
workspace.FeedBlob(blob, blob_dict[blob])
|
||||
workspace.CreateNet(train_model.net, overwrite=True)
|
||||
workspace.RunNet(train_model.net)
|
||||
for blob in output_dict.keys():
|
||||
output_dict_cosim[blob] = workspace.FetchBlob(blob)
|
||||
|
||||
for blob in output_dict.keys():
|
||||
if not np.allclose(output_dict[blob], output_dict_cosim[blob], atol=0.001, rtol=0.0001):
|
||||
print("blob {} error".format(blob))
|
||||
print(np.max(np.abs(output_dict[blob] - output_dict_cosim[blob])))
|
||||
self.assertTrue(False)
|
||||
|
||||
workspace.ResetWorkspace()
|
||||
workspace.SwitchWorkspace(old_ws_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -9,7 +9,6 @@ from hypothesis import given, settings
|
||||
import numpy as np
|
||||
from caffe2.proto import caffe2_pb2
|
||||
from caffe2.python import core, workspace
|
||||
from caffe2.python.transformations import optimizeForIDEEP
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.ideep_test_util as mu
|
||||
|
||||
|
@ -277,7 +277,7 @@ def fuse_conv_relu(net):
|
||||
op.device_option.CopyFrom(device_option)
|
||||
|
||||
new_net = caffe2_pb2.NetDef()
|
||||
new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString()))
|
||||
new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString()))
|
||||
return new_net
|
||||
|
||||
|
||||
|
@ -1691,12 +1691,12 @@ void addGlobalMethods(py::module& m) {
|
||||
// into a python interface in transformations.py
|
||||
// Prefix the transformation with transform_ to avoid clobbering the
|
||||
// function namespace.
|
||||
m.def("transform_optimizeForIDEEP", [](py::bytes def, bool training_mode) {
|
||||
m.def("transform_optimizeForMKLDNN", [](py::bytes def, bool training_mode) {
|
||||
caffe2::NetDef proto;
|
||||
CAFFE_ENFORCE(ParseProtoFromLargeString(def.cast<std::string>(), &proto));
|
||||
|
||||
auto nn = caffe2::convertToNNModule(proto);
|
||||
opt::OptimizeForIdeep(&nn, gWorkspace, training_mode);
|
||||
opt::OptimizeForMkldnn(&nn, gWorkspace, training_mode);
|
||||
auto new_proto = caffe2::convertToCaffe2Proto(nn, proto);
|
||||
|
||||
std::string out;
|
||||
|
@ -64,7 +64,8 @@ public:
|
||||
numpy_type != -1,
|
||||
"Unsupported ideep memory data type? This usually should not happen "
|
||||
"since ideep memory usually only do float and double.");
|
||||
itensor::dims dims = atensor.get_dims();
|
||||
itensor::dims dims = atensor.get_public_format_dims();
|
||||
|
||||
std::vector<npy_intp> npy_dims(dims.begin(), dims.end());
|
||||
|
||||
result.copied = force_copy || atensor.need_reorder();
|
||||
|
@ -46,9 +46,9 @@ def fuseNNPACKConvRelu(net):
|
||||
)
|
||||
|
||||
|
||||
def optimizeForIDEEP(net, training_mode = False):
|
||||
def optimizeForMKLDNN(net, training_mode = False):
|
||||
net.Proto().ParseFromString(
|
||||
C.transform_optimizeForIDEEP(net.Proto().SerializeToString(), training_mode)
|
||||
C.transform_optimizeForMKLDNN(net.Proto().SerializeToString(), training_mode)
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user