From 34662f77c665ee8a8369a7d23fe0560992de92e9 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Tue, 8 Oct 2019 01:56:28 -0700 Subject: [PATCH] Revert D17159707: [pytorch][PR] [ONNX] Fixed Select symbolic to export slice when index = negative one Test Plan: revert-hammer Differential Revision: D17159707 Original commit changeset: 2c3b27542108 fbshipit-source-id: accce910abdbe13270d0f592810a48b1dabe4b01 --- caffe2/onnx/backend.cc | 18 ++++++++------ .../TestOperators.test_dyn_arange.expect | 12 +++++----- .../expect/TestOperators.test_full.expect | 24 +++++++++---------- ...TestOperators.test_upsample_nearest.expect | 24 +++++++++---------- .../TestOperators.test_view_flatten.expect | 24 +++++++++---------- test/onnx/test_onnx_opset.py | 16 ++++++------- test/onnx/test_pytorch_onnx_caffe2.py | 1 + test/onnx/test_pytorch_onnx_onnxruntime.py | 3 ++- torch/onnx/symbolic_opset9.py | 14 +++++------ 9 files changed, 70 insertions(+), 66 deletions(-) diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 284be62b9627..e0943294580c 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -709,16 +709,20 @@ Caffe2Ops Caffe2Backend::CreateGather( std::vector inputs; inputs.emplace_back(node.input(0)); inputs.emplace_back(node.input(1)); - - auto axis = onnx_node->attributes.get("axis", 0L); - caffe2::Argument arg_axis; - arg_axis.set_name("axis"); - arg_axis.set_i(axis); - std::vector outputs; outputs.emplace_back(node.output(0)); - BuildOperator(c2_op, "Gather", inputs, outputs, {arg_axis}); + auto axis = onnx_node->attributes.get("axis", 0L); + if (axis == 0) { + BuildOperator(c2_op, "Gather", inputs, outputs); + } else if (axis == 1) { + BuildOperator(c2_op, "BatchGather", inputs, outputs); + } else { + CAFFE_THROW( + "Caffe2 only supports Gather with axis being 0 or 1, ", + "whereas axis is ", + axis); + } return ret; } diff --git a/test/onnx/expect/TestOperators.test_dyn_arange.expect b/test/onnx/expect/TestOperators.test_dyn_arange.expect index 2dd557d6af4d..552114be962b 100644 --- a/test/onnx/expect/TestOperators.test_dyn_arange.expect +++ b/test/onnx/expect/TestOperators.test_dyn_arange.expect @@ -3,12 +3,7 @@ producer_name: "pytorch" producer_version: "1.3" graph { node { - input: "0" output: "1" - op_type: "Shape" - } - node { - output: "2" op_type: "Constant" attribute { name: "value" @@ -20,8 +15,13 @@ graph { } } node { - input: "1" + input: "0" + output: "2" + op_type: "Shape" + } + node { input: "2" + input: "1" output: "3" op_type: "Gather" attribute { diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect index c11d3e94c1f8..d13667ef8475 100644 --- a/test/onnx/expect/TestOperators.test_full.expect +++ b/test/onnx/expect/TestOperators.test_full.expect @@ -3,12 +3,7 @@ producer_name: "pytorch" producer_version: "1.3" graph { node { - input: "0" output: "1" - op_type: "Shape" - } - node { - output: "2" op_type: "Constant" attribute { name: "value" @@ -20,8 +15,13 @@ graph { } } node { - input: "1" + input: "0" + output: "2" + op_type: "Shape" + } + node { input: "2" + input: "1" output: "3" op_type: "Gather" attribute { @@ -31,12 +31,7 @@ graph { } } node { - input: "0" output: "4" - op_type: "Shape" - } - node { - output: "5" op_type: "Constant" attribute { name: "value" @@ -48,8 +43,13 @@ graph { } } node { - input: "4" + input: "0" + output: "5" + op_type: "Shape" + } + node { input: "5" + input: "4" output: "6" op_type: "Gather" attribute { diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest.expect b/test/onnx/expect/TestOperators.test_upsample_nearest.expect index baf27592822b..16451536e181 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest.expect @@ -3,12 +3,7 @@ producer_name: "pytorch" producer_version: "1.3" graph { node { - input: "input" output: "1" - op_type: "Shape" - } - node { - output: "2" op_type: "Constant" attribute { name: "value" @@ -20,8 +15,13 @@ graph { } } node { - input: "1" + input: "input" + output: "2" + op_type: "Shape" + } + node { input: "2" + input: "1" output: "3" op_type: "Gather" attribute { @@ -74,12 +74,7 @@ graph { op_type: "Floor" } node { - input: "input" output: "9" - op_type: "Shape" - } - node { - output: "10" op_type: "Constant" attribute { name: "value" @@ -91,8 +86,13 @@ graph { } } node { - input: "9" + input: "input" + output: "10" + op_type: "Shape" + } + node { input: "10" + input: "9" output: "11" op_type: "Gather" attribute { diff --git a/test/onnx/expect/TestOperators.test_view_flatten.expect b/test/onnx/expect/TestOperators.test_view_flatten.expect index 9909c19c1526..4d81920b95f7 100644 --- a/test/onnx/expect/TestOperators.test_view_flatten.expect +++ b/test/onnx/expect/TestOperators.test_view_flatten.expect @@ -3,12 +3,7 @@ producer_name: "pytorch" producer_version: "1.3" graph { node { - input: "0" output: "1" - op_type: "Shape" - } - node { - output: "2" op_type: "Constant" attribute { name: "value" @@ -20,8 +15,13 @@ graph { } } node { - input: "1" + input: "0" + output: "2" + op_type: "Shape" + } + node { input: "2" + input: "1" output: "3" op_type: "Gather" attribute { @@ -31,12 +31,7 @@ graph { } } node { - input: "0" output: "4" - op_type: "Shape" - } - node { - output: "5" op_type: "Constant" attribute { name: "value" @@ -48,8 +43,13 @@ graph { } } node { - input: "4" + input: "0" + output: "5" + op_type: "Shape" + } + node { input: "5" + input: "4" output: "6" op_type: "Gather" attribute { diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index cfcd510b5940..0efe3dd875d6 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -186,7 +186,7 @@ class TestONNXOpset(TestCase): "attributes" : []}] ops = {9 : ops_9, 10 : ops_10} x = torch.randn(3) - check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9]) + check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) class DynamicSliceModel(torch.jit.ScriptModule): @torch.jit.script_method @@ -196,7 +196,6 @@ class TestONNXOpset(TestCase): ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Constant"}, {"op_name" : "Shape"}, - {"op_name": "Constant"}, {"op_name" : "Gather", "attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]}, {"op_name" : "Unsqueeze", @@ -209,7 +208,6 @@ class TestONNXOpset(TestCase): ops_10 = [{"op_name" : "Constant"}, {"op_name" : "Constant"}, {"op_name" : "Shape"}, - {"op_name": "Constant"}, {"op_name" : "Gather", "attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]}, {"op_name" : "Unsqueeze", @@ -286,11 +284,11 @@ class TestONNXOpset(TestCase): return torch.nn.functional.interpolate(x, size=size, mode='nearest') - ops_9 = [{"op_name" : "Shape"}, - {"op_name" : "Constant"}, - {"op_name" : "Gather"}, + ops_9 = [{"op_name" : "Constant"}, {"op_name" : "Shape"}, + {"op_name" : "Gather"}, {"op_name" : "Constant"}, + {"op_name" : "Shape"}, {"op_name" : "Gather"}, {"op_name" : "Constant"}, {"op_name" : "Mul"}, @@ -309,11 +307,11 @@ class TestONNXOpset(TestCase): {"op_name" : "Upsample", "attributes" : [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] - ops_10 = [{"op_name" : "Shape"}, - {"op_name" : "Constant"}, - {"op_name" : "Gather"}, + ops_10 = [{"op_name" : "Constant"}, {"op_name" : "Shape"}, + {"op_name" : "Gather"}, {"op_name" : "Constant"}, + {"op_name" : "Shape"}, {"op_name" : "Gather"}, {"op_name" : "Constant"}, {"op_name" : "Mul"}, diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index fac6715f167c..32ff6d97c56f 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1306,6 +1306,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase): x = torch.randn(3, 4, 5, 6, 7) self.run_model_test(NegSlice(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False) + @unittest.skip('https://github.com/pytorch/pytorch/issues/10984') @skipIfUnsupportedOpsetVersion([10]) def test_neg_slice_large_negone(self): class NegSlice(torch.nn.Module): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 11af41143c31..3c741fb8845d 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -345,11 +345,12 @@ class TestONNXRuntime(unittest.TestCase): def test_slice_neg_large(self): class NegSlice(torch.nn.Module): def forward(self, x): - return x[:, :, -3:-1, :, -1] + return x[:, :, :, :, -3] x = torch.randn(3, 4, 5, 6, 7) self.run_test(NegSlice(), x) + @unittest.skip('https://github.com/pytorch/pytorch/issues/10984') def test_slice_neg_large_negone(self): class NegSlice(torch.nn.Module): def forward(self, x): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index a174788005ec..19dd546ecc83 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -394,13 +394,13 @@ def split_with_sizes(g, self, split_sizes, dim): @parse_args('v', 'i', 'v') def select(g, self, dim, index): - index = sym_help._maybe_get_scalar(index) - if (not sym_help._is_value(index)) and (index < 0): - if index == -1: - end_index = 9223372036854775807 - else: - end_index = index + 1 - slice_node = sym_help._slice_helper(g, self, axes=[dim], starts=[index], ends=[end_index]) + if dim > 1: + # TODO: this is a temporary hack because of the implementation details + # of Gather in caffe2. We need to change this as soon as possible. + # TODO: this breaks if index == -1 + index_val = _parse_arg(index, 'i') + slice_node = sym_help._slice_helper(g, self, axes=[dim], + starts=[index_val], ends=[index_val + 1]) return g.op("Squeeze", slice_node, axes_i=[dim]) else: return g.op("Gather", self, index, axis_i=dim)