mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Enable experimental exporter logic to dynamo_export and support refine dynamic_shapes (#134976)
(1) Enable experimental exporter logic to dynamo_export (2) Refine dynamic shapes and retry export in export strategies (3) Delete `torch_export_graph_extractor` and use the new export logic (4) Disable ExportedProgram test in `test_fx_onnx_with_onnxruntime.py`, as ONNXProgram is different now. Fixes https://github.com/pytorch/pytorch/issues/126479 Fixes #135183 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134976 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e57ef08fa
commit
8f6e73f068
@ -148,6 +148,54 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_auto_convert_all_axes_to_dynamic_shapes_with_dynamo_export(self):
|
||||
os.environ["TORCH_ONNX_USE_EXPERIMENTAL_LOGIC"] = "1"
|
||||
assert os.environ.get("TORCH_ONNX_USE_EXPERIMENTAL_LOGIC") == "1"
|
||||
|
||||
class Nested(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
(a0, a1), (b0, b1), (c0, c1, c2) = x
|
||||
return a0 + a1 + b0 + b1 + c0 + c1 + c2
|
||||
|
||||
inputs = (
|
||||
(1, 2),
|
||||
(
|
||||
torch.randn(4, 4),
|
||||
torch.randn(4, 4),
|
||||
),
|
||||
(
|
||||
torch.randn(4, 4),
|
||||
torch.randn(4, 4),
|
||||
torch.randn(4, 4),
|
||||
),
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
Nested(),
|
||||
inputs,
|
||||
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
|
||||
)
|
||||
assert onnx_program is not None
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
def test_refine_dynamic_shapes_with_onnx_export(self):
|
||||
# NOTE: From test/export/test_export.py
|
||||
|
||||
# refine lower, upper bound
|
||||
class TestRefineDynamicShapeModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
if x.shape[0] >= 6 and y.shape[0] <= 16:
|
||||
return x * 2.0, y + 1
|
||||
|
||||
inps = (torch.randn(16), torch.randn(12))
|
||||
dynamic_shapes = {
|
||||
"x": (torch.export.Dim("dx"),),
|
||||
"y": (torch.export.Dim("dy"),),
|
||||
}
|
||||
self.assert_export(
|
||||
TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -346,6 +346,28 @@ def skipDtypeChecking(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def skip_if_fake_model_and_inititalizer(reason: Optional[str] = None):
|
||||
"""skip test with models using ExportedProgram as input.
|
||||
|
||||
Args:
|
||||
reason: The reason for skip the ONNX export test.
|
||||
|
||||
Returns:
|
||||
A decorator for skip tests.
|
||||
"""
|
||||
|
||||
def skip_dec(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if kwargs["use_fake_mode"] and kwargs["include_initializer"]:
|
||||
return unittest.SkipTest(reason)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return skip_dec
|
||||
|
||||
|
||||
def xfail_if_model_type_is_exportedprogram(
|
||||
error_message: str, reason: Optional[str] = None
|
||||
):
|
||||
|
@ -544,34 +544,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
onnx.checker.check_model(onnx_program.model_proto)
|
||||
onnx.shape_inference.infer_shapes(onnx_program.model_proto)
|
||||
|
||||
def test_exported_program_input_with_custom_fx_tracer(self):
|
||||
from torch.onnx._internal import _exporter_legacy
|
||||
from torch.onnx._internal.fx import dynamo_graph_extractor
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(1, 1, 2)
|
||||
exported_program = torch.export.export(Model(), args=(x,))
|
||||
|
||||
export_options = torch.onnx.ExportOptions()
|
||||
export_options = _exporter_legacy.ResolvedExportOptions(
|
||||
export_options, model=exported_program
|
||||
)
|
||||
export_options.fx_tracer = (
|
||||
dynamo_graph_extractor.DynamoExport()
|
||||
) # Override fx_tracer to an unsupported tracer
|
||||
with self.assertRaises(torch.onnx.OnnxExporterError):
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
x,
|
||||
export_options=export_options,
|
||||
)
|
||||
self.assertTrue(onnx_program._export_exception is not None)
|
||||
with self.assertRaises(torch.onnx.InvalidExportOptionsError):
|
||||
raise self._export_exception
|
||||
|
||||
def test_exported_program_torch_distributions_normal_Normal(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -606,21 +578,6 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
# with no Cast node in between.
|
||||
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
|
||||
|
||||
def test_exported_program_as_input_with_model_signature(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1.0
|
||||
|
||||
x = torch.randn(1, 1, 2, dtype=torch.float)
|
||||
exported_program = torch.export.export(Model(), args=(x,))
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
x,
|
||||
)
|
||||
|
||||
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
|
||||
|
||||
@common_utils.parametrize(
|
||||
"float8_type",
|
||||
[
|
||||
@ -707,6 +664,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
onnx_program.save(tmp_onnx_file.name)
|
||||
onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
|
||||
|
||||
@pytorch_test_common.skip_if_fake_model_and_inititalizer("segfault")
|
||||
@common_utils.parametrize(
|
||||
"include_initializer",
|
||||
[
|
||||
|
@ -19,12 +19,6 @@ def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
|
||||
|
||||
|
||||
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
def _test_exported_program_forces_decomposition(self, model, input, op_type):
|
||||
ep = torch.export.export(model, input)
|
||||
onnx_program = torch.onnx.dynamo_export(ep, *input)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, op_type)
|
||||
|
||||
def test_upsample_bilinear2d(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -37,9 +31,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
self._test_exported_program_forces_decomposition(
|
||||
TestModel(), (torch.randn(1, 1, 2, 2),), "Resize"
|
||||
)
|
||||
|
||||
def test_upsample_bilinear2d_output_size(self):
|
||||
def func(x: torch.Tensor):
|
||||
@ -61,9 +52,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2, 3))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
self._test_exported_program_forces_decomposition(
|
||||
TestModel(), (torch.randn(1, 1, 2, 2, 3),), "Resize"
|
||||
)
|
||||
|
||||
def test_upsample_trilinear3d_output_size(self):
|
||||
def func(x: torch.Tensor):
|
||||
@ -82,9 +70,6 @@ class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
# If decomposition is skipped, the model will contain an InstanceNormalization op
|
||||
# instead of BatchNormalization op w/ training=True.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
|
||||
self._test_exported_program_forces_decomposition(
|
||||
TestModel(), (torch.randn(1, 1, 2, 2),), "InstanceNormalization"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -45,10 +45,7 @@ def _parameterized_class_attrs_and_values():
|
||||
input_values.extend(
|
||||
itertools.product(
|
||||
(True, False),
|
||||
(
|
||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
(pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
|
||||
)
|
||||
)
|
||||
return {
|
||||
@ -912,10 +909,7 @@ def _parameterized_class_attrs_and_values_with_fake_options():
|
||||
(True, False),
|
||||
(True, False),
|
||||
(True, False),
|
||||
(
|
||||
pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
),
|
||||
(pytorch_test_common.TorchModelType.TORCH_NN_MODULE,),
|
||||
)
|
||||
)
|
||||
return {
|
||||
@ -986,13 +980,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# Create the toy model with real weight.
|
||||
real_model = create_model()
|
||||
state_dict = real_model.state_dict() # concrete (non-fake) state_dict
|
||||
if (
|
||||
model_type
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
real_model = torch.export.export(
|
||||
real_model, args=create_args(), kwargs=create_kwargs()
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix=model_name, suffix=".pt"
|
||||
@ -1015,13 +1002,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
)
|
||||
|
||||
if export_within_fake_mode:
|
||||
if (
|
||||
model_type
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
fake_model = torch.export.export(
|
||||
fake_model, args=fake_args, kwargs=fake_kwargs
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
fake_model,
|
||||
*fake_args,
|
||||
@ -1030,13 +1010,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
)
|
||||
|
||||
if not export_within_fake_mode:
|
||||
if (
|
||||
model_type
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
fake_model = torch.export.export(
|
||||
fake_model, args=fake_args, kwargs=fake_kwargs
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
fake_model, *fake_args, **fake_kwargs, export_options=export_options
|
||||
)
|
||||
@ -1093,10 +1066,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
|
||||
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_fake_tensor_mode_simple(self):
|
||||
def create_model() -> nn.Module:
|
||||
class Model(torch.nn.Module):
|
||||
@ -1126,10 +1095,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||
error_message="!(it.GetName().empty())",
|
||||
reason="With after onnx==1.16, constant folding in optimizer causes this error.",
|
||||
@ -1166,10 +1131,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_large_scale_exporter_with_toy_mlp(self):
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -1208,10 +1169,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_fake_tensor_mode_huggingface_google_t5(self):
|
||||
config = transformers.T5Config(
|
||||
vocab_size=8096, d_model=64, num_layers=2, num_heads=2
|
||||
@ -1244,10 +1201,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||
error_message="scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
|
||||
reason="Dynamo error: scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool",
|
||||
@ -1310,10 +1263,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
def test_fake_tensor_mode_huggingface_mosaicml_mpt(self):
|
||||
config = transformers.MptConfig(
|
||||
vocab_size=8096, d_model=64, n_heads=2, n_layers=3
|
||||
@ -1341,10 +1290,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||
error_message="SymIntArrayRef expected to contain only concrete integers",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
@ -1374,10 +1319,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
@pytorch_test_common.xfail_if_model_type_is_not_exportedprogram(
|
||||
error_message="Expected 5 inputs, got 3",
|
||||
reason="https://github.com/pytorch/pytorch/issues/115745",
|
||||
@ -1417,10 +1358,6 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
model_type=self.model_type,
|
||||
)
|
||||
|
||||
@pytorch_test_common.skip_dynamic_fx_test(
|
||||
reason="Dynamic shape check is not expected for exported program in this test suite.",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM,
|
||||
)
|
||||
@pytorch_test_common.xfail_dynamic_fx_test(
|
||||
error_message="SymIntArrayRef expected to contain only concrete integers",
|
||||
model_type=pytorch_test_common.TorchModelType.TORCH_NN_MODULE,
|
||||
|
@ -36,14 +36,15 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs)
|
||||
else:
|
||||
torch_outputs = torch_exported_program(*input_args, **input_kwargs)
|
||||
torch_outputs_onnx_format = onnx_exported_program.adapt_torch_outputs_to_onnx(
|
||||
torch_outputs
|
||||
)
|
||||
if len(torch_outputs_onnx_format) != len(onnx_outputs):
|
||||
|
||||
if isinstance(torch_outputs, torch.Tensor):
|
||||
torch_outputs = [torch_outputs]
|
||||
|
||||
if len(torch_outputs) != len(onnx_outputs):
|
||||
raise AssertionError(
|
||||
f"Expected {len(torch_outputs_onnx_format)} outputs, got {len(onnx_outputs)}"
|
||||
f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}"
|
||||
)
|
||||
for torch_output, onnx_output in zip(torch_outputs_onnx_format, onnx_outputs):
|
||||
for torch_output, onnx_output in zip(torch_outputs, onnx_outputs):
|
||||
torch.testing.assert_close(
|
||||
torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol
|
||||
)
|
||||
|
@ -54,7 +54,7 @@ __all__ = [
|
||||
"is_onnxrt_backend_supported",
|
||||
]
|
||||
|
||||
from typing import Any, Collection, Mapping, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Callable, Collection, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
@ -112,7 +112,6 @@ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to av
|
||||
InvalidExportOptionsError,
|
||||
OnnxExporterError,
|
||||
OnnxRegistry,
|
||||
dynamo_export,
|
||||
enable_fake_mode,
|
||||
)
|
||||
|
||||
@ -126,7 +125,6 @@ JitScalarType.__module__ = "torch.onnx"
|
||||
ExportOptions.__module__ = "torch.onnx"
|
||||
ONNXProgram.__module__ = "torch.onnx"
|
||||
ONNXRuntimeOptions.__module__ = "torch.onnx"
|
||||
dynamo_export.__module__ = "torch.onnx"
|
||||
InvalidExportOptionsError.__module__ = "torch.onnx"
|
||||
OnnxExporterError.__module__ = "torch.onnx"
|
||||
enable_fake_mode.__module__ = "torch.onnx"
|
||||
@ -393,6 +391,131 @@ def export(
|
||||
return None
|
||||
|
||||
|
||||
def dynamo_export(
|
||||
model: torch.nn.Module | Callable | torch.export.ExportedProgram, # type: ignore[name-defined]
|
||||
/,
|
||||
*model_args,
|
||||
export_options: ExportOptions | None = None,
|
||||
**model_kwargs,
|
||||
) -> ONNXProgram | Any:
|
||||
"""Export a torch.nn.Module to an ONNX graph.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to be exported to ONNX.
|
||||
model_args: Positional inputs to ``model``.
|
||||
model_kwargs: Keyword inputs to ``model``.
|
||||
export_options: Options to influence the export to ONNX.
|
||||
|
||||
Returns:
|
||||
An in-memory representation of the exported ONNX model.
|
||||
|
||||
**Example 1 - Simplest export**
|
||||
::
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(2, 2)
|
||||
|
||||
def forward(self, x, bias=None):
|
||||
out = self.linear(x)
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
|
||||
model = MyModel()
|
||||
kwargs = {"bias": 3.0}
|
||||
args = (torch.randn(2, 2, 2),)
|
||||
onnx_program = torch.onnx.dynamo_export(model, *args, **kwargs).save(
|
||||
"my_simple_model.onnx"
|
||||
)
|
||||
|
||||
**Example 2 - Exporting with dynamic shapes**
|
||||
::
|
||||
|
||||
# The previous model can be exported with dynamic shapes
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
model, *args, **kwargs, export_options=export_options
|
||||
)
|
||||
onnx_program.save("my_dynamic_model.onnx")
|
||||
"""
|
||||
|
||||
# NOTE: The new exporter is experimental and is not enabled by default.
|
||||
import warnings
|
||||
|
||||
from torch.onnx import _flags
|
||||
from torch.onnx._internal import exporter
|
||||
from torch.utils import _pytree
|
||||
|
||||
if isinstance(model, torch.export.ExportedProgram):
|
||||
return exporter.export_compat(
|
||||
model, # type: ignore[arg-type]
|
||||
model_args,
|
||||
f=None,
|
||||
kwargs=model_kwargs,
|
||||
opset_version=18,
|
||||
external_data=True,
|
||||
export_params=True,
|
||||
fallback=True,
|
||||
)
|
||||
elif _flags.USE_EXPERIMENTAL_LOGIC:
|
||||
if export_options is not None:
|
||||
warnings.warn(
|
||||
"You are using an experimental ONNX export logic, which currently only supports dynamic shapes. "
|
||||
"For a more comprehensive set of export options, including advanced features, please consider using "
|
||||
"`torch.onnx.export(..., dynamo=True)`. ",
|
||||
category=FutureWarning,
|
||||
)
|
||||
|
||||
if export_options is not None and export_options.dynamic_shapes:
|
||||
# Make all shapes dynamic
|
||||
def _to_dynamic_shapes_mapper():
|
||||
arg_order = 0
|
||||
|
||||
def _to_dynamic_shape(x):
|
||||
nonlocal arg_order
|
||||
if isinstance(x, torch.Tensor):
|
||||
rank = len(x.shape)
|
||||
dynamic_shape = {}
|
||||
for i in range(rank):
|
||||
dynamic_shape[i] = torch.export.Dim(
|
||||
f"arg_{arg_order}_dim_{i}"
|
||||
)
|
||||
arg_order += 1
|
||||
return dynamic_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
return _to_dynamic_shape
|
||||
|
||||
# model_args could be nested
|
||||
dynamic_shapes = _pytree.tree_map(
|
||||
_to_dynamic_shapes_mapper(),
|
||||
model_args,
|
||||
)
|
||||
else:
|
||||
dynamic_shapes = None
|
||||
|
||||
return exporter.export_compat(
|
||||
model, # type: ignore[arg-type]
|
||||
model_args,
|
||||
f=None,
|
||||
kwargs=model_kwargs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
opset_version=18,
|
||||
external_data=True,
|
||||
export_params=True,
|
||||
fallback=True,
|
||||
)
|
||||
else:
|
||||
from torch.onnx._internal._exporter_legacy import dynamo_export
|
||||
|
||||
return dynamo_export(
|
||||
model, *model_args, export_options=export_options, **model_kwargs
|
||||
)
|
||||
|
||||
|
||||
# TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
|
||||
|
||||
# Returns True iff ONNX logging is turned on.
|
||||
|
@ -16,7 +16,6 @@ from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.export as torch_export
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.onnx._internal import io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
@ -304,27 +303,17 @@ class ResolvedExportOptions(ExportOptions):
|
||||
def __init__(
|
||||
self,
|
||||
options: ExportOptions | ResolvedExportOptions,
|
||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, # type: ignore[name-defined]
|
||||
model: torch.nn.Module | Callable | None = None, # type: ignore[name-defined]
|
||||
):
|
||||
from torch.onnx._internal.fx import ( # TODO: Prevent circular dep
|
||||
diagnostics,
|
||||
dynamo_graph_extractor,
|
||||
torch_export_graph_extractor,
|
||||
)
|
||||
|
||||
if isinstance(options, ResolvedExportOptions):
|
||||
self.dynamic_shapes = options.dynamic_shapes
|
||||
self.diagnostic_options = options.diagnostic_options
|
||||
self.fake_context = options.fake_context
|
||||
# private
|
||||
if isinstance(model, torch_export.ExportedProgram) and not isinstance(
|
||||
options.fx_tracer, torch_export_graph_extractor.TorchExport
|
||||
):
|
||||
message = "'model' of type 'ExportedProgram' is only supported with 'TorchExport' FX Tracer"
|
||||
e = InvalidExportOptionsError(message)
|
||||
raise InvalidExportOptionsError(
|
||||
ONNXProgram._from_failure(e, options.diagnostic_context), message
|
||||
)
|
||||
self.fx_tracer = options.fx_tracer
|
||||
self.onnx_registry = options.onnx_registry
|
||||
self.onnxfunction_dispatcher = options.onnxfunction_dispatcher
|
||||
@ -345,10 +334,8 @@ class ResolvedExportOptions(ExportOptions):
|
||||
self.diagnostic_options = resolve(
|
||||
options.diagnostic_options, DiagnosticOptions()
|
||||
)
|
||||
if isinstance(model, torch_export.ExportedProgram):
|
||||
self.fx_tracer = torch_export_graph_extractor.TorchExport()
|
||||
else:
|
||||
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
|
||||
|
||||
self.fx_tracer = dynamo_graph_extractor.DynamoExport()
|
||||
|
||||
self.fake_context = resolve(options.fake_context, None) # type: ignore[arg-type]
|
||||
self.diagnostic_context = diagnostics.DiagnosticContext(
|
||||
@ -492,7 +479,6 @@ class ONNXProgram:
|
||||
diagnostic_context: Context object for the SARIF diagnostic system responsible for logging errors and metadata.
|
||||
fake_context: The fake context used for symbolic tracing.
|
||||
export_exception: The exception that occurred during export, if any.
|
||||
model_signature: The model signature for the exported ONNX graph.
|
||||
"""
|
||||
|
||||
_model_proto: Final[onnx.ModelProto] # type: ignore[name-defined, misc]
|
||||
@ -501,9 +487,8 @@ class ONNXProgram:
|
||||
_diagnostic_context: Final[diagnostics.DiagnosticContext] # type: ignore[misc]
|
||||
_fake_context: Final[ONNXFakeContext | None] # type: ignore[misc]
|
||||
_export_exception: Final[Exception | None] # type: ignore[misc]
|
||||
_model_signature: Final[torch.export.ExportGraphSignature | None] # type: ignore[misc]
|
||||
_model_torch: Final[ # type: ignore[misc]
|
||||
torch.nn.Module | Callable | torch_export.ExportedProgram | None
|
||||
torch.nn.Module | Callable | None
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@ -515,14 +500,9 @@ class ONNXProgram:
|
||||
*,
|
||||
fake_context: ONNXFakeContext | None = None,
|
||||
export_exception: Exception | None = None,
|
||||
model_signature: torch.export.ExportGraphSignature | None = None,
|
||||
model_torch: torch.nn.Module
|
||||
| Callable
|
||||
| torch_export.ExportedProgram
|
||||
| None = None,
|
||||
model_torch: torch.nn.Module | Callable | None = None,
|
||||
):
|
||||
self._model_proto = model_proto
|
||||
self._model_signature = model_signature
|
||||
self._model_torch = model_torch
|
||||
self._input_adapter = input_adapter
|
||||
self._output_adapter = output_adapter
|
||||
@ -533,10 +513,7 @@ class ONNXProgram:
|
||||
def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
model_with_state_dict: torch.nn.Module
|
||||
| Callable
|
||||
| torch_export.ExportedProgram
|
||||
| None = None,
|
||||
model_with_state_dict: torch.nn.Module | Callable | None = None,
|
||||
options: ONNXRuntimeOptions | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@ -571,8 +548,6 @@ class ONNXProgram:
|
||||
onnx_model = os.path.join(tmpdir_path, "model.onnx")
|
||||
if isinstance(model_with_state_dict, torch.nn.Module):
|
||||
model_state = model_with_state_dict.state_dict()
|
||||
elif isinstance(model_with_state_dict, torch_export.ExportedProgram):
|
||||
model_state = model_with_state_dict.state_dict
|
||||
else:
|
||||
model_state = None
|
||||
self.save(
|
||||
@ -608,104 +583,6 @@ class ONNXProgram:
|
||||
raise self._export_exception
|
||||
return self._model_proto
|
||||
|
||||
@property
|
||||
def model_signature(self) -> torch.export.ExportGraphSignature | None:
|
||||
"""The model signature for the exported ONNX graph.
|
||||
|
||||
This information is relevant because ONNX specification often differs from PyTorch's, resulting
|
||||
in a ONNX graph with input and output schema different from the actual PyTorch model implementation.
|
||||
By using the model signature, the users can understand the inputs and outputs differences
|
||||
and properly execute the model in ONNX Runtime.
|
||||
|
||||
NOTE: Model signature is only available when the ONNX graph was exported from a
|
||||
:class:`torch.export.ExportedProgram` object.
|
||||
|
||||
NOTE: Any transformation done to the model that changes the model signature must be accompanied
|
||||
by updates to this model signature as well through :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`.
|
||||
|
||||
Example:
|
||||
|
||||
The following model produces different sets of inputs and outputs.
|
||||
The first 4 inputs are model parameters (namely conv1.weight, conv2.weight, fc1.weight, fc2.weight),
|
||||
and the next 2 inputs are registered buffers (namely my_buffer2, my_buffer1) and finally
|
||||
the last 2 inputs are user inputs (namely x and b).
|
||||
The first output is a buffer mutation (namely my_buffer2) and the last output is the actual model output.
|
||||
|
||||
>>> import pprint
|
||||
>>> class CustomModule(torch.nn.Module):
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
... self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||
... self.register_buffer("my_buffer1", torch.tensor(3.0))
|
||||
... self.register_buffer("my_buffer2", torch.tensor(4.0))
|
||||
... self.conv1 = torch.nn.Conv2d(1, 32, 3, 1, bias=False)
|
||||
... self.conv2 = torch.nn.Conv2d(32, 64, 3, 1, bias=False)
|
||||
... self.fc1 = torch.nn.Linear(9216, 128, bias=False)
|
||||
... self.fc2 = torch.nn.Linear(128, 10, bias=False)
|
||||
...
|
||||
... def forward(self, x, b):
|
||||
... tensor_x = self.conv1(x)
|
||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
||||
... tensor_x = self.conv2(tensor_x)
|
||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
||||
... tensor_x = torch.nn.functional.max_pool2d(tensor_x, 2)
|
||||
... tensor_x = torch.flatten(tensor_x, 1)
|
||||
... tensor_x = self.fc1(tensor_x)
|
||||
... tensor_x = torch.nn.functional.sigmoid(tensor_x)
|
||||
... tensor_x = self.fc2(tensor_x)
|
||||
... output = torch.nn.functional.log_softmax(tensor_x, dim=1)
|
||||
... (
|
||||
... self.my_buffer2.add_(1.0) + self.my_buffer1
|
||||
... ) # Mutate buffer through in-place addition
|
||||
... return output
|
||||
>>> inputs = (torch.rand((64, 1, 28, 28), dtype=torch.float32), torch.randn(3))
|
||||
>>> exported_program = torch.export.export(
|
||||
... CustomModule(), args=inputs
|
||||
... ).run_decompositions({})
|
||||
>>> onnx_program = torch.onnx.dynamo_export(exported_program, *inputs)
|
||||
>>> pprint.pprint(onnx_program.model_signature)
|
||||
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv1_weight'),
|
||||
target='conv1.weight',
|
||||
persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_conv2_weight'),
|
||||
target='conv2.weight',
|
||||
persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_fc1_weight'),
|
||||
target='fc1.weight',
|
||||
persistent=None),
|
||||
InputSpec(kind=<InputKind.PARAMETER: 2>,
|
||||
arg=TensorArgument(name='p_fc2_weight'),
|
||||
target='fc2.weight',
|
||||
persistent=None),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_my_buffer2'),
|
||||
target='my_buffer2',
|
||||
persistent=True),
|
||||
InputSpec(kind=<InputKind.BUFFER: 3>,
|
||||
arg=TensorArgument(name='b_my_buffer1'),
|
||||
target='my_buffer1',
|
||||
persistent=True),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='x'),
|
||||
target=None,
|
||||
persistent=None),
|
||||
InputSpec(kind=<InputKind.USER_INPUT: 1>,
|
||||
arg=TensorArgument(name='b'),
|
||||
target=None,
|
||||
persistent=None)],
|
||||
output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>,
|
||||
arg=TensorArgument(name='add'),
|
||||
target='my_buffer2'),
|
||||
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,
|
||||
arg=TensorArgument(name='_log_softmax'),
|
||||
target=None)])
|
||||
"""
|
||||
|
||||
return self._model_signature
|
||||
|
||||
@property
|
||||
def diagnostic_context(self) -> diagnostics.DiagnosticContext:
|
||||
"""The diagnostic context associated with the export."""
|
||||
@ -721,10 +598,7 @@ class ONNXProgram:
|
||||
def adapt_torch_inputs_to_onnx(
|
||||
self,
|
||||
*model_args,
|
||||
model_with_state_dict: torch.nn.Module
|
||||
| Callable
|
||||
| torch_export.ExportedProgram
|
||||
| None = None,
|
||||
model_with_state_dict: torch.nn.Module | Callable | None = None,
|
||||
**model_kwargs,
|
||||
) -> Sequence[torch.Tensor | int | float | bool | torch.dtype]:
|
||||
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
|
||||
@ -794,10 +668,7 @@ class ONNXProgram:
|
||||
def adapt_torch_outputs_to_onnx(
|
||||
self,
|
||||
model_outputs: Any,
|
||||
model_with_state_dict: torch.nn.Module
|
||||
| Callable
|
||||
| torch_export.ExportedProgram
|
||||
| None = None,
|
||||
model_with_state_dict: torch.nn.Module | Callable | None = None,
|
||||
) -> Sequence[torch.Tensor | int | float | bool]:
|
||||
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
|
||||
|
||||
@ -1050,7 +921,7 @@ class Exporter:
|
||||
def __init__(
|
||||
self,
|
||||
options: ResolvedExportOptions,
|
||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram,
|
||||
model: torch.nn.Module | Callable,
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
):
|
||||
@ -1138,9 +1009,6 @@ class Exporter:
|
||||
self.options.fx_tracer.output_adapter,
|
||||
self.options.diagnostic_context,
|
||||
fake_context=self.options.fake_context,
|
||||
model_signature=getattr(
|
||||
self.model, "graph_signature", None
|
||||
), # Available for isinstance(self.model, ExportedProgram) only
|
||||
model_torch=self.model,
|
||||
)
|
||||
|
||||
@ -1261,12 +1129,12 @@ def _assert_dependencies(export_options: ResolvedExportOptions):
|
||||
|
||||
|
||||
def dynamo_export(
|
||||
model: torch.nn.Module | Callable | torch_export.ExportedProgram, # type: ignore[name-defined]
|
||||
model: torch.nn.Module | Callable,
|
||||
/,
|
||||
*model_args,
|
||||
export_options: ExportOptions | None = None,
|
||||
**model_kwargs,
|
||||
) -> ONNXProgram:
|
||||
) -> ONNXProgram | Any:
|
||||
"""Export a torch.nn.Module to an ONNX graph.
|
||||
|
||||
Args:
|
||||
|
@ -120,9 +120,20 @@ class TorchExportStrategy(CaptureStrategy):
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
try:
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
except torch._dynamo.exc.UserError as exc:
|
||||
# Refine the dynamic shapes based on the suggested fixes.
|
||||
new_shapes = (
|
||||
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
|
||||
exc.msg, dynamic_shapes
|
||||
)
|
||||
)
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=new_shapes
|
||||
)
|
||||
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
@ -148,9 +159,20 @@ class TorchExportNonStrictStrategy(CaptureStrategy):
|
||||
def _capture(
|
||||
self, model, args, kwargs, dynamic_shapes
|
||||
) -> torch.export.ExportedProgram:
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
|
||||
)
|
||||
try:
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
|
||||
)
|
||||
except torch._dynamo.exc.UserError as exc:
|
||||
# Refine the dynamic shapes based on the suggested fixes.
|
||||
new_shapes = (
|
||||
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
|
||||
exc.msg, dynamic_shapes
|
||||
)
|
||||
)
|
||||
return torch.export.export(
|
||||
model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False
|
||||
)
|
||||
|
||||
def _enter(self, model) -> None:
|
||||
model_repr = _take_first_line(repr(model))
|
||||
|
@ -9,7 +9,6 @@ import logging
|
||||
from typing import Any, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.export
|
||||
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
|
||||
from torch.onnx._internal.exporter import _core, _onnx_program
|
||||
|
||||
|
@ -176,7 +176,7 @@ class Transform(abc.ABC):
|
||||
|
||||
One important aspect to note is that if the transformation modifies the model input and/or output signature,
|
||||
(e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`
|
||||
are needed to reconcile :attr:`ONNXProgram.model_signature` and :attr:`ONNXProgram.model_proto`.
|
||||
are needed to reconcile :attr:`ONNXProgram.model_proto`.
|
||||
That is, the model signature and the model representation must match.
|
||||
|
||||
As an additional feature, this class provides builtin support for transformation recording using the diagnostics.
|
||||
|
@ -1,128 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# NOTE: This file is referenced by name at
|
||||
# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
|
||||
# introduced by https://github.com/pytorch/pytorch/pull/98894.
|
||||
# If this file is renamed, moved, etc please update the reference there!
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
from torch.onnx._internal import _exporter_legacy, io_adapter
|
||||
from torch.onnx._internal.diagnostics import infra
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.onnx
|
||||
from torch.export.exported_program import ExportedProgram
|
||||
|
||||
|
||||
class TorchExport(_exporter_legacy.FXGraphExtractor):
|
||||
"""Generates a FX GraphModule using torch.export API
|
||||
Args:
|
||||
aten_graph: If True, exports a graph with ATen operators.
|
||||
If False, exports a graph with Python operators.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aten_graph: bool | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.aten_graph = aten_graph or True
|
||||
|
||||
def generate_fx(
|
||||
self,
|
||||
options: _exporter_legacy.ResolvedExportOptions,
|
||||
model: ExportedProgram, # type: ignore[override]
|
||||
model_args: Sequence[Any],
|
||||
model_kwargs: Mapping[str, Any],
|
||||
) -> torch.fx.GraphModule:
|
||||
# No need to translate callable to FX graph.
|
||||
# This FX Graph extractor assumes `model` was obtained through
|
||||
# exported_program = torch.export.export(
|
||||
# model,
|
||||
# args=model_args, # type: ignore[arg-type]
|
||||
# kwargs=model_kwargs, # type: ignore[arg-type]
|
||||
# )
|
||||
|
||||
# Export FX graph to ONNX ModelProto.
|
||||
self.input_adapter.append_step(
|
||||
io_adapter.FlattenInputWithTreeSpecValidationInputStep()
|
||||
)
|
||||
self.input_adapter.append_step(
|
||||
io_adapter.PrependParamsBuffersConstantAotAutogradInputStep()
|
||||
)
|
||||
|
||||
# ONNX does not support None inputs. During graph building, all None inputs
|
||||
# are removed. Here we register this step to input adapter.
|
||||
options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep())
|
||||
|
||||
# NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534
|
||||
# Dynamo doesn't support non-tensor inputs.
|
||||
options.fx_tracer.input_adapter.append_step(
|
||||
io_adapter.RemoveNonTensorInputStep()
|
||||
)
|
||||
|
||||
# ONNX does not support complex inputs. During graph building, all complex inputs
|
||||
# are converted to real representation inputs. Here we register this step to
|
||||
# input/output adapter.
|
||||
options.fx_tracer.input_adapter.append_step(
|
||||
io_adapter.ConvertComplexToRealRepresentationInputStep()
|
||||
)
|
||||
|
||||
updated_model_args = self.input_adapter.apply(
|
||||
*model_args, model=model, **model_kwargs
|
||||
)
|
||||
|
||||
# ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
|
||||
# tensor, etc), we flatten the collection and register each element as output.
|
||||
options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep())
|
||||
|
||||
# Output post-processing steps should happen after `FlattenOutputStep`.
|
||||
options.fx_tracer.output_adapter.append_step(
|
||||
io_adapter.ConvertComplexToRealRepresentationOutputStep()
|
||||
)
|
||||
|
||||
options.fx_tracer.output_adapter.append_step(
|
||||
io_adapter.PrependParamsAndBuffersAotAutogradOutputStep()
|
||||
)
|
||||
|
||||
# run_decomposition generates a new graph module with decomposed ops.
|
||||
# Thus, we need to run this step after io_adapters.
|
||||
model = model.run_decompositions(options.decomposition_table)
|
||||
|
||||
# Export FX graph to ONNX ModelProto.
|
||||
return self.pre_export_passes( # type: ignore[return-value]
|
||||
options, model, model.graph_module, updated_model_args
|
||||
)
|
||||
|
||||
def pre_export_passes(
|
||||
self,
|
||||
options: _exporter_legacy.ResolvedExportOptions,
|
||||
original_model: torch.nn.Module | Callable,
|
||||
fx_module: torch.fx.GraphModule,
|
||||
fx_module_args: Sequence[Any],
|
||||
):
|
||||
# TODO: Import here to prevent circular dependency
|
||||
from torch.onnx._internal.fx import analysis, passes
|
||||
|
||||
diagnostic_context = options.diagnostic_context
|
||||
|
||||
# ONNX does not support concept of (implicit) type promotion.
|
||||
# Insert type casts explicitly where needed.
|
||||
fx_module = passes.InsertTypePromotion(diagnostic_context, fx_module).run()
|
||||
|
||||
analysis.UnsupportedFxNodesAnalysis(
|
||||
diagnostic_context, fx_module, options.onnxfunction_dispatcher
|
||||
).analyze(infra.levels.ERROR)
|
||||
|
||||
# This operation should be invoked as the last pre export pass.
|
||||
# See [NOTE: Modularize pass ordering]
|
||||
fx_module = passes.Modularize(
|
||||
diagnostic_context, fx_module, is_exported_program=True
|
||||
).run()
|
||||
|
||||
return fx_module
|
Reference in New Issue
Block a user