mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert D17159707: [pytorch][PR] [ONNX] Fixed Select symbolic to export slice when index = negative one
Test Plan: revert-hammer Differential Revision: D17159707 Original commit changeset: 2c3b27542108 fbshipit-source-id: accce910abdbe13270d0f592810a48b1dabe4b01
This commit is contained in:
committed by
Facebook Github Bot
parent
1b5df37441
commit
34662f77c6
@ -709,16 +709,20 @@ Caffe2Ops Caffe2Backend::CreateGather(
|
||||
std::vector<std::string> inputs;
|
||||
inputs.emplace_back(node.input(0));
|
||||
inputs.emplace_back(node.input(1));
|
||||
|
||||
auto axis = onnx_node->attributes.get<int64_t>("axis", 0L);
|
||||
caffe2::Argument arg_axis;
|
||||
arg_axis.set_name("axis");
|
||||
arg_axis.set_i(axis);
|
||||
|
||||
std::vector<std::string> outputs;
|
||||
outputs.emplace_back(node.output(0));
|
||||
|
||||
BuildOperator(c2_op, "Gather", inputs, outputs, {arg_axis});
|
||||
auto axis = onnx_node->attributes.get<int64_t>("axis", 0L);
|
||||
if (axis == 0) {
|
||||
BuildOperator(c2_op, "Gather", inputs, outputs);
|
||||
} else if (axis == 1) {
|
||||
BuildOperator(c2_op, "BatchGather", inputs, outputs);
|
||||
} else {
|
||||
CAFFE_THROW(
|
||||
"Caffe2 only supports Gather with axis being 0 or 1, ",
|
||||
"whereas axis is ",
|
||||
axis);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -3,12 +3,7 @@ producer_name: "pytorch"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
output: "1"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -20,8 +15,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "1"
|
||||
input: "0"
|
||||
output: "2"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
input: "1"
|
||||
output: "3"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
|
@ -3,12 +3,7 @@ producer_name: "pytorch"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
output: "1"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -20,8 +15,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "1"
|
||||
input: "0"
|
||||
output: "2"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
input: "1"
|
||||
output: "3"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
@ -31,12 +31,7 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "0"
|
||||
output: "4"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "5"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -48,8 +43,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "4"
|
||||
input: "0"
|
||||
output: "5"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "5"
|
||||
input: "4"
|
||||
output: "6"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
|
@ -3,12 +3,7 @@ producer_name: "pytorch"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "input"
|
||||
output: "1"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -20,8 +15,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "1"
|
||||
input: "input"
|
||||
output: "2"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
input: "1"
|
||||
output: "3"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
@ -74,12 +74,7 @@ graph {
|
||||
op_type: "Floor"
|
||||
}
|
||||
node {
|
||||
input: "input"
|
||||
output: "9"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "10"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -91,8 +86,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "9"
|
||||
input: "input"
|
||||
output: "10"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "10"
|
||||
input: "9"
|
||||
output: "11"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
|
@ -3,12 +3,7 @@ producer_name: "pytorch"
|
||||
producer_version: "1.3"
|
||||
graph {
|
||||
node {
|
||||
input: "0"
|
||||
output: "1"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -20,8 +15,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "1"
|
||||
input: "0"
|
||||
output: "2"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "2"
|
||||
input: "1"
|
||||
output: "3"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
@ -31,12 +31,7 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "0"
|
||||
output: "4"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "5"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
@ -48,8 +43,13 @@ graph {
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "4"
|
||||
input: "0"
|
||||
output: "5"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "5"
|
||||
input: "4"
|
||||
output: "6"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
|
@ -186,7 +186,7 @@ class TestONNXOpset(TestCase):
|
||||
"attributes" : []}]
|
||||
ops = {9 : ops_9, 10 : ops_10}
|
||||
x = torch.randn(3)
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9])
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
|
||||
|
||||
class DynamicSliceModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
@ -196,7 +196,6 @@ class TestONNXOpset(TestCase):
|
||||
ops_9 = [{"op_name" : "Constant"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Shape"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name" : "Gather",
|
||||
"attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]},
|
||||
{"op_name" : "Unsqueeze",
|
||||
@ -209,7 +208,6 @@ class TestONNXOpset(TestCase):
|
||||
ops_10 = [{"op_name" : "Constant"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Shape"},
|
||||
{"op_name": "Constant"},
|
||||
{"op_name" : "Gather",
|
||||
"attributes" : [{"name" : "axis", "i" : 0, "type" : 2}]},
|
||||
{"op_name" : "Unsqueeze",
|
||||
@ -286,11 +284,11 @@ class TestONNXOpset(TestCase):
|
||||
return torch.nn.functional.interpolate(x,
|
||||
size=size,
|
||||
mode='nearest')
|
||||
ops_9 = [{"op_name" : "Shape"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Gather"},
|
||||
ops_9 = [{"op_name" : "Constant"},
|
||||
{"op_name" : "Shape"},
|
||||
{"op_name" : "Gather"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Shape"},
|
||||
{"op_name" : "Gather"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Mul"},
|
||||
@ -309,11 +307,11 @@ class TestONNXOpset(TestCase):
|
||||
{"op_name" : "Upsample",
|
||||
"attributes" :
|
||||
[{"name": "mode", "s": ("nearest").encode(), "type": 3}]}]
|
||||
ops_10 = [{"op_name" : "Shape"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Gather"},
|
||||
ops_10 = [{"op_name" : "Constant"},
|
||||
{"op_name" : "Shape"},
|
||||
{"op_name" : "Gather"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Shape"},
|
||||
{"op_name" : "Gather"},
|
||||
{"op_name" : "Constant"},
|
||||
{"op_name" : "Mul"},
|
||||
|
@ -1306,6 +1306,7 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||
x = torch.randn(3, 4, 5, 6, 7)
|
||||
self.run_model_test(NegSlice(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
|
||||
|
||||
@unittest.skip('https://github.com/pytorch/pytorch/issues/10984')
|
||||
@skipIfUnsupportedOpsetVersion([10])
|
||||
def test_neg_slice_large_negone(self):
|
||||
class NegSlice(torch.nn.Module):
|
||||
|
@ -345,11 +345,12 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
def test_slice_neg_large(self):
|
||||
class NegSlice(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x[:, :, -3:-1, :, -1]
|
||||
return x[:, :, :, :, -3]
|
||||
|
||||
x = torch.randn(3, 4, 5, 6, 7)
|
||||
self.run_test(NegSlice(), x)
|
||||
|
||||
@unittest.skip('https://github.com/pytorch/pytorch/issues/10984')
|
||||
def test_slice_neg_large_negone(self):
|
||||
class NegSlice(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -394,13 +394,13 @@ def split_with_sizes(g, self, split_sizes, dim):
|
||||
|
||||
@parse_args('v', 'i', 'v')
|
||||
def select(g, self, dim, index):
|
||||
index = sym_help._maybe_get_scalar(index)
|
||||
if (not sym_help._is_value(index)) and (index < 0):
|
||||
if index == -1:
|
||||
end_index = 9223372036854775807
|
||||
else:
|
||||
end_index = index + 1
|
||||
slice_node = sym_help._slice_helper(g, self, axes=[dim], starts=[index], ends=[end_index])
|
||||
if dim > 1:
|
||||
# TODO: this is a temporary hack because of the implementation details
|
||||
# of Gather in caffe2. We need to change this as soon as possible.
|
||||
# TODO: this breaks if index == -1
|
||||
index_val = _parse_arg(index, 'i')
|
||||
slice_node = sym_help._slice_helper(g, self, axes=[dim],
|
||||
starts=[index_val], ends=[index_val + 1])
|
||||
return g.op("Squeeze", slice_node, axes_i=[dim])
|
||||
else:
|
||||
return g.op("Gather", self, index, axis_i=dim)
|
||||
|
Reference in New Issue
Block a user