Files
pytorch/torch/onnx/symbolic_opset16.py
Thiago Crepaldi a8f40b39ce Update all ONNX symbolics with new JitScalarType API (#87245)
Fixes https://github.com/pytorch/pytorch/issues/84365 and more

This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available.

This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs.

Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch.

Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
2022-11-03 03:01:33 +00:00

105 lines
3.1 KiB
Python

"""This file exports ONNX ops for opset 16.
Note [ONNX Operators that are added/updated in opset 16]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
New operators:
GridSample https://github.com/onnx/onnx/pull/3557
Updated operators:
Identity
If
LeakyRelu
Loop
PRelu
RoiAlign
Scan
ScatterElemenets
ScatterND
Where
GreaterOrEqual
LessOrEqual
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
import functools
from torch.nn.functional import (
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, symbolic_helper
from torch.onnx._internal import _beartype, jit_utils, registration
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@_onnx_symbolic("aten::grid_sampler")
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
@_beartype.beartype
def grid_sampler(
g: jit_utils.GraphContext,
input,
grid,
mode_enum,
padding_mode_enum,
align_corners,
):
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@_onnx_symbolic("aten::scatter_add")
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
)
src_sizes = symbolic_helper._get_tensor_sizes(src)
index_sizes = symbolic_helper._get_tensor_sizes(index)
if src_sizes != index_sizes:
return symbolic_helper._unimplemented(
"scatter_add",
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if _type_utils.JitScalarType.from_value(self) != src_type:
src = g.op(
"Cast",
src,
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
)
return g.op(
"ScatterElements",
self,
index,
src,
axis_i=dim,
reduction_s="add",
)