Files
pytorch/torch/onnx/symbolic_opset18.py
Thiago Crepaldi a63524684d [ONNX] Add col2im for opset 18 (#84594)
Opset 18 will be used to introduce suport for ONNX's Col2Im-18 and resolve https://github.com/pytorch/pytorch/issues/84408

Depends: https://github.com/pytorch/pytorch/pull/83201 (CI will fail until ONNX submodule is updated)

as per Faith recommendation, this PR should be merged post ORT 1.13 only
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84594
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/abock, https://github.com/BowenBao
2023-02-09 19:54:42 +00:00

71 lines
1.7 KiB
Python

"""This file exports ONNX ops for opset 18.
Note [ONNX Operators that are added/updated in opset 18]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
New operators:
CenterCropPad
Col2Im
Mish
OptionalGetElement
OptionalHasElement
Pad
Resize
ScatterElements
ScatterND
"""
import functools
from typing import Sequence
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx._internal import _beartype, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
__all__ = ["col2im"]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
@_onnx_symbolic("aten::col2im")
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
@_beartype.beartype
def col2im(
g,
input: _C.Value,
output_size: _C.Value,
kernel_size: _C.Value,
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
):
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
adjusted_padding = []
for pad in padding:
for _ in range(2):
adjusted_padding.append(pad)
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
if not adjusted_padding:
adjusted_padding = [0, 0] * num_dimensional_axis
if not dilation:
dilation = [1] * num_dimensional_axis
if not stride:
stride = [1] * num_dimensional_axis
return g.op(
"Col2Im",
input,
output_size,
kernel_size,
dilations_i=dilation,
pads_i=adjusted_padding,
strides_i=stride,
)