mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Refactor torchscript based exporter (#161323)
Refactor torchscript based exporter logic to move them to a single (private) location for better code management. Original public module and method apis are preserved. - Updated module paths in `torch/csrc/autograd/python_function.cpp` accordingly - Removed `check_onnx_broadcast` from `torch/autograd/_functions/utils.py` because it is private&unused @albanD / @soulitzer could you review changes in `torch/csrc/autograd/python_function.cpp` and `torch/autograd/_functions/utils.py`? Thanks! ## BC Breaking - **Deprecated members in `torch.onnx.verification` are removed** Differential Revision: [D81236421](https://our.internmc.facebook.com/intern/diff/D81236421) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161323 Approved by: https://github.com/titaiwangms, https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
793fc12aff
commit
524b78d4f6
@ -4,7 +4,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal import registration
|
||||
from torch.onnx._internal.torchscript_exporter import registration
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -17,7 +17,8 @@ import pytorch_test_common
|
||||
|
||||
import torch
|
||||
from torch import export as torch_export
|
||||
from torch.onnx import _constants, verification
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal.torchscript_exporter import verification
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.opinfo import core as opinfo_core
|
||||
from torch.types import Number
|
||||
|
@ -5,13 +5,11 @@ from onnx_test_common import run_model_test
|
||||
|
||||
import torch
|
||||
from torch.onnx import OperatorExportTypes
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx.utils import _model_to_graph
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
||||
opset_version = GLOBALS.export_onnx_opset_version
|
||||
opset_version = 20
|
||||
keep_initializers_as_inputs = False
|
||||
onnx_shape_inference = True
|
||||
|
||||
@ -133,7 +131,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
||||
input = torch.ones(1, 5)
|
||||
|
||||
# Test ONNX_FALLTHROUGH_MODE
|
||||
graph, _, _ = _model_to_graph(
|
||||
graph, _, _ = torch.onnx.utils._model_to_graph(
|
||||
model,
|
||||
(input,),
|
||||
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
@ -142,7 +140,7 @@ class TestAutogradFuns(pytorch_test_common.ExportTestCase):
|
||||
self.assertEqual(next(iter).kind(), "prim::PythonOp")
|
||||
|
||||
# Test ATEN_FALLBACK_MODE
|
||||
graph, _, _ = _model_to_graph(
|
||||
graph, _, _ = torch.onnx.utils._model_to_graph(
|
||||
model,
|
||||
(input,),
|
||||
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
|
@ -11,7 +11,7 @@ import torch
|
||||
import torch.onnx
|
||||
from torch.nn import Module
|
||||
from torch.onnx import producer_name, producer_version
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ import onnxscript
|
||||
from onnxscript.onnx_types import FLOAT
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ import onnxscript
|
||||
from onnxscript.onnx_types import FLOAT
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -4,8 +4,11 @@ import pytorch_test_common
|
||||
from pytorch_test_common import skipIfNoCuda
|
||||
|
||||
import torch
|
||||
from torch.onnx import verification
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter import verification
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter.utils import (
|
||||
_trigger_symbolic_function_registration,
|
||||
)
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@ -20,6 +23,7 @@ def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
|
||||
"""
|
||||
|
||||
GLOBALS.export_onnx_opset_version = opset_version
|
||||
_trigger_symbolic_function_registration()
|
||||
graph = torch.onnx.utils._optimize_graph(
|
||||
graph, operator_export_type, params_dict={}
|
||||
)
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.onnx import symbolic_helper, utils
|
||||
from torch.onnx._internal import registration
|
||||
from torch.onnx._internal.torchscript_exporter import registration
|
||||
from torch.testing._internal import common_quantization, common_utils, jit_utils
|
||||
|
||||
|
||||
@ -430,9 +430,8 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
|
||||
torch.randn(3, 4, requires_grad=True),
|
||||
mocks=[
|
||||
unittest.mock.patch(
|
||||
"torch.onnx._internal.registration.registry.get_function_group",
|
||||
"torch.onnx._internal.torchscript_exporter.registration.registry.get_function_group",
|
||||
side_effect=break_is_registered_op_api,
|
||||
# wraps=registration.registry.get_function_group
|
||||
)
|
||||
],
|
||||
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
|
||||
|
@ -41,7 +41,9 @@ from pytorch_test_common import (
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.utils import rnn as rnn_utils
|
||||
from torch.onnx import errors, verification
|
||||
from torch.onnx import errors
|
||||
from torch.onnx._internal.torchscript_exporter import verification
|
||||
from torch.onnx._internal.torchscript_exporter._type_utils import JitScalarType
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
|
||||
@ -13705,7 +13707,7 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime):
|
||||
input_names=["x"],
|
||||
)
|
||||
exported = onnx.load_from_string(f.getvalue())
|
||||
expected_elem_type = torch.onnx.JitScalarType.from_value(x).onnx_type()
|
||||
expected_elem_type = JitScalarType.from_value(x).onnx_type()
|
||||
expected_output_type = onnx.helper.make_optional_type_proto(
|
||||
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
|
||||
)
|
||||
|
@ -10,8 +10,8 @@ from pytorch_test_common import skipIfUnsupportedMinOpsetVersion
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter import jit_utils
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
@ -1,11 +1,9 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import io
|
||||
import re
|
||||
import warnings
|
||||
from typing import Callable
|
||||
|
||||
import onnx
|
||||
|
||||
@ -23,7 +21,7 @@ import torch
|
||||
import torch.onnx
|
||||
import torch.utils.cpp_extension
|
||||
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
|
||||
from torch.onnx._globals import GLOBALS
|
||||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
|
||||
from torch.onnx.symbolic_helper import _unpack_list, parse_args
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_utils import skipIfNoLapack
|
||||
@ -86,86 +84,6 @@ class _BaseTestCase(pytorch_test_common.ExportTestCase):
|
||||
return graph, params_dict, torch_out
|
||||
|
||||
|
||||
@common_utils.instantiate_parametrized_tests
|
||||
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
|
||||
"""Unit tests for the `unconvertible_ops` function."""
|
||||
|
||||
def setUp(self):
|
||||
class EinsumModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.einsum("ii", x)
|
||||
|
||||
self.einsum_module = EinsumModule()
|
||||
|
||||
def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
|
||||
graph, unconvertible_ops = utils.unconvertible_ops(
|
||||
self.einsum_module, (x,), opset_version=9
|
||||
)
|
||||
nodes = graph.nodes()
|
||||
self.assertEqual(next(nodes).kind(), "prim::Constant")
|
||||
self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
|
||||
self.assertEqual(next(nodes).kind(), "prim::Constant")
|
||||
self.assertEqual(next(nodes).kind(), "aten::einsum")
|
||||
self.assertEqual(unconvertible_ops, ["aten::einsum"])
|
||||
|
||||
@common_utils.parametrize(
|
||||
"jit_function",
|
||||
[
|
||||
common_utils.subtest(
|
||||
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
|
||||
name="traced",
|
||||
),
|
||||
common_utils.subtest(torch.jit.script, name="scripted"),
|
||||
],
|
||||
)
|
||||
def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
|
||||
self, jit_function: Callable
|
||||
):
|
||||
module = jit_function(self.einsum_module)
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
|
||||
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
|
||||
self.assertEqual(unconvertible_ops, ["aten::einsum"])
|
||||
|
||||
@common_utils.parametrize(
|
||||
"jit_function",
|
||||
[
|
||||
common_utils.subtest(lambda x: x, name="nn_module"),
|
||||
common_utils.subtest(
|
||||
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
|
||||
name="traced",
|
||||
),
|
||||
common_utils.subtest(torch.jit.script, name="scripted"),
|
||||
],
|
||||
)
|
||||
def test_it_returns_empty_list_when_all_ops_convertible(
|
||||
self, jit_function: Callable
|
||||
):
|
||||
module = jit_function(self.einsum_module)
|
||||
x = torch.randn(4, 4)
|
||||
|
||||
# Einsum is supported since opset 12
|
||||
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
|
||||
self.assertEqual(unconvertible_ops, [])
|
||||
|
||||
def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
|
||||
class SkipConnectionModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out = x
|
||||
out += x
|
||||
out = torch.nn.functional.relu(out, inplace=True)
|
||||
return out
|
||||
|
||||
module = SkipConnectionModule()
|
||||
x = torch.randn(4, 4)
|
||||
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
|
||||
self.assertEqual(unconvertible_ops, [])
|
||||
|
||||
|
||||
@parameterized.parameterized_class(
|
||||
[
|
||||
{"opset_version": opset}
|
||||
|
@ -13,7 +13,8 @@ import pytorch_test_common
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
from torch.onnx import _constants, _experimental, verification
|
||||
from torch.onnx import _constants
|
||||
from torch.onnx._internal.torchscript_exporter import _experimental, verification
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user