From 8f6e73f068ef995bf17bea1eca4bd3e66138f11a Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Fri, 6 Sep 2024 01:29:54 +0000 Subject: [PATCH] [ONNX] Enable experimental exporter logic to dynamo_export and support refine dynamic_shapes (#134976) (1) Enable experimental exporter logic to dynamo_export (2) Refine dynamic shapes and retry export in export strategies (3) Delete `torch_export_graph_extractor` and use the new export logic (4) Disable ExportedProgram test in `test_fx_onnx_with_onnxruntime.py`, as ONNXProgram is different now. Fixes https://github.com/pytorch/pytorch/issues/126479 Fixes #135183 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134976 Approved by: https://github.com/justinchuby --- test/onnx/exporter/test_api.py | 48 ++++++ test/onnx/pytorch_test_common.py | 22 +++ test/onnx/test_fx_to_onnx.py | 44 +---- test/onnx/test_fx_to_onnx_decomp_skip.py | 15 -- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 67 +------- .../test_torch_export_with_onnxruntime.py | 13 +- torch/onnx/__init__.py | 129 ++++++++++++++- torch/onnx/_internal/_exporter_legacy.py | 154 ++---------------- .../_internal/exporter/_capture_strategies.py | 34 +++- torch/onnx/_internal/exporter/_compat.py | 1 - torch/onnx/_internal/fx/_pass.py | 2 +- .../fx/torch_export_graph_extractor.py | 128 --------------- 12 files changed, 246 insertions(+), 411 deletions(-) delete mode 100644 torch/onnx/_internal/fx/torch_export_graph_extractor.py diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 6f2a1f942313..183307555104 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -148,6 +148,54 @@ class TestExportAPIDynamo(common_utils.TestCase): }, ) + def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self): + os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1" + assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1" + + class Nested(torch.nn.Module): + def forward(self, x): + (a0, a1), (b0, b1), (c0, c1, c2) = x + return a0 + a1 + b0 + b1 + c0 + c1 + c2 + + inputs = ( + (1, 2), + ( + torch.randn(4, 4), + torch.randn(4, 4), + ), + ( + torch.randn(4, 4), + torch.randn(4, 4), + torch.randn(4, 4), + ), + ) + + onnx_program = torch.onnx.dynamo_export( + Nested(), + inputs, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ) + assert onnx_program is not None + onnx_testing.assert_onnx_program(onnx_program) + + def test_refine_dynamic_shapes_with_onnx_export(self): + # NOTE: From test/export/test_export.py + + # refine lower, upper bound + class TestRefineDynamicShapeModel(torch.nn.Module): + def forward(self, x, y): + if x.shape[0] >= 6 and y.shape[0] <= 16: + return x * 2.0, y + 1 + + inps = (torch.randn(16), torch.randn(12)) + dynamic_shapes = { + "x": (torch.export.Dim("dx"),), + "y": (torch.export.Dim("dy"),), + } + self.assert_export( + TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes + ) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 408168a9c711..6c90290a3e96 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -346,6 +346,28 @@ def skipDtypeChecking(func): return wrapper +def skip_if_fake_model_and_inititalizer(reason: Optional[str] = None): + """skip test with models using ExportedProgram as input. + + Args: + reason: The reason for skip the ONNX export test. + + Returns: + A decorator for skip tests. + """ + + def skip_dec(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + if kwargs["use_fake_mode"] and kwargs["include_initializer"]: + return unittest.SkipTest(reason) + return func(self, *args, **kwargs) + + return wrapper + + return skip_dec + + def xfail_if_model_type_is_exportedprogram( error_message: str, reason: Optional[str] = None ): diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 5de711f181f4..460d9996cb93 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -544,34 +544,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): onnx.checker.check_model(onnx_program.model_proto) onnx.shape_inference.infer_shapes(onnx_program.model_proto) - def test_exported_program_input_with_custom_fx_tracer(self): - from torch.onnx._internal import _exporter_legacy - from torch.onnx._internal.fx import dynamo_graph_extractor - - class Model(torch.nn.Module): - def forward(self, x): - return x + 1 - - x = torch.randn(1, 1, 2) - exported_program = torch.export.export(Model(), args=(x,)) - - export_options = torch.onnx.ExportOptions() - export_options = _exporter_legacy.ResolvedExportOptions( - export_options, model=exported_program - ) - export_options.fx_tracer = ( - dynamo_graph_extractor.DynamoExport() - ) # Override fx_tracer to an unsupported tracer - with self.assertRaises(torch.onnx.OnnxExporterError): - onnx_program = torch.onnx.dynamo_export( - exported_program, - x, - export_options=export_options, - ) - self.assertTrue(onnx_program._export_exception is not None) - with self.assertRaises(torch.onnx.InvalidExportOptionsError): - raise self._export_exception - def test_exported_program_torch_distributions_normal_Normal(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -606,21 +578,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): # with no Cast node in between. self.assertEqual(div_node.input[0], model_proto.graph.input[0].name) - def test_exported_program_as_input_with_model_signature(self): - class Model(torch.nn.Module): - def forward(self, x): - return x + 1.0 - - x = torch.randn(1, 1, 2, dtype=torch.float) - exported_program = torch.export.export(Model(), args=(x,)) - - onnx_program = torch.onnx.dynamo_export( - exported_program, - x, - ) - - self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature) - @common_utils.parametrize( "float8_type", [ @@ -707,6 +664,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase): onnx_program.save(tmp_onnx_file.name) onnx.checker.check_model(tmp_onnx_file.name, full_check=True) + @pytorch_test_common.skip_if_fake_model_and_inititalizer("segfault") @common_utils.parametrize( "include_initializer", [ diff --git a/test/onnx/test_fx_to_onnx_decomp_skip.py b/test/onnx/test_fx_to_onnx_decomp_skip.py index db8edce14259..466ee4a0bb95 100644 --- a/test/onnx/test_fx_to_onnx_decomp_skip.py +++ b/test/onnx/test_fx_to_onnx_decomp_skip.py @@ -19,12 +19,6 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str): class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): - def _test_exported_program_forces_decomposition(self, model, input, op_type): - ep = torch.export.export(model, input) - onnx_program = torch.onnx.dynamo_export(ep, *input) - with self.assertRaises(AssertionError): - assert_op_in_onnx_model(onnx_program.model_proto, op_type) - def test_upsample_bilinear2d(self): class TestModel(torch.nn.Module): def __init__(self) -> None: @@ -37,9 +31,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2),), "Resize" - ) def test_upsample_bilinear2d_output_size(self): def func(x: torch.Tensor): @@ -61,9 +52,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3)) # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph. assert_op_in_onnx_model(onnx_program.model_proto, "Resize") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize" - ) def test_upsample_trilinear3d_output_size(self): def func(x: torch.Tensor): @@ -82,9 +70,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase): # If decomposition is skipped, the model will contain an InstanceNormalization op # instead of BatchNormalization op w/ training=True. assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization") - self._test_exported_program_forces_decomposition( - TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization" - ) if __name__ == "__main__": diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index dd53ded23cdb..ff4d3a91bd1a 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -45,10 +45,7 @@ def _parameterized_class_attrs_and_values(): input_values.extend( itertools.product( (True, False), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), + (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), ) ) return { @@ -912,10 +909,7 @@ def _parameterized_class_attrs_and_values_with_fake_options(): (True, False), (True, False), (True, False), - ( - pytorch_test_common.TorchModelType.TORCH_NN_MODULE, - pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ), + (pytorch_test_common.TorchModelType.TORCH_NN_MODULE,), ) ) return { @@ -986,13 +980,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): # Create the toy model with real weight. real_model = create_model() state_dict = real_model.state_dict() # concrete (non-fake) state_dict - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - real_model = torch.export.export( - real_model, args=create_args(), kwargs=create_kwargs() - ) with tempfile.NamedTemporaryFile( prefix=model_name, suffix=".pt" @@ -1015,13 +1002,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): ) if export_within_fake_mode: - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - fake_model = torch.export.export( - fake_model, args=fake_args, kwargs=fake_kwargs - ) onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, @@ -1030,13 +1010,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): ) if not export_within_fake_mode: - if ( - model_type - == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM - ): - fake_model = torch.export.export( - fake_model, args=fake_args, kwargs=fake_kwargs - ) onnx_program = torch.onnx.dynamo_export( fake_model, *fake_args, **fake_kwargs, export_options=export_options ) @@ -1093,10 +1066,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): for ref_output, ort_output in zip(ref_outputs, ort_outputs): torch.testing.assert_close(ref_output, torch.tensor(ort_output)) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_simple(self): def create_model() -> nn.Module: class Model(torch.nn.Module): @@ -1126,10 +1095,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="!(it.GetName().empty())", reason="With after onnx==1.16, constant folding in optimizer causes this error.", @@ -1166,10 +1131,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_large_scale_exporter_with_toy_mlp(self): class MLPModel(nn.Module): def __init__(self) -> None: @@ -1208,10 +1169,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_huggingface_google_t5(self): config = transformers.T5Config( vocab_size=8096, d_model=64, num_layers=2, num_heads=2 @@ -1244,10 +1201,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool", @@ -1310,10 +1263,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) def test_fake_tensor_mode_huggingface_mosaicml_mpt(self): config = transformers.MptConfig( vocab_size=8096, d_model=64, n_heads=2, n_layers=3 @@ -1341,10 +1290,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="SymIntArrayRef expected to contain only concrete integers", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, @@ -1374,10 +1319,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_if_model_type_is_not_exportedprogram( error_message="Expected 5 inputs, got 3", reason="https://github.com/pytorch/pytorch/issues/115745", @@ -1417,10 +1358,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime): model_type=self.model_type, ) - @pytorch_test_common.skip_dynamic_fx_test( - reason="Dynamic shape check is not expected for exported program in this test suite.", - model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM, - ) @pytorch_test_common.xfail_dynamic_fx_test( error_message="SymIntArrayRef expected to contain only concrete integers", model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE, diff --git a/test/onnx/torch_export/test_torch_export_with_onnxruntime.py b/test/onnx/torch_export/test_torch_export_with_onnxruntime.py index 7e3b24874e00..7e7cf24a84e8 100644 --- a/test/onnx/torch_export/test_torch_export_with_onnxruntime.py +++ b/test/onnx/torch_export/test_torch_export_with_onnxruntime.py @@ -36,14 +36,15 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) else: torch_outputs = torch_exported_program(*input_args, **input_kwargs) - torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx( - torch_outputs - ) - if len(torch_outputs_onnx_format) != len(onnx_outputs): + + if isinstance(torch_outputs, torch.Tensor): + torch_outputs = [torch_outputs] + + if len(torch_outputs) != len(onnx_outputs): raise AssertionError( - f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}" + f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}" ) - for torch_output, onnx_output in zip(torch_outputs_onnx_format, onnx_outputs): + for torch_output, onnx_output in zip(torch_outputs, onnx_outputs): torch.testing.assert_close( torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol ) diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 9d4084da1ac4..963766a33e3b 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -54,7 +54,7 @@ __all__ = [ "is_onnxrt_backend_supported", ] -from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING +from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING import torch from torch import _C @@ -112,7 +112,6 @@ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to av InvalidExportOptionsError, OnnxExporterError, OnnxRegistry, - dynamo_export, enable_fake_mode, ) @@ -126,7 +125,6 @@ JitScalarType.__module__ = "torch.onnx" ExportOptions.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" -dynamo_export.__module__ = "torch.onnx" InvalidExportOptionsError.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" enable_fake_mode.__module__ = "torch.onnx" @@ -393,6 +391,131 @@ def export( return None +def dynamo_export( + model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined] + /, + *model_args, + export_options: ExportOptions | None = None, + **model_kwargs, +) -> ONNXProgram | Any: + """Export a torch.nn.Module to an ONNX graph. + + Args: + model: The PyTorch model to be exported to ONNX. + model_args: Positional inputs to ``model``. + model_kwargs: Keyword inputs to ``model``. + export_options: Options to influence the export to ONNX. + + Returns: + An in-memory representation of the exported ONNX model. + + **Example 1 - Simplest export** + :: + + class MyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x, bias=None): + out = self.linear(x) + out = out + bias + return out + + + model = MyModel() + kwargs = {"bias": 3.0} + args = (torch.randn(2, 2, 2),) + onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save( + "my_simple_model.onnx" + ) + + **Example 2 - Exporting with dynamic shapes** + :: + + # The previous model can be exported with dynamic shapes + export_options = torch.onnx.ExportOptions(dynamic_shapes=True) + onnx_program = torch.onnx.dynamo_export( + model, *args, **kwargs, export_options=export_options + ) + onnx_program.save("my_dynamic_model.onnx") + """ + + # NOTE: The new exporter is experimental and is not enabled by default. + import warnings + + from torch.onnx import _flags + from torch.onnx._internal import exporter + from torch.utils import _pytree + + if isinstance(model, torch.export.ExportedProgram): + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + elif _flags.USE_EXPERIMENTAL_LOGIC: + if export_options is not None: + warnings.warn( + "You are using an experimental ONNX export logic, which currently only supports dynamic shapes. " + "For a more comprehensive set of export options, including advanced features, please consider using " + "`torch.onnx.export(..., dynamo=True)`. ", + category=FutureWarning, + ) + + if export_options is not None and export_options.dynamic_shapes: + # Make all shapes dynamic + def _to_dynamic_shapes_mapper(): + arg_order = 0 + + def _to_dynamic_shape(x): + nonlocal arg_order + if isinstance(x, torch.Tensor): + rank = len(x.shape) + dynamic_shape = {} + for i in range(rank): + dynamic_shape[i] = torch.export.Dim( + f"arg_{arg_order}_dim_{i}" + ) + arg_order += 1 + return dynamic_shape + else: + return None + + return _to_dynamic_shape + + # model_args could be nested + dynamic_shapes = _pytree.tree_map( + _to_dynamic_shapes_mapper(), + model_args, + ) + else: + dynamic_shapes = None + + return exporter.export_compat( + model, # type: ignore[arg-type] + model_args, + f=None, + kwargs=model_kwargs, + dynamic_shapes=dynamic_shapes, + opset_version=18, + external_data=True, + export_params=True, + fallback=True, + ) + else: + from torch.onnx._internal._exporter_legacy import dynamo_export + + return dynamo_export( + model, *model_args, export_options=export_options, **model_kwargs + ) + + # TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module. # Returns True iff ONNX logging is turned on. diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 4d1830b3e9ce..04347773abf0 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -16,7 +16,6 @@ from typing_extensions import Self import torch import torch._ops -import torch.export as torch_export import torch.utils._pytree as pytree from torch.onnx._internal import io_adapter from torch.onnx._internal.diagnostics import infra @@ -304,27 +303,17 @@ class ResolvedExportOptions(ExportOptions): def __init__( self, options: ExportOptions | ResolvedExportOptions, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, # type: ignore[name-defined] + model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined] ): from torch.onnx._internal.fx import ( # TODO: Prevent circular dep diagnostics, dynamo_graph_extractor, - torch_export_graph_extractor, ) if isinstance(options, ResolvedExportOptions): self.dynamic_shapes = options.dynamic_shapes self.diagnostic_options = options.diagnostic_options self.fake_context = options.fake_context - # private - if isinstance(model, torch_export.ExportedProgram) and not isinstance( - options.fx_tracer, torch_export_graph_extractor.TorchExport - ): - message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer" - e = InvalidExportOptionsError(message) - raise InvalidExportOptionsError( - ONNXProgram._from_failure(e, options.diagnostic_context), message - ) self.fx_tracer = options.fx_tracer self.onnx_registry = options.onnx_registry self.onnxfunction_dispatcher = options.onnxfunction_dispatcher @@ -345,10 +334,8 @@ class ResolvedExportOptions(ExportOptions): self.diagnostic_options = resolve( options.diagnostic_options, DiagnosticOptions() ) - if isinstance(model, torch_export.ExportedProgram): - self.fx_tracer = torch_export_graph_extractor.TorchExport() - else: - self.fx_tracer = dynamo_graph_extractor.DynamoExport() + + self.fx_tracer = dynamo_graph_extractor.DynamoExport() self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type] self.diagnostic_context = diagnostics.DiagnosticContext( @@ -492,7 +479,6 @@ class ONNXProgram: diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata. fake_context: The fake context used for symbolic tracing. export_exception: The exception that occurred during export, if any. - model_signature: The model signature for the exported ONNX graph. """ _model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc] @@ -501,9 +487,8 @@ class ONNXProgram: _diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc] _fake_context: Final[ONNXFakeContext | None] # type: ignore[misc] _export_exception: Final[Exception | None] # type: ignore[misc] - _model_signature: Final[torch.export.ExportGraphSignature | None] # type: ignore[misc] _model_torch: Final[ # type: ignore[misc] - torch.nn.Module | Callable | torch_export.ExportedProgram | None + torch.nn.Module | Callable | None ] def __init__( @@ -515,14 +500,9 @@ class ONNXProgram: *, fake_context: ONNXFakeContext | None = None, export_exception: Exception | None = None, - model_signature: torch.export.ExportGraphSignature | None = None, - model_torch: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_torch: torch.nn.Module | Callable | None = None, ): self._model_proto = model_proto - self._model_signature = model_signature self._model_torch = model_torch self._input_adapter = input_adapter self._output_adapter = output_adapter @@ -533,10 +513,7 @@ class ONNXProgram: def __call__( self, *args: Any, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, options: ONNXRuntimeOptions | None = None, **kwargs: Any, ) -> Any: @@ -571,8 +548,6 @@ class ONNXProgram: onnx_model = os.path.join(tmpdir_path, "model.onnx") if isinstance(model_with_state_dict, torch.nn.Module): model_state = model_with_state_dict.state_dict() - elif isinstance(model_with_state_dict, torch_export.ExportedProgram): - model_state = model_with_state_dict.state_dict else: model_state = None self.save( @@ -608,104 +583,6 @@ class ONNXProgram: raise self._export_exception return self._model_proto - @property - def model_signature(self) -> torch.export.ExportGraphSignature | None: - """The model signature for the exported ONNX graph. - - This information is relevant because ONNX specification often differs from PyTorch's, resulting - in a ONNX graph with input and output schema different from the actual PyTorch model implementation. - By using the model signature, the users can understand the inputs and outputs differences - and properly execute the model in ONNX Runtime. - - NOTE: Model signature is only available when the ONNX graph was exported from a - :class:`torch.export.ExportedProgram` object. - - NOTE: Any transformation done to the model that changes the model signature must be accompanied - by updates to this model signature as well through :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`. - - Example: - - The following model produces different sets of inputs and outputs. - The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight), - and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally - the last 2 inputs are user inputs (namely x and b). - The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output. - - >>> import pprint - >>> class CustomModule(torch.nn.Module): - ... def __init__(self) -> None: - ... super().__init__() - ... self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) - ... self.register_buffer("my_buffer1", torch.tensor(3.0)) - ... self.register_buffer("my_buffer2", torch.tensor(4.0)) - ... self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False) - ... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False) - ... self.fc1 = torch.nn.Linear(9216, 128, bias=False) - ... self.fc2 = torch.nn.Linear(128, 10, bias=False) - ... - ... def forward(self, x, b): - ... tensor_x = self.conv1(x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = self.conv2(tensor_x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2) - ... tensor_x = torch.flatten(tensor_x, 1) - ... tensor_x = self.fc1(tensor_x) - ... tensor_x = torch.nn.functional.sigmoid(tensor_x) - ... tensor_x = self.fc2(tensor_x) - ... output = torch.nn.functional.log_softmax(tensor_x, dim=1) - ... ( - ... self.my_buffer2.add_(1.0) + self.my_buffer1 - ... ) # Mutate buffer through in-place addition - ... return output - >>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3)) - >>> exported_program = torch.export.export( - ... CustomModule(), args=inputs - ... ).run_decompositions({}) - >>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs) - >>> pprint.pprint(onnx_program.model_signature) - ExportGraphSignature(input_specs=[InputSpec(kind=, - arg=TensorArgument(name='p_conv1_weight'), - target='conv1.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_conv2_weight'), - target='conv2.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_fc1_weight'), - target='fc1.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='p_fc2_weight'), - target='fc2.weight', - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='b_my_buffer2'), - target='my_buffer2', - persistent=True), - InputSpec(kind=, - arg=TensorArgument(name='b_my_buffer1'), - target='my_buffer1', - persistent=True), - InputSpec(kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None), - InputSpec(kind=, - arg=TensorArgument(name='b'), - target=None, - persistent=None)], - output_specs=[OutputSpec(kind=, - arg=TensorArgument(name='add'), - target='my_buffer2'), - OutputSpec(kind=, - arg=TensorArgument(name='_log_softmax'), - target=None)]) - """ - - return self._model_signature - @property def diagnostic_context(self) -> diagnostics.DiagnosticContext: """The diagnostic context associated with the export.""" @@ -721,10 +598,7 @@ class ONNXProgram: def adapt_torch_inputs_to_onnx( self, *model_args, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, **model_kwargs, ) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]: """Converts the PyTorch model inputs to exported ONNX model inputs format. @@ -794,10 +668,7 @@ class ONNXProgram: def adapt_torch_outputs_to_onnx( self, model_outputs: Any, - model_with_state_dict: torch.nn.Module - | Callable - | torch_export.ExportedProgram - | None = None, + model_with_state_dict: torch.nn.Module | Callable | None = None, ) -> Sequence[torch.Tensor | int | float | bool]: """Converts the PyTorch model outputs to exported ONNX model outputs format. @@ -1050,7 +921,7 @@ class Exporter: def __init__( self, options: ResolvedExportOptions, - model: torch.nn.Module | Callable | torch_export.ExportedProgram, + model: torch.nn.Module | Callable, model_args: Sequence[Any], model_kwargs: Mapping[str, Any], ): @@ -1138,9 +1009,6 @@ class Exporter: self.options.fx_tracer.output_adapter, self.options.diagnostic_context, fake_context=self.options.fake_context, - model_signature=getattr( - self.model, "graph_signature", None - ), # Available for isinstance(self.model, ExportedProgram) only model_torch=self.model, ) @@ -1261,12 +1129,12 @@ def _assert_dependencies(export_options: ResolvedExportOptions): def dynamo_export( - model: torch.nn.Module | Callable | torch_export.ExportedProgram, # type: ignore[name-defined] + model: torch.nn.Module | Callable, /, *model_args, export_options: ExportOptions | None = None, **model_kwargs, -) -> ONNXProgram: +) -> ONNXProgram | Any: """Export a torch.nn.Module to an ONNX graph. Args: diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index dc511491d6b4..8b908bc35e93 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -120,9 +120,20 @@ class TorchExportStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes - ) + try: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + new_shapes = ( + torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + ) + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes + ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) @@ -148,9 +159,20 @@ class TorchExportNonStrictStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False - ) + try: + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + new_shapes = ( + torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + ) + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False + ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 72d2a411c980..3fddef36b8b4 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -9,7 +9,6 @@ import logging from typing import Any, Mapping, Sequence, TYPE_CHECKING import torch -import torch.export from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir from torch.onnx._internal.exporter import _core, _onnx_program diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index 388ae29cb699..5246788756f3 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -176,7 +176,7 @@ class Transform(abc.ABC): One important aspect to note is that if the transformation modifies the model input and/or output signature, (e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep` - are needed to reconcile :attr:`ONNXProgram.model_signature` and :attr:`ONNXProgram.model_proto`. + are needed to reconcile :attr:`ONNXProgram.model_proto`. That is, the model signature and the model representation must match. As an additional feature, this class provides builtin support for transformation recording using the diagnostics. diff --git a/torch/onnx/_internal/fx/torch_export_graph_extractor.py b/torch/onnx/_internal/fx/torch_export_graph_extractor.py deleted file mode 100644 index aff9c154cd9e..000000000000 --- a/torch/onnx/_internal/fx/torch_export_graph_extractor.py +++ /dev/null @@ -1,128 +0,0 @@ -# mypy: allow-untyped-defs -# NOTE: This file is referenced by name at -# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. -# introduced by https://github.com/pytorch/pytorch/pull/98894. -# If this file is renamed, moved, etc please update the reference there! - -from __future__ import annotations - -from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING - -import torch._dynamo -import torch.fx -from torch.onnx._internal import _exporter_legacy, io_adapter -from torch.onnx._internal.diagnostics import infra - - -if TYPE_CHECKING: - import torch.onnx - from torch.export.exported_program import ExportedProgram - - -class TorchExport(_exporter_legacy.FXGraphExtractor): - """Generates a FX GraphModule using torch.export API - Args: - aten_graph: If True, exports a graph with ATen operators. - If False, exports a graph with Python operators. - """ - - def __init__( - self, - aten_graph: bool | None = None, - ): - super().__init__() - self.aten_graph = aten_graph or True - - def generate_fx( - self, - options: _exporter_legacy.ResolvedExportOptions, - model: ExportedProgram, # type: ignore[override] - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - # No need to translate callable to FX graph. - # This FX Graph extractor assumes `model` was obtained through - # exported_program = torch.export.export( - # model, - # args=model_args, # type: ignore[arg-type] - # kwargs=model_kwargs, # type: ignore[arg-type] - # ) - - # Export FX graph to ONNX ModelProto. - self.input_adapter.append_step( - io_adapter.FlattenInputWithTreeSpecValidationInputStep() - ) - self.input_adapter.append_step( - io_adapter.PrependParamsBuffersConstantAotAutogradInputStep() - ) - - # ONNX does not support None inputs. During graph building, all None inputs - # are removed. Here we register this step to input adapter. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) - - # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 - # Dynamo doesn't support non-tensor inputs. - options.fx_tracer.input_adapter.append_step( - io_adapter.RemoveNonTensorInputStep() - ) - - # ONNX does not support complex inputs. During graph building, all complex inputs - # are converted to real representation inputs. Here we register this step to - # input/output adapter. - options.fx_tracer.input_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationInputStep() - ) - - updated_model_args = self.input_adapter.apply( - *model_args, model=model, **model_kwargs - ) - - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) - - # Output post-processing steps should happen after `FlattenOutputStep`. - options.fx_tracer.output_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationOutputStep() - ) - - options.fx_tracer.output_adapter.append_step( - io_adapter.PrependParamsAndBuffersAotAutogradOutputStep() - ) - - # run_decomposition generates a new graph module with decomposed ops. - # Thus, we need to run this step after io_adapters. - model = model.run_decompositions(options.decomposition_table) - - # Export FX graph to ONNX ModelProto. - return self.pre_export_passes( # type: ignore[return-value] - options, model, model.graph_module, updated_model_args - ) - - def pre_export_passes( - self, - options: _exporter_legacy.ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - # TODO: Import here to prevent circular dependency - from torch.onnx._internal.fx import analysis, passes - - diagnostic_context = options.diagnostic_context - - # ONNX does not support concept of (implicit) type promotion. - # Insert type casts explicitly where needed. - fx_module = passes.InsertTypePromotion(diagnostic_context, fx_module).run() - - analysis.UnsupportedFxNodesAnalysis( - diagnostic_context, fx_module, options.onnxfunction_dispatcher - ).analyze(infra.levels.ERROR) - - # This operation should be invoked as the last pre export pass. - # See [NOTE: Modularize pass ordering] - fx_module = passes.Modularize( - diagnostic_context, fx_module, is_exported_program=True - ).run() - - return fx_module