[ONNX] Add col2im for opset 18 (#84594)

Opset 18 will be used to introduce suport for ONNX's Col2Im-18 and resolve https://github.com/pytorch/pytorch/issues/84408

Depends: https://github.com/pytorch/pytorch/pull/83201 (CI will fail until ONNX submodule is updated)

as per Faith recommendation, this PR should be merged post ORT 1.13 only
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84594
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/abock, https://github.com/BowenBao
This commit is contained in:
Thiago Crepaldi
2023-02-09 19:54:42 +00:00
committed by PyTorch MergeBot
parent ea98ba02e2
commit a63524684d
6 changed files with 111 additions and 3 deletions

View File

@ -1156,6 +1156,39 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
dim,
)
def test_col2im(self):
# This test can be moved to test/onnx/test_pytorch_onnx_onnxruntime.py when ORT implement ::Col2Im
# Random batched RGB 32x32 image-shaped input tensor of batch size 64
original_image_inputs = torch.randn((64, 3, 32, 32))
output_size = tuple(original_image_inputs.shape[2:])
kernel_size = (1, 2)
dilation = 3
padding = 2
stride = 1
model_im2col = torch.nn.Unfold(
kernel_size, dilation=dilation, padding=padding, stride=stride
)
blocks = model_im2col(original_image_inputs)
model = torch.nn.Fold(
output_size=output_size,
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride,
)
f = io.BytesIO()
torch.onnx.export(model, (blocks,), f, opset_version=18)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(onnx_model.graph.node[-1].op_type, "Col2Im")
self.assertEqual(onnx_model.graph.node[-1].domain, "")
self.assertEqual(len(onnx_model.graph.node[-1].input), 3)
self.assertEqual(onnx_model.graph.node[-1].attribute[0].name, "dilations")
self.assertEqual(onnx_model.graph.node[-1].attribute[1].name, "pads")
self.assertEqual(onnx_model.graph.node[-1].attribute[2].name, "strides")
class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):

View File

@ -44,7 +44,9 @@ from torch.testing._internal.common_utils import skipIfNoLapack
# The min onnx opset version to test for
MIN_ONNX_OPSET_VERSION = 9
# The max onnx opset version to test for
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
MAX_ONNX_OPSET_VERSION = (
_constants.ONNX_MAX_OPSET - 1
) # TODO: ORT does not support opset 18 yet
def _init_test_generalized_rcnn_transform():

View File

@ -59,7 +59,7 @@ namespace onnx_torch = ::torch::onnx;
namespace onnx = ::ONNX_NAMESPACE;
const static int kInvalidOpsetVersion = -1;
const static int kMainOpsetVersion = 17;
const static int kMainOpsetVersion = 18;
// Based on OP_SET_ID_VERSION_MAP in
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
@ -82,6 +82,7 @@ constexpr static std::array<int64_t, kMainOpsetVersion + 1>
8, // opset 15
8, // opset 16
8, // opset 17
8, // opset 18
};
std::string getNodeStackTraceString(const Node* n) {

View File

@ -25,6 +25,7 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
symbolic_opset15,
symbolic_opset16,
symbolic_opset17,
symbolic_opset18,
utils,
)
@ -62,6 +63,7 @@ __all__ = [
"symbolic_opset15",
"symbolic_opset16",
"symbolic_opset17",
"symbolic_opset18",
# Enums
"ExportTypes",
"OperatorExportTypes",

View File

@ -4,7 +4,7 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 17
ONNX_MAX_OPSET = 18
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 14
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9

View File

@ -0,0 +1,70 @@
"""This file exports ONNX ops for opset 18.
Note [ONNX Operators that are added/updated in opset 18]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:
CenterCropPad
Col2Im
Mish
OptionalGetElement
OptionalHasElement
Pad
Resize
ScatterElements
ScatterND
"""
import functools
from typing import Sequence
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = ["col2im"]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
@_beartype.beartype
def col2im(
g,
input: _C.Value,
output_size: _C.Value,
kernel_size: _C.Value,
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
):
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
adjusted_padding = []
for pad in padding:
for _ in range(2):
adjusted_padding.append(pad)
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
if not adjusted_padding:
adjusted_padding = [0, 0] * num_dimensional_axis
if not dilation:
dilation = [1] * num_dimensional_axis
if not stride:
stride = [1] * num_dimensional_axis
return g.op(
"Col2Im",
input,
output_size,
kernel_size,
dilations_i=dilation,
pads_i=adjusted_padding,
strides_i=stride,
)