Files
pytorch/torch/onnx/symbolic_opset16.py
Justin Chu c6cdca5c68 [ONNX] Reland #81953 Type utility for converting among JIT, torch and ONNX data types (#82995)
Re-land #81953

Add `_type_utils` for handling data type conversion among JIT, torch and ONNX.

- Replace dictionary / list indexing with methods in ScalarType
- Breaking: **Remove ScalarType from `symbolic_helper`** and move it to `_type_utils`
- Deprecated: "cast_pytorch_to_onnx", "pytorch_name_to_type", "scalar_name_to_pytorch", "scalar_type_to_onnx", "scalar_type_to_pytorch_type" in `symbolic_helper`
- Deprecate the type mappings and lists. Remove all internal references
- Move _cast_func_template to opset 9 and remove its reference elsewhere (clean up). Added documentation for easy discovery

Why: List / dictionary indexing and lookup are error-prone and convoluted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82995
Approved by: https://github.com/kit1980
2022-08-08 23:43:43 +00:00

90 lines
2.8 KiB
Python

"""This file exports ONNX ops for opset 16.
Note [ONNX Operators that are added/updated in opset 16]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
New operators:
GridSample https://github.com/onnx/onnx/pull/3557
Updated operators:
Identity
If
LeakyRelu
Loop
PRelu
RoiAlign
Scan
ScatterElemenets
ScatterND
Where
GreaterOrEqual
LessOrEqual
SequenceMap
"""
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
from torch.nn.functional import (
GRID_SAMPLE_INTERPOLATION_MODES,
GRID_SAMPLE_PADDING_MODES,
)
from torch.onnx import _type_utils, symbolic_helper
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
return g.op(
"GridSample",
input,
grid,
align_corners_i=int(align_corners),
mode_s=mode_s,
padding_mode_s=padding_mode_s,
)
@symbolic_helper.parse_args("v", "i", "v", "v")
def scatter_add(g, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
src_type = src.type().scalarType()
src_sizes = symbolic_helper._get_tensor_sizes(src)
index_sizes = symbolic_helper._get_tensor_sizes(index)
if src_sizes != index_sizes:
return symbolic_helper._unimplemented(
"scatter_add",
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
else:
# Check if scalar "src" has same type as self (PyTorch allows different
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
if self.type().scalarType() != src_type:
src = g.op(
"Cast",
src,
to_i=_type_utils.JitScalarType.from_name(
self.type().scalarType()
).onnx_type(),
)
return g.op(
"ScatterElements",
self,
index,
src,
axis_i=dim,
reduction_s="add",
)