Revert D17159707: [pytorch][PR] [ONNX] Fixed Select symbolic to export slice when index = negative one

Test Plan: revert-hammer

Differential Revision:
D17159707

Original commit changeset: 2c3b27542108

fbshipit-source-id: accce910abdbe13270d0f592810a48b1dabe4b01
This commit is contained in:
Lu Fang
2019-10-08 01:56:28 -07:00
committed by Facebook Github Bot
parent 1b5df37441
commit 34662f77c6
9 changed files with 70 additions and 66 deletions

View File

@ -709,16 +709,20 @@ Caffe2Ops Caffe2Backend::CreateGather(
std::vector<std::string> inputs;
inputs.emplace_back(node.input(0));
inputs.emplace_back(node.input(1));
auto axis = onnx_node->attributes.get<int64_t>("axis", 0L);
caffe2::Argument arg_axis;
arg_axis.set_name("axis");
arg_axis.set_i(axis);
std::vector<std::string> outputs;
outputs.emplace_back(node.output(0));
BuildOperator(c2_op, "Gather", inputs, outputs, {arg_axis});
auto axis = onnx_node->attributes.get<int64_t>("axis", 0L);
if (axis == 0) {
BuildOperator(c2_op, "Gather", inputs, outputs);
} else if (axis == 1) {
BuildOperator(c2_op, "BatchGather", inputs, outputs);
} else {
CAFFE_THROW(
"Caffe2 only supports Gather with axis being 0 or 1, ",
"whereas axis is ",
axis);
}
return ret;
}

View File

