[ONNX] Fix graph sequence output from loop node (#51305) (#51521)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51521

* Add loop & if node to the list of nodes that could produce sequence type output.
* Switch from `[]` to `at()` to avoid segfault of out of range access.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26203112

Pulled By: SplitInfinity

fbshipit-source-id: e990eeed933124b195be0be159271e33fb485063
This commit is contained in:
BowenBao
2021-02-04 12:35:27 -08:00
committed by Facebook GitHub Bot
parent 3cc46002a3
commit 586c2e8d62
3 changed files with 29 additions and 27 deletions

View File

@ -83,7 +83,6 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
done
pytest "${args[@]}" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference" \
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_IRv4_old_jit_API"
fi

View File

@ -18,6 +18,9 @@ from test_pytorch_common import BATCH_SIZE
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
from typing import List, Tuple, Optional
import model_defs.word_language_model as word_language_model
import onnx
import torchvision
from torchvision import ops
from torchvision.models.detection.image_list import ImageList
@ -26,7 +29,6 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
from collections import OrderedDict
import onnx
def to_numpy(tensor):
if tensor.requires_grad:
@ -3876,7 +3878,6 @@ class TestONNXRuntime(unittest.TestCase):
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
self.run_test(model, inputs)
@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_loop_with_list(self):
class ListLoopModel(torch.jit.ScriptModule):
@ -6063,7 +6064,6 @@ class TestONNXRuntime(unittest.TestCase):
convert_to_onnx(model, input=(box_regression, proposal),
example_outputs=outputs, use_new_jit_passes=True)
@skipIfUnsupportedOpsetVersion([13])
def test_initializer_sequence(self):
class MyModule(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
@ -6681,12 +6681,5 @@ TestONNXRuntime_opset12_IRv4_old_jit_API = type(str("TestONNXRuntime_opset12_IRv
keep_initializers_as_inputs=False,
use_new_jit_passes=False))
# opset 12 tests, with _onnx_shape_inference=True.
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
(unittest.TestCase,),
dict(TestONNXRuntime.__dict__, opset_version=12,
onnx_shape_inference=True))
if __name__ == '__main__':
unittest.main()

View File

@ -586,7 +586,8 @@ bool HasSequenceTypeOutput(Node* node) {
node->kind() == ::c10::onnx::SequenceInsert ||
node->kind() == ::c10::onnx::SequenceEmpty ||
node->kind() == ::c10::onnx::SequenceErase ||
node->kind() == ::c10::onnx::SequenceConstruct)
node->kind() == ::c10::onnx::SequenceConstruct ||
node->kind() == ::c10::onnx::Loop || node->kind() == ::c10::onnx::If)
return true;
return false;
}
@ -618,7 +619,7 @@ void ONNXAssignOutputShape(
if (PyList_Check(elem)) {
size_t list_len = PyList_GET_SIZE(elem);
if (HasSequenceTypeOutput(graph->outputs()[outputs_index]->node())) {
if (HasSequenceTypeOutput(graph->outputs().at(outputs_index)->node())) {
if (list_len > 0) {
auto& var =
reinterpret_cast<THPVariable*>(PyList_GET_ITEM(elem, 0))->cdata;
@ -630,15 +631,18 @@ void ONNXAssignOutputShape(
var.scalar_type() == new_var.scalar_type(),
"Unsupported sequence type in model outputs. ONNX supports sequences of elements of the same data type.");
}
auto elem_type = graph->outputs()[outputs_index]
auto elem_type = graph->outputs()
.at(outputs_index)
->type()
->castRaw<ListType>()
->getElementType()
->cast<TensorType>();
elem_type = elem_type->withScalarType(var.scalar_type());
graph->outputs()[outputs_index]->setType(MergeInferredType(
graph->outputs()[outputs_index]->type(),
ListType::create(elem_type)));
graph->outputs()
.at(outputs_index)
->setType(MergeInferredType(
graph->outputs().at(outputs_index)->type(),
ListType::create(elem_type)));
outputs_index++;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
@ -652,9 +656,11 @@ void ONNXAssignOutputShape(
PyObject* list_elem = PyList_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
auto& var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
graph->outputs()
.at(outputs_index + j)
->setType(MergeInferredType(
graph->outputs().at(outputs_index + j)->type(),
TensorType::create(var)));
}
outputs_index += list_len;
TORCH_INTERNAL_ASSERT(
@ -669,9 +675,11 @@ void ONNXAssignOutputShape(
PyObject* tuple_elem = PyTuple_GET_ITEM(elem, j);
TORCH_INTERNAL_ASSERT(THPVariable_Check(tuple_elem));
auto& var = reinterpret_cast<THPVariable*>(tuple_elem)->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
graph->outputs()
.at(outputs_index + j)
->setType(MergeInferredType(
graph->outputs().at(outputs_index + j)->type(),
TensorType::create(var)));
}
outputs_index += tuple_len;
TORCH_INTERNAL_ASSERT(
@ -681,7 +689,7 @@ void ONNXAssignOutputShape(
} else if (THPVariable_Check(elem)) {
at::Tensor var = reinterpret_cast<THPVariable*>(elem)->cdata;
ONNXUpdateTypeFromTensor(
graph->outputs()[outputs_index], var, onnx_shape_inference);
graph->outputs().at(outputs_index), var, onnx_shape_inference);
outputs_index++;
TORCH_INTERNAL_ASSERT(
outputs_index <= graph->outputs().size(),
@ -700,9 +708,11 @@ void ONNXAssignOutputShape(
auto& var =
reinterpret_cast<THPVariable*>(PyTuple_GET_ITEM(tuple_elem, 1))
->cdata;
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
graph->outputs()[outputs_index + j]->type(),
TensorType::create(var)));
graph->outputs()
.at(outputs_index + j)
->setType(MergeInferredType(
graph->outputs().at(outputs_index + j)->type(),
TensorType::create(var)));
}
outputs_index += unrolled_dict.size();
TORCH_INTERNAL_ASSERT(