Add shape inference function for Split (#18838)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18838

It turns out that we don't have shape inference function of `Split` op at all. This diff adds that.

Reviewed By: bertmaher

Differential Revision: D14766871

fbshipit-source-id: 535cb4f24bdada603c76579e00e7a39aee93e19f
This commit is contained in:
Yinghai Lu
2019-04-04 00:19:21 -07:00
committed by Facebook Github Bot
parent 0c237f1383
commit e5e2110a8e
2 changed files with 128 additions and 0 deletions

View File

@ -16,6 +16,76 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer(
}
return std::make_pair(in_dev, out_dev);
}
vector<TensorShape> TensorInferenceForSplit(
const OperatorDef& def,
const vector<TensorShape>& in) {
auto ret_invalid_shape = [&def]() {
vector<TensorShape> out(def.output().size());
for (auto& out_ts : out) {
out_ts.set_unknown_shape(true);
}
return out;
};
// We only support shape inference of Split with 1 input
if (def.input_size() != 1 || in.empty() || in.front().unknown_shape()) {
return ret_invalid_shape();
} else if (def.output_size() == 0) {
return vector<TensorShape>();
}
ArgumentHelper helper(def);
const int axis = helper.HasArgument("axis")
? helper.GetSingleArgument<int>("axis", -1)
: GetDimFromOrderString(
helper.GetSingleArgument<string>("order", "NCHW"));
const int add_axis = helper.HasArgument("axis")
? helper.GetSingleArgument<int>("add_axis", 0)
: 0;
const auto& input = in[0];
const int canonical_axis = canonical_axis_index_(axis, input.dims_size());
const int input_channels = input.dims(canonical_axis);
auto split = helper.GetRepeatedArgument<int>("split");
// Equally split the input into outputs
const int output_size = def.output_size();
if (split.empty()) {
if (!input_channels % output_size) {
LOG(WARNING) << "Input channels (" << input_channels
<< ") should be divisible by number of outputs ("
<< output_size << ")";
return ret_invalid_shape();
}
split.resize(output_size, input_channels / output_size);
} else if (split.size() != output_size) {
LOG(WARNING) << "`split` size (" << split.size()
<< ") should be equal to output size (" << output_size << ")";
return ret_invalid_shape();
}
// Check validity of the split
const int total_channels = add_axis
? def.output_size()
: std::accumulate(split.begin(), split.begin() + output_size, 0);
if (total_channels != input_channels) {
LOG(WARNING) << "Input channels (" << input_channels
<< ") is not equal to total output channels ("
<< total_channels << ")";
return ret_invalid_shape();
}
vector<int> output_dims(input.dims().begin(), input.dims().end());
if (add_axis) {
output_dims.erase(output_dims.begin() + canonical_axis);
}
vector<TensorShape> output_shapes;
for (int i = 0; i < output_size; ++i) {
if (!add_axis) {
output_dims[canonical_axis] = split[i];
}
output_shapes.emplace_back(
CreateTensorShape(output_dims, input.data_type()));
}
return output_shapes;
}
} // namespace.
REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
@ -29,11 +99,15 @@ OPERATOR_SCHEMA(Split)
"split",
"(*Tensor`<int>`*): [OPTIONAL] list of output lengths (see also arg `split`)")
.Arg("axis", "(*int*): axis to split on")
.Arg(
"add_axis",
"*(type: int)* Pass non-zero integer to remove the axis specified in `axis` to all input tensors.")
.Arg("split", "(*Tuple(int)*): length of each output")
.Arg(
"order",
"(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"")
.Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor")
.TensorInferenceFunction(TensorInferenceForSplit)
.DeviceInferenceFunction(splitOpDevInfer)
.SetDoc(R"DOC(
Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts.

View File

@ -214,6 +214,60 @@ TEST(BoundShapeInference, ConcatMissingInput) {
{spec.max_batch_size, 2, 60});
}
TEST(BoundShapeInference, Split) {
NetDef net;
net.add_op()->CopyFrom(CreateOperatorDef(
"Split", "", {"X"}, {"Y0", "Y1"}, {MakeArgument<int>("axis", 1)}));
net.add_op()->CopyFrom(CreateOperatorDef(
"Split",
"",
{"X"},
{"Y2", "Y3", "Y4"},
{MakeArgument<int>("axis", 1),
MakeArgument<std::vector<int>>("split", {4, 30, 14})}));
net.add_op()->CopyFrom(CreateOperatorDef(
"Split",
"",
{"X1"},
{"Y5", "Y6"},
{MakeArgument<int>("axis", 1), MakeArgument<int>("add_axis", 1)}));
BoundShapeSpec spec(20, 1000);
ShapeInfoMap shape_map;
shape_map.emplace(
"X",
makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48}));
shape_map.emplace(
"X1",
makeTensorInfo(ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48}));
BoundShapeInferencer eng(spec);
eng.InferBoundShapeAndType(net, shape_map);
const auto& out_shape = eng.shape_info();
verifyShapeInfo(
out_shape, "X", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
verifyShapeInfo(
out_shape, "X1", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 2, 48});
verifyShapeInfo(
out_shape,
"Y0",
ShapeInfo::DimType::BATCH,
{spec.max_batch_size, 48 / 2});
verifyShapeInfo(
out_shape,
"Y1",
ShapeInfo::DimType::BATCH,
{spec.max_batch_size, 48 / 2});
verifyShapeInfo(
out_shape, "Y2", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 4});
verifyShapeInfo(
out_shape, "Y3", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 30});
verifyShapeInfo(
out_shape, "Y4", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 14});
verifyShapeInfo(
out_shape, "Y5", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
verifyShapeInfo(
out_shape, "Y6", ShapeInfo::DimType::BATCH, {spec.max_batch_size, 48});
}
TEST(BoundShapeInference, FC) {
NetDef net;
net.add_op()->CopyFrom(