@ -3,12 +3,7 @@ producer_name: "pytorch"
producer_version: "1.3"
graph {
node {
input: "0"
output: "1"
op_type: "Shape"
}
node {
output: "2"
op_type: "Constant"
attribute {
name: "value"
@ -20,8 +15,13 @@ graph {
}
}
node {
input: "1"
input: "0"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {

View File

@ -3,12 +3,7 @@ producer_name: "pytorch"
producer_version: "1.3"
graph {
node {
input: "0"
output: "1"
op_type: "Shape"
}
node {
output: "2"
op_type: "Constant"
attribute {
name: "value"
@ -20,8 +15,13 @@ graph {
}
}
node {
input: "1"
input: "0"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {
@ -31,12 +31,7 @@ graph {
}
}
node {
input: "0"
output: "4"
op_type: "Shape"
}
node {
output: "5"
op_type: "Constant"
attribute {
name: "value"
@ -48,8 +43,13 @@ graph {
}
}
node {
input: "4"
input: "0"
output: "5"
op_type: "Shape"
}
node {
input: "5"
input: "4"
output: "6"
op_type: "Gather"
attribute {

View File

@ -3,12 +3,7 @@ producer_name: "pytorch"
producer_version: "1.3"
graph {
node {
input: "input"
output: "1"
op_type: "Shape"
}
node {
output: "2"
op_type: "Constant"
attribute {
name: "value"
@ -20,8 +15,13 @@ graph {
}
}
node {
input: "1"
input: "input"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {
@ -74,12 +74,7 @@ graph {
op_type: "Floor"
}
node {
input: "input"
output: "9"
op_type: "Shape"
}
node {
output: "10"
op_type: "Constant"
attribute {
name: "value"
@ -91,8 +86,13 @@ graph {
}
}
node {
input: "9"
input: "input"
output: "10"
op_type: "Shape"
}
node {
input: "10"
input: "9"
output: "11"
op_type: "Gather"
attribute {

View File

@ -3,12 +3,7 @@ producer_name: "pytorch"
producer_version: "1.3"
graph {
node {
input: "0"
output: "1"
op_type: "Shape"
}
node {
output: "2"
op_type: "Constant"
attribute {
name: "value"
@ -20,8 +15,13 @@ graph {
}
}
node {
input: "1"
input: "0"
output: "2"
op_type: "Shape"
}
node {
input: "2"
input: "1"
output: "3"
op_type: "Gather"
attribute {
@ -31,12 +31,7 @@ graph {
}
}
node {
input: "0"
output: "4"
op_type: "Shape"
}
node {
output: "5"
op_type: "Constant"
attribute {
name: "value"
@ -48,8 +43,13 @@ graph {
}
}
node {
input: "4"
input: "0"
output: "5"
op_type: "Shape"
}
node {
input: "5"
input: "4"
output: "6"
op_type: "Gather"
attribute {

View File

@ -186,7 +186,7 @@ class TestONNXOpset(TestCase):
"attributes" : []}]
ops = {9 : ops_9, 10 : ops_10}
x = torch.randn(3)
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9])
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
class DynamicSliceModel(torch.jit.ScriptModule):
@torch.jit.script_method
@ -196,7 +196,6 @@ class TestONNXOpset(TestCase):
ops_9 = [{"op_name" : "Constant"},
{"op_name" : "Constant"},
{"op_name" : "Shape"},
{"op_name": "Constant"},
{"op_name" : "Gather",
"attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]},
{"op_name" : "Unsqueeze",
@ -209,7 +208,6 @@ class TestONNXOpset(TestCase):
ops_10 = [{"op_name" : "Constant"},
{"op_name" : "Constant"},
{"op_name" : "Shape"},
{"op_name": "Constant"},
{"op_name" : "Gather",
"attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]},
{"op_name" : "Unsqueeze",
@ -286,11 +284,11 @@ class TestONNXOpset(TestCase):
return torch.nn.functional.interpolate(x,
size=size,
mode='nearest')
ops_9 = [{"op_name" : "Shape"},
{"op_name" : "Constant"},
{"op_name" : "Gather"},
ops_9 = [{"op_name" : "Constant"},
{"op_name" : "Shape"},
{"op_name" : "Gather"},
{"op_name" : "Constant"},
{"op_name" : "Shape"},
{"op_name" : "Gather"},
{"op_name" : "Constant"},
{"op_name" : "Mul"},
@ -309,11 +307,11 @@ class TestONNXOpset(TestCase):
{"op_name" : "Upsample",
"attributes" :
[{"name": "mode", "s": ("nearest").encode(), "type": 3}]}]
ops_10 = [{"op_name" : "Shape"},
{"op_name" : "Constant"},
{"op_name" : "Gather"},
ops_10 = [{"op_name" : "Constant"},
{"op_name" : "Shape"},
{"op_name" : "Gather"},
{"op_name" : "Constant"},
{"op_name" : "Shape"},
{"op_name" : "Gather"},
{"op_name" : "Constant"},
{"op_name" : "Mul"},

View File

@ -1306,6 +1306,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.randn(3, 4, 5, 6, 7)
self.run_model_test(NegSlice(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
@unittest.skip('https://github.com/pytorch/pytorch/issues/10984')
@skipIfUnsupportedOpsetVersion([10])
def test_neg_slice_large_negone(self):
class NegSlice(torch.nn.Module):

View File

@ -345,11 +345,12 @@ class TestONNXRuntime(unittest.TestCase):
def test_slice_neg_large(self):
class NegSlice(torch.nn.Module):
def forward(self, x):
return x[:, :, -3:-1, :, -1]
return x[:, :, :, :, -3]
x = torch.randn(3, 4, 5, 6, 7)
self.run_test(NegSlice(), x)
@unittest.skip('https://github.com/pytorch/pytorch/issues/10984')
def test_slice_neg_large_negone(self):
class NegSlice(torch.nn.Module):
def forward(self, x):

View File

@ -394,13 +394,13 @@ def split_with_sizes(g, self, split_sizes, dim):
@parse_args('v', 'i', 'v')
def select(g, self, dim, index):
index = sym_help._maybe_get_scalar(index)
if (not sym_help._is_value(index)) and (index < 0):
if index == -1:
end_index = 9223372036854775807
else:
end_index = index + 1
slice_node = sym_help._slice_helper(g, self, axes=[dim], starts=[index], ends=[end_index])
if dim > 1:
# TODO: this is a temporary hack because of the implementation details
# of Gather in caffe2. We need to change this as soon as possible.
# TODO: this breaks if index == -1
index_val = _parse_arg(index, 'i')
slice_node = sym_help._slice_helper(g, self, axes=[dim],
starts=[index_val], ends=[index_val + 1])
return g.op("Squeeze", slice_node, axes_i=[dim])
else:
return g.op("Gather", self, index, axis_i=dim)