mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes https://github.com/pytorch/pytorch/issues/84365 and more This PR addresses not only the issue above, but the entire family of issues related to `torch._C.Value.type()` parsing when `scalarType()` or `dtype()` is not available. This issue exists before `JitScalarType` was introduced, but the new implementation refactored the bug in because the new api `from_name` and `from_dtype` requires parsing `torch._C.Value.type()` to get proper inputs, which is exactly the root cause for this family of bugs. Therefore `from_name` and `from_dtype` must be called when the implementor knows the `name` and `dtype` without parsing a `torch._C.Value`. To handle the corner cases hidden within `torch._C.Value`, a new `from_value` API was introduced and it should be used in favor of the former ones for most cases. The new API is safer and doesn't require type parsing from user, triggering JIT asserts in the core of pytorch. Although CI is passing for all tests, please review carefully all symbolics/helpers refactoring to make sure the meaning/intetion of the old call are not changed in the new call Pull Request resolved: https://github.com/pytorch/pytorch/pull/87245 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
"""This file exports ONNX ops for opset 16.
|
|
|
|
Note [ONNX Operators that are added/updated in opset 16]
|
|
|
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
|
|
New operators:
|
|
GridSample https://github.com/onnx/onnx/pull/3557
|
|
|
|
Updated operators:
|
|
Identity
|
|
If
|
|
LeakyRelu
|
|
Loop
|
|
PRelu
|
|
RoiAlign
|
|
Scan
|
|
ScatterElemenets
|
|
ScatterND
|
|
Where
|
|
GreaterOrEqual
|
|
LessOrEqual
|
|
"""
|
|
|
|
# EDITING THIS FILE? READ THIS FIRST!
|
|
# see Note [Edit Symbolic Files] in README.md
|
|
|
|
import functools
|
|
|
|
from torch.nn.functional import (
|
|
GRID_SAMPLE_INTERPOLATION_MODES,
|
|
GRID_SAMPLE_PADDING_MODES,
|
|
)
|
|
from torch.onnx import _type_utils, symbolic_helper
|
|
from torch.onnx._internal import _beartype, jit_utils, registration
|
|
|
|
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
|
|
|
|
|
|
# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
|
|
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
|
|
@_onnx_symbolic("aten::grid_sampler")
|
|
@symbolic_helper.parse_args("v", "v", "i", "i", "b")
|
|
@_beartype.beartype
|
|
def grid_sampler(
|
|
g: jit_utils.GraphContext,
|
|
input,
|
|
grid,
|
|
mode_enum,
|
|
padding_mode_enum,
|
|
align_corners,
|
|
):
|
|
mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
|
|
padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
|
|
return g.op(
|
|
"GridSample",
|
|
input,
|
|
grid,
|
|
align_corners_i=int(align_corners),
|
|
mode_s=mode_s,
|
|
padding_mode_s=padding_mode_s,
|
|
)
|
|
|
|
|
|
@_onnx_symbolic("aten::scatter_add")
|
|
@symbolic_helper.parse_args("v", "i", "v", "v")
|
|
@_beartype.beartype
|
|
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
|
|
if symbolic_helper.is_caffe2_aten_fallback():
|
|
return g.at("scatter", self, dim, index, src, overload_name="src")
|
|
|
|
src_type = _type_utils.JitScalarType.from_value(
|
|
src, _type_utils.JitScalarType.UNDEFINED
|
|
)
|
|
src_sizes = symbolic_helper._get_tensor_sizes(src)
|
|
index_sizes = symbolic_helper._get_tensor_sizes(index)
|
|
|
|
if src_sizes != index_sizes:
|
|
return symbolic_helper._unimplemented(
|
|
"scatter_add",
|
|
f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
|
|
)
|
|
|
|
src = symbolic_helper._maybe_get_scalar(src)
|
|
if symbolic_helper._is_value(src):
|
|
return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
|
|
else:
|
|
# Check if scalar "src" has same type as self (PyTorch allows different
|
|
# type for scalar src (but not when src is tensor)). If not, insert Cast node.
|
|
if _type_utils.JitScalarType.from_value(self) != src_type:
|
|
src = g.op(
|
|
"Cast",
|
|
src,
|
|
to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
|
|
)
|
|
|
|
return g.op(
|
|
"ScatterElements",
|
|
self,
|
|
index,
|
|
src,
|
|
axis_i=dim,
|
|
reduction_s="add",
|
|
)
|