mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Add col2im for opset 18 (#84594)
Opset 18 will be used to introduce suport for ONNX's Col2Im-18 and resolve https://github.com/pytorch/pytorch/issues/84408 Depends: https://github.com/pytorch/pytorch/pull/83201 (CI will fail until ONNX submodule is updated) as per Faith recommendation, this PR should be merged post ORT 1.13 only Pull Request resolved: https://github.com/pytorch/pytorch/pull/84594 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/abock, https://github.com/BowenBao
This commit is contained in:
committed by
PyTorch MergeBot
parent
ea98ba02e2
commit
a63524684d
@ -1156,6 +1156,39 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
||||
dim,
|
||||
)
|
||||
|
||||
def test_col2im(self):
|
||||
# This test can be moved to test/onnx/test_pytorch_onnx_onnxruntime.py when ORT implement ::Col2Im
|
||||
|
||||
# Random batched RGB 32x32 image-shaped input tensor of batch size 64
|
||||
original_image_inputs = torch.randn((64, 3, 32, 32))
|
||||
output_size = tuple(original_image_inputs.shape[2:])
|
||||
kernel_size = (1, 2)
|
||||
dilation = 3
|
||||
padding = 2
|
||||
stride = 1
|
||||
model_im2col = torch.nn.Unfold(
|
||||
kernel_size, dilation=dilation, padding=padding, stride=stride
|
||||
)
|
||||
blocks = model_im2col(original_image_inputs)
|
||||
|
||||
model = torch.nn.Fold(
|
||||
output_size=output_size,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
)
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(model, (blocks,), f, opset_version=18)
|
||||
|
||||
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
self.assertEqual(onnx_model.graph.node[-1].op_type, "Col2Im")
|
||||
self.assertEqual(onnx_model.graph.node[-1].domain, "")
|
||||
self.assertEqual(len(onnx_model.graph.node[-1].input), 3)
|
||||
self.assertEqual(onnx_model.graph.node[-1].attribute[0].name, "dilations")
|
||||
self.assertEqual(onnx_model.graph.node[-1].attribute[1].name, "pads")
|
||||
self.assertEqual(onnx_model.graph.node[-1].attribute[2].name, "strides")
|
||||
|
||||
|
||||
class TestQuantizeEagerONNXExport(common_utils.TestCase):
|
||||
def _test_lower_graph_impl(self, model, data):
|
||||
|
@ -44,7 +44,9 @@ from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
# The min onnx opset version to test for
|
||||
MIN_ONNX_OPSET_VERSION = 9
|
||||
# The max onnx opset version to test for
|
||||
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
|
||||
MAX_ONNX_OPSET_VERSION = (
|
||||
_constants.ONNX_MAX_OPSET - 1
|
||||
) # TODO: ORT does not support opset 18 yet
|
||||
|
||||
|
||||
def _init_test_generalized_rcnn_transform():
|
||||
|
@ -59,7 +59,7 @@ namespace onnx_torch = ::torch::onnx;
|
||||
namespace onnx = ::ONNX_NAMESPACE;
|
||||
|
||||
const static int kInvalidOpsetVersion = -1;
|
||||
const static int kMainOpsetVersion = 17;
|
||||
const static int kMainOpsetVersion = 18;
|
||||
// Based on OP_SET_ID_VERSION_MAP in
|
||||
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
|
||||
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
||||
@ -82,6 +82,7 @@ constexpr static std::array<int64_t, kMainOpsetVersion + 1>
|
||||
8, // opset 15
|
||||
8, // opset 16
|
||||
8, // opset 17
|
||||
8, // opset 18
|
||||
};
|
||||
|
||||
std::string getNodeStackTraceString(const Node* n) {
|
||||
|
@ -25,6 +25,7 @@ from . import ( # usort:skip. Keep the order instead of sorting lexicographical
|
||||
symbolic_opset15,
|
||||
symbolic_opset16,
|
||||
symbolic_opset17,
|
||||
symbolic_opset18,
|
||||
utils,
|
||||
)
|
||||
|
||||
@ -62,6 +63,7 @@ __all__ = [
|
||||
"symbolic_opset15",
|
||||
"symbolic_opset16",
|
||||
"symbolic_opset17",
|
||||
"symbolic_opset18",
|
||||
# Enums
|
||||
"ExportTypes",
|
||||
"OperatorExportTypes",
|
||||
|
@ -4,7 +4,7 @@ ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
|
||||
|
||||
ONNX_BASE_OPSET = 9
|
||||
ONNX_MIN_OPSET = 7
|
||||
ONNX_MAX_OPSET = 17
|
||||
ONNX_MAX_OPSET = 18
|
||||
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
|
||||
ONNX_DEFAULT_OPSET = 14
|
||||
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9
|
||||
|
70
torch/onnx/symbolic_opset18.py
Normal file
70
torch/onnx/symbolic_opset18.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""This file exports ONNX ops for opset 18.
|
||||
|
||||
Note [ONNX Operators that are added/updated in opset 18]
|
||||
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
|
||||
New operators:
|
||||
CenterCropPad
|
||||
Col2Im
|
||||
Mish
|
||||
OptionalGetElement
|
||||
OptionalHasElement
|
||||
Pad
|
||||
Resize
|
||||
ScatterElements
|
||||
ScatterND
|
||||
"""
|
||||
|
||||
import functools
|
||||
from typing import Sequence
|
||||
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._internal import _beartype, registration
|
||||
|
||||
# EDITING THIS FILE? READ THIS FIRST!
|
||||
# see Note [Edit Symbolic Files] in symbolic_helper.py
|
||||
|
||||
__all__ = ["col2im"]
|
||||
|
||||
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
|
||||
|
||||
|
||||
@_onnx_symbolic("aten::col2im")
|
||||
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
|
||||
@_beartype.beartype
|
||||
def col2im(
|
||||
g,
|
||||
input: _C.Value,
|
||||
output_size: _C.Value,
|
||||
kernel_size: _C.Value,
|
||||
dilation: Sequence[int],
|
||||
padding: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
):
|
||||
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
|
||||
adjusted_padding = []
|
||||
for pad in padding:
|
||||
for _ in range(2):
|
||||
adjusted_padding.append(pad)
|
||||
|
||||
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
|
||||
if not adjusted_padding:
|
||||
adjusted_padding = [0, 0] * num_dimensional_axis
|
||||
|
||||
if not dilation:
|
||||
dilation = [1] * num_dimensional_axis
|
||||
|
||||
if not stride:
|
||||
stride = [1] * num_dimensional_axis
|
||||
|
||||
return g.op(
|
||||
"Col2Im",
|
||||
input,
|
||||
output_size,
|
||||
kernel_size,
|
||||
dilations_i=dilation,
|
||||
pads_i=adjusted_padding,
|
||||
strides_i=stride,
|
||||
)
|
Reference in New Issue
Block a user