Improved onnx export for 3 onnx ops. (#18512)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18512

Ceil and Floor have been supported since version 6 of ONNX: export them using the native onnx ops instead of an Aten op.
Similarly, support for the Where op has been added in version 9, so we don't need to wrap these op in an Aten op.

Reviewed By: houseroad

Differential Revision: D14635130

fbshipit-source-id: d54a2b6e295074a6214b5939b21051a6735c9958
This commit is contained in:
Benoit Steiner
2019-03-28 08:52:01 -07:00
committed by Facebook Github Bot
parent ffc7158bf2
commit eee760dbd3
4 changed files with 28 additions and 3 deletions

View File

@ -362,7 +362,8 @@ Caffe2Backend::get_special_operators() const {
{"Dropout", &Caffe2Backend::CreateDropout},
{"LRN", &Caffe2Backend::CreateLRN},
{"DynamicSlice", &Caffe2Backend::CreateDynamicSlice},
{"RandomNormal", &Caffe2Backend::CreateRandomNormal}};
{"RandomNormal", &Caffe2Backend::CreateRandomNormal},
{"Where", &Caffe2Backend::CreateWhereOp}};
return kSpecialOperators;
}
@ -580,6 +581,21 @@ Caffe2Ops Caffe2Backend::CreateRandomNormal(
return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx);
}
Caffe2Ops Caffe2Backend::CreateWhereOp(
OnnxNode* onnx_node,
const ConversionContext& ctx) {
// The native Caffe2 op doesn't support broadcasting, so we defer the handling
// of this op to the ATen library that does.
onnx::NodeProto converted;
converted.CopyFrom(onnx_node->node);
converted.set_op_type("ATen");
onnx::AttributeProto* attr = converted.add_attribute();
attr->set_name("operator");
attr->set_s("where");
OnnxNode new_node(converted);
return CommonOnnxNodeToCaffe2Ops(&new_node, ctx);
}
Caffe2Ops Caffe2Backend::CreateReciprocal(
OnnxNode* onnx_node,
const ConversionContext& ctx) {

View File

@ -236,6 +236,8 @@ class CAFFE2_API Caffe2Backend {
OnnxNode* onnx_node,
const ConversionContext& ctx);
Caffe2Ops CreateWhereOp(OnnxNode* onnx_node, const ConversionContext& ctx);
Caffe2Ops CreateBatchNormalization(
OnnxNode* onnx_node,
const ConversionContext& ctx);

View File

@ -52,7 +52,6 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid.
'|test_isnan.*' # Needs implementation
'|test_scatter.*' # Should be similar to ScatterAssign
'|test_constantofshape_int.*' # Needs implementation
'|test_where.*' # Needs implementation
'|test_shrink.*' # Needs implementation
'|test_strnorm.*' # Needs implementation
'|test_nonzero.*' # Needs implementation

View File

@ -548,6 +548,14 @@ def relu(g, input):
return g.op("Relu", input)
def ceil(g, input):
return g.op("Ceil", input)
def floor(g, input):
return g.op("Floor", input)
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
@ -922,7 +930,7 @@ def le(g, input, other):
def where(g, condition, self, other):
return g.op("ATen", condition, self, other, operator_s="where")
return g.op("Where", condition, self, other)
@parse_args('v', 'i', 'i')