mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
translator and misc fixes for legacy group convolution, sigh.
This commit is contained in:
@ -15,14 +15,17 @@ class DepthSplitOp final : public Operator<dtype, DeviceContext> {
|
||||
DepthSplitOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<dtype, DeviceContext>(operator_def, ws),
|
||||
order_(StringToStorageOrder(
|
||||
OperatorBase::GetSingleArgument<string>("order", "NHWC"))) {}
|
||||
OperatorBase::GetSingleArgument<string>("order", "NHWC"))),
|
||||
dimensions_(
|
||||
OperatorBase::GetRepeatedArgument<int>("dimensions")) {}
|
||||
bool RunOnDevice() override;
|
||||
|
||||
protected:
|
||||
StorageOrder order_;
|
||||
// Input: X, dimensions
|
||||
vector<int> dimensions_;
|
||||
// Input: X, optionally dimensions
|
||||
// The dimensions are stored in CPU.
|
||||
INPUT_OUTPUT_STATS(2, 2, 1, INT_MAX);
|
||||
INPUT_OUTPUT_STATS(1, 2, 1, INT_MAX);
|
||||
DISABLE_COPY_AND_ASSIGN(DepthSplitOp);
|
||||
};
|
||||
|
||||
@ -49,12 +52,26 @@ class DepthConcatOp final : public Operator<dtype, DeviceContext> {
|
||||
template <typename dtype, class DeviceContext>
|
||||
bool DepthSplitOp<dtype, DeviceContext>::RunOnDevice() {
|
||||
auto& input = Input(0);
|
||||
auto& dimensions =
|
||||
OperatorBase::Input<Tensor<int, CPUContext> >(1);
|
||||
const int* dim_data = dimensions.data();
|
||||
DCHECK_EQ(dimensions.size(), OutputSize());
|
||||
DCHECK_EQ(std::accumulate(dim_data, dim_data + OutputSize(), 0),
|
||||
(order_ == StorageOrder::NCHW ? input.dim(1) : input.dim(3)));
|
||||
const int* dim_data;
|
||||
if (InputSize() == 2) {
|
||||
// We obtain dimensions from the input tensor.
|
||||
CHECK_EQ(dimensions_.size(), 0)
|
||||
<< "If you set dimensions with an input blob, do not pass in "
|
||||
<< "dimensions in the argument.";
|
||||
auto& dimensions_tensor =
|
||||
OperatorBase::Input<Tensor<int, CPUContext> >(1);
|
||||
CHECK_EQ(dimensions_tensor.size(), OutputSize());
|
||||
dim_data = dimensions_tensor.data();
|
||||
} else {
|
||||
// We obtain dimensions from the parameters.
|
||||
CHECK_EQ(dimensions_.size(), OutputSize());
|
||||
dim_data = dimensions_.data();
|
||||
}
|
||||
const int input_channels =
|
||||
(order_ == StorageOrder::NCHW ? input.dim(1) : input.dim(3));
|
||||
CHECK_EQ(std::accumulate(dim_data, dim_data + OutputSize(), 0),
|
||||
input_channels)
|
||||
<< "Dimensions do not match: should be " << input_channels;
|
||||
int input_offset = 0;
|
||||
for (int i = 0; i < OutputSize(); ++i) {
|
||||
auto* output = Output(i);
|
||||
|
@ -2,7 +2,7 @@ from caffe2.proto import caffe2_pb2, caffe2_legacy_pb2
|
||||
from caffe.proto import caffe_pb2
|
||||
from google.protobuf import text_format
|
||||
import numpy as np
|
||||
from pycaffe2 import utils
|
||||
from pycaffe2 import core, utils
|
||||
|
||||
|
||||
def _StateMeetsRule(state, rule):
|
||||
@ -49,7 +49,7 @@ def DeleteDropout(net):
|
||||
Outputs:
|
||||
None. The function works by modifying net in-place.
|
||||
"""
|
||||
for op in net.operators:
|
||||
for op in net.op:
|
||||
if op.type == 'Dropout':
|
||||
op.type = 'Alias'
|
||||
del op.output[1] # output 1 is the dropout mask, which is not needed.
|
||||
@ -73,7 +73,7 @@ class CacaRegistry(object):
|
||||
try:
|
||||
caffe_ops, params = cls.registry_[layer.type](layer, pretrained_blobs)
|
||||
except KeyError as err:
|
||||
raise KeyError('No translator registered for layer: %s' % str(layer))
|
||||
raise KeyError('No translator registered for layer: %s yet.' % str(layer))
|
||||
if caffe_ops is None:
|
||||
return []
|
||||
if type(caffe_ops) is not list:
|
||||
@ -110,7 +110,7 @@ class CacaRegistry(object):
|
||||
# print 'No pretrained layer for layer', layer.name
|
||||
pretrained_blobs = []
|
||||
operators, params = cls.TranslateLayer(layer, pretrained_blobs)
|
||||
net.operators.extend(operators)
|
||||
net.op.extend(operators)
|
||||
net_params.protos.extend(params)
|
||||
return net, net_params
|
||||
|
||||
@ -138,35 +138,72 @@ def AddArgument(op, key, value):
|
||||
|
||||
@CacaRegistry.Register("Convolution")
|
||||
def TranslateConv(layer, pretrained_blobs):
|
||||
param = layer.convolution_param
|
||||
if param.group > 1:
|
||||
return TranslateConvWithGroups(layer, pretrained_blobs)
|
||||
# If there is no odd things, we will basically translate it to a standard
|
||||
# caffe2 op.
|
||||
caffe_op = BaseTranslate(layer, "Conv")
|
||||
output = caffe_op.output[0]
|
||||
caffe_op.input.extend([output + '_w', output + '_b'])
|
||||
param = layer.convolution_param
|
||||
AddArgument(caffe_op, "stride", param.stride)
|
||||
AddArgument(caffe_op, "kernel", param.kernel_size)
|
||||
AddArgument(caffe_op, "pad", param.pad)
|
||||
AddArgument(caffe_op, "order", "NCHW")
|
||||
if param.group > 1:
|
||||
# Now, if the model is grouped convolution, let's do a backward hack and make
|
||||
# things working but in an efficient way by inserting zero parameters. Note
|
||||
# that this is not computationally safe, but grouped convolution is such an
|
||||
# antique technology that we should really deprecate it.
|
||||
n, c, h, w = pretrained_blobs[0].shape
|
||||
g = param.group
|
||||
og = int(n / g)
|
||||
if (og * g != n):
|
||||
raise ValueError("This should not happen")
|
||||
weight = np.zeros((n, c * g, h, w), dtype=np.float32)
|
||||
for i in range(param.group):
|
||||
weight[i * og : (i + 1) * og, i * c : (i+1) * c, :, :] = pretrained_blobs[0][i * og : (i + 1) * og]
|
||||
else:
|
||||
weight = pretrained_blobs[0]
|
||||
weight = utils.NumpyArrayToCaffe2Tensor(weight, output + '_w')
|
||||
weight = utils.NumpyArrayToCaffe2Tensor(pretrained_blobs[0], output + '_w')
|
||||
bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
pretrained_blobs[1].flatten(), output + '_b')
|
||||
# Todo: deal with parameters.
|
||||
return caffe_op, [weight, bias]
|
||||
|
||||
def TranslateConvWithGroups(layer, pretrained_blobs):
|
||||
print ("Legacy warning: convolution with groups seem to be less and less " +
|
||||
"popular, so we no longer have it as a first-class citizen op. " +
|
||||
"Instead, we will simulate it with depth split followed by conv " +
|
||||
"followed by depth concat.")
|
||||
caffe_ops = []
|
||||
caffe_params = []
|
||||
param = layer.convolution_param
|
||||
weight, bias = pretrained_blobs
|
||||
bias = bias.flatten()
|
||||
n, c, h, w = weight.shape
|
||||
g = param.group # group
|
||||
od = int(n / g) # output dimension
|
||||
if (od * g != n):
|
||||
# This should not happen: n should always be divisible by g.
|
||||
raise ValueError("This should not happen.")
|
||||
output = layer.top[0]
|
||||
# first, depth_split
|
||||
depth_split_op = core.CreateOperator("DepthSplit")(
|
||||
layer.bottom[0],
|
||||
['_' + output + '_gconv_split_' + str(i) for i in range(g)],
|
||||
dimensions=[c for i in range(g)],
|
||||
order="NCHW")
|
||||
caffe_ops.append(depth_split_op)
|
||||
# second, convolutions
|
||||
for i in range(g):
|
||||
# convolution layer i
|
||||
this_weight = utils.NumpyArrayToCaffe2Tensor(
|
||||
weight[i * od : (i + 1) * od], output + '_' + str(i) + '_w')
|
||||
this_bias = utils.NumpyArrayToCaffe2Tensor(
|
||||
bias[i * od : (i + 1) * od], output + '_' + str(i) + '_b')
|
||||
conv_op = core.CreateOperator("Conv")(
|
||||
[depth_split_op.output[i], this_weight.name, this_bias.name],
|
||||
['_' + output + '_gconv_conv_' + str(i)],
|
||||
stride=param.stride,
|
||||
kernel=param.kernel_size,
|
||||
pad=param.pad,
|
||||
order="NCHW")
|
||||
caffe_ops.append(conv_op)
|
||||
caffe_params.extend([this_weight, this_bias])
|
||||
# third, depth concat
|
||||
depth_concat_op = core.CreateOperator("DepthConcat")(
|
||||
['_' + output + '_gconv_conv_' + str(i) for i in range(g)],
|
||||
[output, '_' + output + '_gconv_concat_dims'],
|
||||
order="NCHW")
|
||||
caffe_ops.append(depth_concat_op)
|
||||
return caffe_ops, caffe_params
|
||||
|
||||
|
||||
@CacaRegistry.Register("ReLU")
|
||||
def TranslateRelu(layer, pretrained_blobs):
|
||||
return BaseTranslate(layer, "Relu"), []
|
||||
|
@ -45,6 +45,8 @@ if __name__ == '__main__':
|
||||
net, pretrained_params = caffe_translator.TranslateModel(
|
||||
caffenet, caffenet_pretrained)
|
||||
caffe_translator.DeleteDropout(net)
|
||||
with open('data/testdata/caffe_translator/bvlc_reference_caffenet.translatedmodel', 'w') as fid:
|
||||
fid.write(str(net))
|
||||
for param in pretrained_params.protos:
|
||||
workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param))
|
||||
# Let's also feed in the data from the Caffe test code.
|
||||
|
@ -62,8 +62,11 @@ def CreateOperator(operator_type):
|
||||
operator.name = name
|
||||
if type(inputs) is str or type(inputs) is BlobReference:
|
||||
inputs = [inputs]
|
||||
elif type(inputs) is unicode:
|
||||
inputs = [str(inputs)]
|
||||
elif type(inputs) is not list:
|
||||
raise ValueError("Unknown input format: %s." % str(inputs))
|
||||
raise ValueError("Unknown input format: %s of type %s."
|
||||
% (str(inputs), type(inputs)))
|
||||
if type(outputs) is str or type(outputs) is BlobReference:
|
||||
outputs = [outputs]
|
||||
elif type(outputs) is not list:
|
||||
|
@ -40,7 +40,7 @@ def Allreduce(net, blobs, reduced_affix="_reduced", gpu_indices=None):
|
||||
|
||||
def Allreduce2(net, blobs, reduced_affix, gpu_indices):
|
||||
"""Allreduce for 2 gpus.
|
||||
|
||||
|
||||
Algorithm: 0r <- 0 + 1, 1r <- 0r, where r means "reduced"
|
||||
"""
|
||||
a, b = blobs
|
||||
@ -54,7 +54,7 @@ def Allreduce2(net, blobs, reduced_affix, gpu_indices):
|
||||
def Allreduce4(net, blobs, reduced_affix, gpu_indices):
|
||||
"""Allreduce for 4 gpus.
|
||||
|
||||
Algorithm: 2 level reduction.
|
||||
Algorithm: 2 level reduction.
|
||||
0r <- 0 + 1, 2r <- 2 + 3
|
||||
0r <- 0r + 2r
|
||||
2r <- 0r,
|
||||
@ -124,7 +124,7 @@ def Allreduce8(net, blobs, reduced_affix, gpu_indices):
|
||||
|
||||
def AllreduceFallback(net, blobs, reduced_affix, gpu_indices):
|
||||
"""A fallback option for Allreduce with no assumption on p2p.
|
||||
|
||||
|
||||
Algorithm: a flat operation on gpu 0
|
||||
0r <- 0
|
||||
0r <- 0r + i for i in gpu_indices[1:]
|
||||
|
Reference in New Issue
Block a user