mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51521 * Add loop & if node to the list of nodes that could produce sequence type output. * Switch from `[]` to `at()` to avoid segfault of out of range access. Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26203112 Pulled By: SplitInfinity fbshipit-source-id: e990eeed933124b195be0be159271e33fb485063
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3cc46002a3
commit
586c2e8d62
@ -83,7 +83,6 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset$i"
|
||||
done
|
||||
pytest "${args[@]}" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_onnx_shape_inference" \
|
||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset12_IRv4_old_jit_API"
|
||||
fi
|
||||
|
||||
|
@ -18,6 +18,9 @@ from test_pytorch_common import BATCH_SIZE
|
||||
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
|
||||
from typing import List, Tuple, Optional
|
||||
import model_defs.word_language_model as word_language_model
|
||||
|
||||
import onnx
|
||||
|
||||
import torchvision
|
||||
from torchvision import ops
|
||||
from torchvision.models.detection.image_list import ImageList
|
||||
@ -26,7 +29,6 @@ from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionPro
|
||||
from torchvision.models.detection.roi_heads import RoIHeads
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead
|
||||
from collections import OrderedDict
|
||||
import onnx
|
||||
|
||||
def to_numpy(tensor):
|
||||
if tensor.requires_grad:
|
||||
@ -3876,7 +3878,6 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
inputs = torch.zeros(1, 2, 3, dtype=torch.long)
|
||||
self.run_test(model, inputs)
|
||||
|
||||
@skipIfUnsupportedOpsetVersion([13])
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_loop_with_list(self):
|
||||
class ListLoopModel(torch.jit.ScriptModule):
|
||||
@ -6063,7 +6064,6 @@ class TestONNXRuntime(unittest.TestCase):
|
||||
convert_to_onnx(model, input=(box_regression, proposal),
|
||||
example_outputs=outputs, use_new_jit_passes=True)
|
||||
|
||||
@skipIfUnsupportedOpsetVersion([13])
|
||||
def test_initializer_sequence(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, num_classes):
|
||||
@ -6681,12 +6681,5 @@ TestONNXRuntime_opset12_IRv4_old_jit_API = type(str("TestONNXRuntime_opset12_IRv
|
||||
keep_initializers_as_inputs=False,
|
||||
use_new_jit_passes=False))
|
||||
|
||||
|
||||
# opset 12 tests, with _onnx_shape_inference=True.
|
||||
TestONNXRuntime_opset12_onnx_shape_inference = type(str("TestONNXRuntime_opset12_onnx_shape_inference"),
|
||||
(unittest.TestCase,),
|
||||
dict(TestONNXRuntime.__dict__, opset_version=12,
|
||||
onnx_shape_inference=True))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -586,7 +586,8 @@ bool HasSequenceTypeOutput(Node* node) {
|
||||
node->kind() == ::c10::onnx::SequenceInsert ||
|
||||
node->kind() == ::c10::onnx::SequenceEmpty ||
|
||||
node->kind() == ::c10::onnx::SequenceErase ||
|
||||
node->kind() == ::c10::onnx::SequenceConstruct)
|
||||
node->kind() == ::c10::onnx::SequenceConstruct ||
|
||||
node->kind() == ::c10::onnx::Loop || node->kind() == ::c10::onnx::If)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
@ -618,7 +619,7 @@ void ONNXAssignOutputShape(
|
||||
|
||||
if (PyList_Check(elem)) {
|
||||
size_t list_len = PyList_GET_SIZE(elem);
|
||||
if (HasSequenceTypeOutput(graph->outputs()[outputs_index]->node())) {
|
||||
if (HasSequenceTypeOutput(graph->outputs().at(outputs_index)->node())) {
|
||||
if (list_len > 0) {
|
||||
auto& var =
|
||||
reinterpret_cast<THPVariable*>(PyList_GET_ITEM(elem, 0))->cdata;
|
||||
@ -630,15 +631,18 @@ void ONNXAssignOutputShape(
|
||||
var.scalar_type() == new_var.scalar_type(),
|
||||
"Unsupported sequence type in model outputs. ONNX supports sequences of elements of the same data type.");
|
||||
}
|
||||
auto elem_type = graph->outputs()[outputs_index]
|
||||
auto elem_type = graph->outputs()
|
||||
.at(outputs_index)
|
||||
->type()
|
||||
->castRaw<ListType>()
|
||||
->getElementType()
|
||||
->cast<TensorType>();
|
||||
elem_type = elem_type->withScalarType(var.scalar_type());
|
||||
graph->outputs()[outputs_index]->setType(MergeInferredType(
|
||||
graph->outputs()[outputs_index]->type(),
|
||||
ListType::create(elem_type)));
|
||||
graph->outputs()
|
||||
.at(outputs_index)
|
||||
->setType(MergeInferredType(
|
||||
graph->outputs().at(outputs_index)->type(),
|
||||
ListType::create(elem_type)));
|
||||
outputs_index++;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
outputs_index <= graph->outputs().size(),
|
||||
@ -652,9 +656,11 @@ void ONNXAssignOutputShape(
|
||||
PyObject* list_elem = PyList_GET_ITEM(elem, j);
|
||||
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
|
||||
auto& var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
|
||||
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
|
||||
graph->outputs()[outputs_index + j]->type(),
|
||||
TensorType::create(var)));
|
||||
graph->outputs()
|
||||
.at(outputs_index + j)
|
||||
->setType(MergeInferredType(
|
||||
graph->outputs().at(outputs_index + j)->type(),
|
||||
TensorType::create(var)));
|
||||
}
|
||||
outputs_index += list_len;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
@ -669,9 +675,11 @@ void ONNXAssignOutputShape(
|
||||
PyObject* tuple_elem = PyTuple_GET_ITEM(elem, j);
|
||||
TORCH_INTERNAL_ASSERT(THPVariable_Check(tuple_elem));
|
||||
auto& var = reinterpret_cast<THPVariable*>(tuple_elem)->cdata;
|
||||
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
|
||||
graph->outputs()[outputs_index + j]->type(),
|
||||
TensorType::create(var)));
|
||||
graph->outputs()
|
||||
.at(outputs_index + j)
|
||||
->setType(MergeInferredType(
|
||||
graph->outputs().at(outputs_index + j)->type(),
|
||||
TensorType::create(var)));
|
||||
}
|
||||
outputs_index += tuple_len;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
@ -681,7 +689,7 @@ void ONNXAssignOutputShape(
|
||||
} else if (THPVariable_Check(elem)) {
|
||||
at::Tensor var = reinterpret_cast<THPVariable*>(elem)->cdata;
|
||||
ONNXUpdateTypeFromTensor(
|
||||
graph->outputs()[outputs_index], var, onnx_shape_inference);
|
||||
graph->outputs().at(outputs_index), var, onnx_shape_inference);
|
||||
outputs_index++;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
outputs_index <= graph->outputs().size(),
|
||||
@ -700,9 +708,11 @@ void ONNXAssignOutputShape(
|
||||
auto& var =
|
||||
reinterpret_cast<THPVariable*>(PyTuple_GET_ITEM(tuple_elem, 1))
|
||||
->cdata;
|
||||
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
|
||||
graph->outputs()[outputs_index + j]->type(),
|
||||
TensorType::create(var)));
|
||||
graph->outputs()
|
||||
.at(outputs_index + j)
|
||||
->setType(MergeInferredType(
|
||||
graph->outputs().at(outputs_index + j)->type(),
|
||||
TensorType::create(var)));
|
||||
}
|
||||
outputs_index += unrolled_dict.size();
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
|
Reference in New Issue
Block a user