[ONNX] Enable experimental exporter logic to dynamo_export and support refine dynamic_shapes (#134976)

(1) Enable experimental exporter logic to dynamo_export
(2) Refine dynamic shapes and retry export in export strategies
(3) Delete `torch_export_graph_extractor` and use the new export logic
(4) Disable ExportedProgram test in `test_fx_onnx_with_onnxruntime.py`, as ONNXProgram is different now.

Fixes https://github.com/pytorch/pytorch/issues/126479
Fixes #135183
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134976
Approved by: https://github.com/justinchuby
This commit is contained in:
titaiwangms
2024-09-06 01:29:54 +00:00
committed by PyTorch MergeBot
parent 1e57ef08fa
commit 8f6e73f068
12 changed files with 246 additions and 411 deletions

View File

@ -148,6 +148,54 @@ class TestExportAPIDynamo(common_utils.TestCase):
},
)
def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self):
os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1"
assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1"
class Nested(torch.nn.Module):
def forward(self, x):
(a0, a1), (b0, b1), (c0, c1, c2) = x
return a0 + a1 + b0 + b1 + c0 + c1 + c2
inputs = (
(1, 2),
(
torch.randn(4, 4),
torch.randn(4, 4),
),
(
torch.randn(4, 4),
torch.randn(4, 4),
torch.randn(4, 4),
),
)
onnx_program = torch.onnx.dynamo_export(
Nested(),
inputs,
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
)
assert onnx_program is not None
onnx_testing.assert_onnx_program(onnx_program)
def test_refine_dynamic_shapes_with_onnx_export(self):
# NOTE: From test/export/test_export.py
# refine lower, upper bound
class TestRefineDynamicShapeModel(torch.nn.Module):
def forward(self, x, y):
if x.shape[0] >= 6 and y.shape[0] <= 16:
return x * 2.0, y + 1
inps = (torch.randn(16), torch.randn(12))
dynamic_shapes = {
"x": (torch.export.Dim("dx"),),
"y": (torch.export.Dim("dy"),),
}
self.assert_export(
TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -346,6 +346,28 @@ def skipDtypeChecking(func):
return wrapper
def skip_if_fake_model_and_inititalizer(reason: Optional[str] = None):
"""skip test with models using ExportedProgram as input.
Args:
reason: The reason for skip the ONNX export test.
Returns:
A decorator for skip tests.
"""
def skip_dec(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
if kwargs["use_fake_mode"] and kwargs["include_initializer"]:
return unittest.SkipTest(reason)
return func(self, *args, **kwargs)
return wrapper
return skip_dec
def xfail_if_model_type_is_exportedprogram(
error_message: str, reason: Optional[str] = None
):

View File

@ -544,34 +544,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
onnx.checker.check_model(onnx_program.model_proto)
onnx.shape_inference.infer_shapes(onnx_program.model_proto)
def test_exported_program_input_with_custom_fx_tracer(self):
from torch.onnx._internal import _exporter_legacy
from torch.onnx._internal.fx import dynamo_graph_extractor
class Model(torch.nn.Module):
def forward(self, x):
return x + 1
x = torch.randn(1, 1, 2)
exported_program = torch.export.export(Model(), args=(x,))
export_options = torch.onnx.ExportOptions()
export_options = _exporter_legacy.ResolvedExportOptions(
export_options, model=exported_program
)
export_options.fx_tracer = (
dynamo_graph_extractor.DynamoExport()
) # Override fx_tracer to an unsupported tracer
with self.assertRaises(torch.onnx.OnnxExporterError):
onnx_program = torch.onnx.dynamo_export(
exported_program,
x,
export_options=export_options,
)
self.assertTrue(onnx_program._export_exception is not None)
with self.assertRaises(torch.onnx.InvalidExportOptionsError):
raise self._export_exception
def test_exported_program_torch_distributions_normal_Normal(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -606,21 +578,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
# with no Cast node in between.
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
def test_exported_program_as_input_with_model_signature(self):
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0
x = torch.randn(1, 1, 2, dtype=torch.float)
exported_program = torch.export.export(Model(), args=(x,))
onnx_program = torch.onnx.dynamo_export(
exported_program,
x,
)
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
@common_utils.parametrize(
"float8_type",
[
@ -707,6 +664,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
onnx_program.save(tmp_onnx_file.name)
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
@pytorch_test_common.skip_if_fake_model_and_inititalizer("segfault")
@common_utils.parametrize(
"include_initializer",
[

View File

@ -19,12 +19,6 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
def _test_exported_program_forces_decomposition(self, model, input, op_type):
ep = torch.export.export(model, input)
onnx_program = torch.onnx.dynamo_export(ep, *input)
with self.assertRaises(AssertionError):
assert_op_in_onnx_model(onnx_program.model_proto, op_type)
def test_upsample_bilinear2d(self):
class TestModel(torch.nn.Module):
def __init__(self) -> None:
@ -37,9 +31,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
self._test_exported_program_forces_decomposition(
TestModel(), (torch.randn(1, 1, 2, 2),), "Resize"
)
def test_upsample_bilinear2d_output_size(self):
def func(x: torch.Tensor):
@ -61,9 +52,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3))
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
self._test_exported_program_forces_decomposition(
TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize"
)
def test_upsample_trilinear3d_output_size(self):
def func(x: torch.Tensor):
@ -82,9 +70,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
# If decomposition is skipped, the model will contain an InstanceNormalization op
# instead of BatchNormalization op w/ training=True.
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
self._test_exported_program_forces_decomposition(
TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization"
)
if __name__ == "__main__":

View File

@ -45,10 +45,7 @@ def _parameterized_class_attrs_and_values():
input_values.extend(
itertools.product(
(True, False),
(
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
(pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
)
)
return {
@ -912,10 +909,7 @@ def _parameterized_class_attrs_and_values_with_fake_options():
(True, False),
(True, False),
(True, False),
(
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
),
(pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
)
)
return {
@ -986,13 +980,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
# Create the toy model with real weight.
real_model = create_model()
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
if (
model_type
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
real_model = torch.export.export(
real_model, args=create_args(), kwargs=create_kwargs()
)
with tempfile.NamedTemporaryFile(
prefix=model_name, suffix=".pt"
@ -1015,13 +1002,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
)
if export_within_fake_mode:
if (
model_type
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
fake_model = torch.export.export(
fake_model, args=fake_args, kwargs=fake_kwargs
)
onnx_program = torch.onnx.dynamo_export(
fake_model,
*fake_args,
@ -1030,13 +1010,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
)
if not export_within_fake_mode:
if (
model_type
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
fake_model = torch.export.export(
fake_model, args=fake_args, kwargs=fake_kwargs
)
onnx_program = torch.onnx.dynamo_export(
fake_model, *fake_args, **fake_kwargs, export_options=export_options
)
@ -1093,10 +1066,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_simple(self):
def create_model() -> nn.Module:
class Model(torch.nn.Module):
@ -1126,10 +1095,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
@pytorch_test_common.xfail_dynamic_fx_test(
error_message="!(it.GetName().empty())",
reason="With after onnx==1.16, constant folding in optimizer causes this error.",
@ -1166,10 +1131,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_large_scale_exporter_with_toy_mlp(self):
class MLPModel(nn.Module):
def __init__(self) -> None:
@ -1208,10 +1169,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_huggingface_google_t5(self):
config = transformers.T5Config(
vocab_size=8096, d_model=64, num_layers=2, num_heads=2
@ -1244,10 +1201,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
@pytorch_test_common.xfail_dynamic_fx_test(
error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
@ -1310,10 +1263,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
def test_fake_tensor_mode_huggingface_mosaicml_mpt(self):
config = transformers.MptConfig(
vocab_size=8096, d_model=64, n_heads=2, n_layers=3
@ -1341,10 +1290,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
@pytorch_test_common.xfail_dynamic_fx_test(
error_message="SymIntArrayRef expected to contain only concrete integers",
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
@ -1374,10 +1319,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
error_message="Expected 5 inputs, got 3",
reason="https://github.com/pytorch/pytorch/issues/115745",
@ -1417,10 +1358,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
model_type=self.model_type,
)
@pytorch_test_common.skip_dynamic_fx_test(
reason="Dynamic shape check is not expected for exported program in this test suite.",
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
)
@pytorch_test_common.xfail_dynamic_fx_test(
error_message="SymIntArrayRef expected to contain only concrete integers",
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,

View File

@ -36,14 +36,15 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs)
else:
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
torch_outputs
)
if len(torch_outputs_onnx_format) != len(onnx_outputs):
if isinstance(torch_outputs, torch.Tensor):
torch_outputs = [torch_outputs]
if len(torch_outputs) != len(onnx_outputs):
raise AssertionError(
f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}"
f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}"
)
for torch_output, onnx_output in zip(torch_outputs_onnx_format, onnx_outputs):
for torch_output, onnx_output in zip(torch_outputs, onnx_outputs):
torch.testing.assert_close(
torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol
)

View File

@ -54,7 +54,7 @@ __all__ = [
"is_onnxrt_backend_supported",
]
from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING
from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
import torch
from torch import _C
@ -112,7 +112,6 @@ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to av
InvalidExportOptionsError,
OnnxExporterError,
OnnxRegistry,
dynamo_export,
enable_fake_mode,
)
@ -126,7 +125,6 @@ JitScalarType.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
dynamo_export.__module__ = "torch.onnx"
InvalidExportOptionsError.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
enable_fake_mode.__module__ = "torch.onnx"
@ -393,6 +391,131 @@ def export(
return None
def dynamo_export(
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
/,
*model_args,
export_options: ExportOptions | None = None,
**model_kwargs,
) -> ONNXProgram | Any:
"""Export a torch.nn.Module to an ONNX graph.
Args:
model: The PyTorch model to be exported to ONNX.
model_args: Positional inputs to ``model``.
model_kwargs: Keyword inputs to ``model``.
export_options: Options to influence the export to ONNX.
Returns:
An in-memory representation of the exported ONNX model.
**Example 1 - Simplest export**
::
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, bias=None):
out = self.linear(x)
out = out + bias
return out
model = MyModel()
kwargs = {"bias": 3.0}
args = (torch.randn(2, 2, 2),)
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
"my_simple_model.onnx"
)
**Example 2 - Exporting with dynamic shapes**
::
# The previous model can be exported with dynamic shapes
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(
model, *args, **kwargs, export_options=export_options
)
onnx_program.save("my_dynamic_model.onnx")
"""
# NOTE: The new exporter is experimental and is not enabled by default.
import warnings
from torch.onnx import _flags
from torch.onnx._internal import exporter
from torch.utils import _pytree
if isinstance(model, torch.export.ExportedProgram):
return exporter.export_compat(
model, # type: ignore[arg-type]
model_args,
f=None,
kwargs=model_kwargs,
opset_version=18,
external_data=True,
export_params=True,
fallback=True,
)
elif _flags.USE_EXPERIMENTAL_LOGIC:
if export_options is not None:
warnings.warn(
"You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
"For a more comprehensive set of export options, including advanced features, please consider using "
"`torch.onnx.export(..., dynamo=True)`. ",
category=FutureWarning,
)
if export_options is not None and export_options.dynamic_shapes:
# Make all shapes dynamic
def _to_dynamic_shapes_mapper():
arg_order = 0
def _to_dynamic_shape(x):
nonlocal arg_order
if isinstance(x, torch.Tensor):
rank = len(x.shape)
dynamic_shape = {}
for i in range(rank):
dynamic_shape[i] = torch.export.Dim(
f"arg_{arg_order}_dim_{i}"
)
arg_order += 1
return dynamic_shape
else:
return None
return _to_dynamic_shape
# model_args could be nested
dynamic_shapes = _pytree.tree_map(
_to_dynamic_shapes_mapper(),
model_args,
)
else:
dynamic_shapes = None
return exporter.export_compat(
model, # type: ignore[arg-type]
model_args,
f=None,
kwargs=model_kwargs,
dynamic_shapes=dynamic_shapes,
opset_version=18,
external_data=True,
export_params=True,
fallback=True,
)
else:
from torch.onnx._internal._exporter_legacy import dynamo_export
return dynamo_export(
model, *model_args, export_options=export_options, **model_kwargs
)
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
# Returns True iff ONNX logging is turned on.

View File

@ -16,7 +16,6 @@ from typing_extensions import Self
import torch
import torch._ops
import torch.export as torch_export
import torch.utils._pytree as pytree
from torch.onnx._internal import io_adapter
from torch.onnx._internal.diagnostics import infra
@ -304,27 +303,17 @@ class ResolvedExportOptions(ExportOptions):
def __init__(
self,
options: ExportOptions | ResolvedExportOptions,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, # type: ignore[name-defined]
model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined]
):
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
diagnostics,
dynamo_graph_extractor,
torch_export_graph_extractor,
)
if isinstance(options, ResolvedExportOptions):
self.dynamic_shapes = options.dynamic_shapes
self.diagnostic_options = options.diagnostic_options
self.fake_context = options.fake_context
# private
if isinstance(model, torch_export.ExportedProgram) and not isinstance(
options.fx_tracer, torch_export_graph_extractor.TorchExport
):
message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer"
e = InvalidExportOptionsError(message)
raise InvalidExportOptionsError(
ONNXProgram._from_failure(e, options.diagnostic_context), message
)
self.fx_tracer = options.fx_tracer
self.onnx_registry = options.onnx_registry
self.onnxfunction_dispatcher = options.onnxfunction_dispatcher
@ -345,10 +334,8 @@ class ResolvedExportOptions(ExportOptions):
self.diagnostic_options = resolve(
options.diagnostic_options, DiagnosticOptions()
)
if isinstance(model, torch_export.ExportedProgram):
self.fx_tracer = torch_export_graph_extractor.TorchExport()
else:
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type]
self.diagnostic_context = diagnostics.DiagnosticContext(
@ -492,7 +479,6 @@ class ONNXProgram:
diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata.
fake_context: The fake context used for symbolic tracing.
export_exception: The exception that occurred during export, if any.
model_signature: The model signature for the exported ONNX graph.
"""
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc]
@ -501,9 +487,8 @@ class ONNXProgram:
_diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc]
_fake_context: Final[ONNXFakeContext | None] # type: ignore[misc]
_export_exception: Final[Exception | None] # type: ignore[misc]
_model_signature: Final[torch.export.ExportGraphSignature | None] # type: ignore[misc]
_model_torch: Final[ # type: ignore[misc]
torch.nn.Module | Callable | torch_export.ExportedProgram | None
torch.nn.Module | Callable | None
]
def __init__(
@ -515,14 +500,9 @@ class ONNXProgram:
*,
fake_context: ONNXFakeContext | None = None,
export_exception: Exception | None = None,
model_signature: torch.export.ExportGraphSignature | None = None,
model_torch: torch.nn.Module
| Callable
| torch_export.ExportedProgram
| None = None,
model_torch: torch.nn.Module | Callable | None = None,
):
self._model_proto = model_proto
self._model_signature = model_signature
self._model_torch = model_torch
self._input_adapter = input_adapter
self._output_adapter = output_adapter
@ -533,10 +513,7 @@ class ONNXProgram:
def __call__(
self,
*args: Any,
model_with_state_dict: torch.nn.Module
| Callable
| torch_export.ExportedProgram
| None = None,
model_with_state_dict: torch.nn.Module | Callable | None = None,
options: ONNXRuntimeOptions | None = None,
**kwargs: Any,
) -> Any:
@ -571,8 +548,6 @@ class ONNXProgram:
onnx_model = os.path.join(tmpdir_path, "model.onnx")
if isinstance(model_with_state_dict, torch.nn.Module):
model_state = model_with_state_dict.state_dict()
elif isinstance(model_with_state_dict, torch_export.ExportedProgram):
model_state = model_with_state_dict.state_dict
else:
model_state = None
self.save(
@ -608,104 +583,6 @@ class ONNXProgram:
raise self._export_exception
return self._model_proto
@property
def model_signature(self) -> torch.export.ExportGraphSignature | None:
"""The model signature for the exported ONNX graph.
This information is relevant because ONNX specification often differs from PyTorch's, resulting
in a ONNX graph with input and output schema different from the actual PyTorch model implementation.
By using the model signature, the users can understand the inputs and outputs differences
and properly execute the model in ONNX Runtime.
NOTE: Model signature is only available when the ONNX graph was exported from a
:class:`torch.export.ExportedProgram` object.
NOTE: Any transformation done to the model that changes the model signature must be accompanied
by updates to this model signature as well through :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`.
Example:
The following model produces different sets of inputs and outputs.
The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight),
and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally
the last 2 inputs are user inputs (namely x and b).
The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output.
>>> import pprint
>>> class CustomModule(torch.nn.Module):
... def __init__(self) -> None:
... super().__init__()
... self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
... self.register_buffer("my_buffer1", torch.tensor(3.0))
... self.register_buffer("my_buffer2", torch.tensor(4.0))
... self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False)
... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
... self.fc1 = torch.nn.Linear(9216, 128, bias=False)
... self.fc2 = torch.nn.Linear(128, 10, bias=False)
...
... def forward(self, x, b):
... tensor_x = self.conv1(x)
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
... tensor_x = self.conv2(tensor_x)
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
... tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2)
... tensor_x = torch.flatten(tensor_x, 1)
... tensor_x = self.fc1(tensor_x)
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
... tensor_x = self.fc2(tensor_x)
... output = torch.nn.functional.log_softmax(tensor_x, dim=1)
... (
... self.my_buffer2.add_(1.0) + self.my_buffer1
... ) # Mutate buffer through in-place addition
... return output
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
>>> exported_program = torch.export.export(
... CustomModule(), args=inputs
... ).run_decompositions({})
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
>>> pprint.pprint(onnx_program.model_signature)
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv1_weight'),
target='conv1.weight',
persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_conv2_weight'),
target='conv2.weight',
persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_fc1_weight'),
target='fc1.weight',
persistent=None),
InputSpec(kind=<InputKind.PARAMETER: 2>,
arg=TensorArgument(name='p_fc2_weight'),
target='fc2.weight',
persistent=None),
InputSpec(kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_my_buffer2'),
target='my_buffer2',
persistent=True),
InputSpec(kind=<InputKind.BUFFER: 3>,
arg=TensorArgument(name='b_my_buffer1'),
target='my_buffer1',
persistent=True),
InputSpec(kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='x'),
target=None,
persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>,
arg=TensorArgument(name='b'),
target=None,
persistent=None)],
output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>,
arg=TensorArgument(name='add'),
target='my_buffer2'),
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,
arg=TensorArgument(name='_log_softmax'),
target=None)])
"""
return self._model_signature
@property
def diagnostic_context(self) -> diagnostics.DiagnosticContext:
"""The diagnostic context associated with the export."""
@ -721,10 +598,7 @@ class ONNXProgram:
def adapt_torch_inputs_to_onnx(
self,
*model_args,
model_with_state_dict: torch.nn.Module
| Callable
| torch_export.ExportedProgram
| None = None,
model_with_state_dict: torch.nn.Module | Callable | None = None,
**model_kwargs,
) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
@ -794,10 +668,7 @@ class ONNXProgram:
def adapt_torch_outputs_to_onnx(
self,
model_outputs: Any,
model_with_state_dict: torch.nn.Module
| Callable
| torch_export.ExportedProgram
| None = None,
model_with_state_dict: torch.nn.Module | Callable | None = None,
) -> Sequence[torch.Tensor | int | float | bool]:
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
@ -1050,7 +921,7 @@ class Exporter:
def __init__(
self,
options: ResolvedExportOptions,
model: torch.nn.Module | Callable | torch_export.ExportedProgram,
model: torch.nn.Module | Callable,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
):
@ -1138,9 +1009,6 @@ class Exporter:
self.options.fx_tracer.output_adapter,
self.options.diagnostic_context,
fake_context=self.options.fake_context,
model_signature=getattr(
self.model, "graph_signature", None
), # Available for isinstance(self.model, ExportedProgram) only
model_torch=self.model,
)
@ -1261,12 +1129,12 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
def dynamo_export(
model: torch.nn.Module | Callable | torch_export.ExportedProgram, # type: ignore[name-defined]
model: torch.nn.Module | Callable,
/,
*model_args,
export_options: ExportOptions | None = None,
**model_kwargs,
) -> ONNXProgram:
) -> ONNXProgram | Any:
"""Export a torch.nn.Module to an ONNX graph.
Args:

View File

@ -120,9 +120,20 @@ class TorchExportStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
try:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
except torch._dynamo.exc.UserError as exc:
# Refine the dynamic shapes based on the suggested fixes.
new_shapes = (
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
exc.msg, dynamic_shapes
)
)
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=new_shapes
)
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
@ -148,9 +159,20 @@ class TorchExportNonStrictStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
)
try:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
)
except torch._dynamo.exc.UserError as exc:
# Refine the dynamic shapes based on the suggested fixes.
new_shapes = (
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
exc.msg, dynamic_shapes
)
)
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False
)
def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))

View File

@ -9,7 +9,6 @@ import logging
from typing import Any, Mapping, Sequence, TYPE_CHECKING
import torch
import torch.export
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.exporter import _core, _onnx_program

View File

@ -176,7 +176,7 @@ class Transform(abc.ABC):
One important aspect to note is that if the transformation modifies the model input and/or output signature,
(e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`
are needed to reconcile :attr:`ONNXProgram.model_signature` and :attr:`ONNXProgram.model_proto`.
are needed to reconcile :attr:`ONNXProgram.model_proto`.
That is, the model signature and the model representation must match.
As an additional feature, this class provides builtin support for transformation recording using the diagnostics.

View File

@ -1,128 +0,0 @@
# mypy: allow-untyped-defs
# NOTE: This file is referenced by name at
# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
# introduced by https://github.com/pytorch/pytorch/pull/98894.
# If this file is renamed, moved, etc please update the reference there!
from __future__ import annotations
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
import torch._dynamo
import torch.fx
from torch.onnx._internal import _exporter_legacy, io_adapter
from torch.onnx._internal.diagnostics import infra
if TYPE_CHECKING:
import torch.onnx
from torch.export.exported_program import ExportedProgram
class TorchExport(_exporter_legacy.FXGraphExtractor):
"""Generates a FX GraphModule using torch.export API
Args:
aten_graph: If True, exports a graph with ATen operators.
If False, exports a graph with Python operators.
"""
def __init__(
self,
aten_graph: bool | None = None,
):
super().__init__()
self.aten_graph = aten_graph or True
def generate_fx(
self,
options: _exporter_legacy.ResolvedExportOptions,
model: ExportedProgram, # type: ignore[override]
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:
# No need to translate callable to FX graph.
# This FX Graph extractor assumes `model` was obtained through
# exported_program = torch.export.export(
# model,
# args=model_args, # type: ignore[arg-type]
# kwargs=model_kwargs, # type: ignore[arg-type]
# )
# Export FX graph to ONNX ModelProto.
self.input_adapter.append_step(
io_adapter.FlattenInputWithTreeSpecValidationInputStep()
)
self.input_adapter.append_step(
io_adapter.PrependParamsBuffersConstantAotAutogradInputStep()
)
# ONNX does not support None inputs. During graph building, all None inputs
# are removed. Here we register this step to input adapter.
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
# NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534
# Dynamo doesn't support non-tensor inputs.
options.fx_tracer.input_adapter.append_step(
io_adapter.RemoveNonTensorInputStep()
)
# ONNX does not support complex inputs. During graph building, all complex inputs
# are converted to real representation inputs. Here we register this step to
# input/output adapter.
options.fx_tracer.input_adapter.append_step(
io_adapter.ConvertComplexToRealRepresentationInputStep()
)
updated_model_args = self.input_adapter.apply(
*model_args, model=model, **model_kwargs
)
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
# tensor, etc), we flatten the collection and register each element as output.
options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
# Output post-processing steps should happen after `FlattenOutputStep`.
options.fx_tracer.output_adapter.append_step(
io_adapter.ConvertComplexToRealRepresentationOutputStep()
)
options.fx_tracer.output_adapter.append_step(
io_adapter.PrependParamsAndBuffersAotAutogradOutputStep()
)
# run_decomposition generates a new graph module with decomposed ops.
# Thus, we need to run this step after io_adapters.
model = model.run_decompositions(options.decomposition_table)
# Export FX graph to ONNX ModelProto.
return self.pre_export_passes( # type: ignore[return-value]
options, model, model.graph_module, updated_model_args
)
def pre_export_passes(
self,
options: _exporter_legacy.ResolvedExportOptions,
original_model: torch.nn.Module | Callable,
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
):
# TODO: Import here to prevent circular dependency
from torch.onnx._internal.fx import analysis, passes
diagnostic_context = options.diagnostic_context
# ONNX does not support concept of (implicit) type promotion.
# Insert type casts explicitly where needed.
fx_module = passes.InsertTypePromotion(diagnostic_context, fx_module).run()
analysis.UnsupportedFxNodesAnalysis(
diagnostic_context, fx_module, options.onnxfunction_dispatcher
).analyze(infra.levels.ERROR)
# This operation should be invoked as the last pre export pass.
# See [NOTE: Modularize pass ordering]
fx_module = passes.Modularize(
diagnostic_context, fx_module, is_exported_program=True
).run()
return fx_module