[ONNX] Delete ONNXProgramSerializer (#135261)

Fixes #135182

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135261
Approved by: https://github.com/justinchuby
This commit is contained in:
titaiwangms
2024-09-05 23:52:49 +00:00
committed by PyTorch MergeBot
parent b2386bdca1
commit 28ccfba248
4 changed files with 19 additions and 157 deletions

View File

@ -28,7 +28,6 @@ The exporter is designed to be modular and extensible. It is composed of the fol
- **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model.
- **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models.
- **ONNX Program**: :class:`ONNXProgram` is the output of the exporter that contains the exported ONNX graph and diagnostics.
- **ONNX Program Serializer**: :class:`ONNXProgramSerializer` serializes the exported model to a file.
- **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter.
Dependencies
@ -144,9 +143,6 @@ API Reference
.. autoclass:: torch.onnx.ONNXProgram
:members:
.. autoclass:: torch.onnx.ONNXProgramSerializer
:members:
.. autoclass:: torch.onnx.ONNXRuntimeOptions
:members:

View File

@ -7,10 +7,7 @@ import onnx
import torch
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
from torch.onnx._internal import _exporter_legacy
from torch.onnx._internal._exporter_legacy import (
ONNXProgramSerializer,
ResolvedExportOptions,
)
from torch.onnx._internal._exporter_legacy import ResolvedExportOptions
from torch.testing._internal import common_utils
@ -75,40 +72,6 @@ class TestDynamoExportAPI(common_utils.TestCase):
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
onnx.load(buffer)
def test_save_to_file_using_specified_serializer(self):
expected_buffer = "I am not actually ONNX"
class CustomSerializer(ONNXProgramSerializer):
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path) as fp:
self.assertEqual(fp.read(), expected_buffer)
def test_save_to_file_using_specified_serializer_without_inheritance(self):
expected_buffer = "I am not actually ONNX"
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
# Because `ONNXProgramSerializer` is a Protocol class.
class CustomSerializer:
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
destination.write(expected_buffer.encode())
with common_utils.TemporaryFileName() as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
path, serializer=CustomSerializer()
)
with open(path) as fp:
self.assertEqual(fp.read(), expected_buffer)
def test_save_sarif_log_to_file_with_successful_export(self):
with common_utils.TemporaryFileName(suffix=".sarif") as path:
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)

View File

@ -44,7 +44,6 @@ __all__ = [
"DiagnosticOptions",
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",
@ -109,7 +108,6 @@ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to av
DiagnosticOptions,
ExportOptions,
ONNXProgram,
ONNXProgramSerializer,
ONNXRuntimeOptions,
InvalidExportOptionsError,
OnnxExporterError,
@ -127,7 +125,6 @@ ExportTypes.__module__ = "torch.onnx"
JitScalarType.__module__ = "torch.onnx"
ExportOptions.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
ONNXProgramSerializer.__module__ = "torch.onnx"
ONNXRuntimeOptions.__module__ = "torch.onnx"
dynamo_export.__module__ = "torch.onnx"
InvalidExportOptionsError.__module__ = "torch.onnx"

View File

@ -11,17 +11,7 @@ import os
import tempfile
import warnings
from collections import defaultdict
from typing import (
Any,
Callable,
Final,
Mapping,
Protocol,
runtime_checkable,
Sequence,
TYPE_CHECKING,
TypeVar,
)
from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar
from typing_extensions import Self
import torch
@ -462,97 +452,6 @@ def enable_fake_mode():
) # type: ignore[assignment]
@runtime_checkable
class ONNXProgramSerializer(Protocol):
"""Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf).
Note that this is an advanced usage scenario."""
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
"""Protocol method that must be implemented for serialization.
Args:
onnx_program: Represents the in-memory exported ONNX model
destination: A binary IO stream or pre-allocated buffer into which
the serialized model should be written.
Example:
A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf
format to ``destination``:
::
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import io
>>> import torch
>>> import torch.onnx
>>> class MyModel(torch.nn.Module): # Dummy model
... def __init__(self) -> None:
... super().__init__()
... self.linear = torch.nn.Linear(2, 2)
... def forward(self, x):
... out = self.linear(x)
... return out
>>> class ProtobufONNXProgramSerializer:
... def serialize(
... self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase
... ) -> None:
... destination.write(onnx_program.model_proto.SerializeToString())
>>> model = MyModel()
>>> arg1 = torch.randn(2, 2, 2) # positional input 1
>>> torch.onnx.dynamo_export(model, arg1).save(
... destination="exported_model.onnx",
... serializer=ProtobufONNXProgramSerializer(),
... )
"""
...
class ProtobufONNXProgramSerializer:
"""Serializes ONNX graph as Protobuf."""
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
import onnx
if not isinstance(onnx_program.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
raise ValueError("onnx_program.ModelProto is not an onnx.ModelProto")
destination.write(onnx_program.model_proto.SerializeToString())
class LargeProtobufONNXProgramSerializer:
"""Serializes ONNX graph as Protobuf.
Fallback to serializing as Protobuf with external data for models larger than 2GB.
"""
_destination_path: Final[str] # type: ignore[misc]
def __init__(self, destination_path: str):
self._destination_path = destination_path
def serialize(
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
) -> None:
"""`destination` is ignored. The model is saved to `self._destination_path` instead."""
import onnx
if onnx_program.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT:
onnx.save_model(onnx_program.model_proto, self._destination_path) # type: ignore[attr-defined]
else:
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
# Fallback to serializing the model with external data.
onnx.save_model( # type: ignore[attr-defined]
onnx_program.model_proto,
self._destination_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
class ONNXRuntimeOptions:
"""Options to influence the execution of the ONNX model through ONNX Runtime.
@ -959,7 +858,6 @@ class ONNXProgram:
*,
include_initializers: bool = True,
model_state: dict[str, Any] | str | None = None,
serializer: ONNXProgramSerializer | None = None,
) -> None:
"""Saves the in-memory ONNX model to ``destination`` using specified ``serializer``.
@ -975,17 +873,12 @@ class ONNXProgram:
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
"""
import onnx
assert (
include_initializers is True or model_state is None
), "Cannot specify both `include_initializers=False` and `model_state`."
if serializer is None:
if isinstance(destination, str):
serializer = LargeProtobufONNXProgramSerializer(destination)
else:
serializer = ProtobufONNXProgramSerializer()
# Add initializers when symbolic tracing is enabled
_model_state_files: list[str | io.BytesIO | dict[str, Any]] = []
@ -1031,10 +924,24 @@ class ONNXProgram:
else:
if isinstance(destination, str):
with open(destination, "wb") as f:
serializer.serialize(self, f)
if self.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT:
onnx.save_model(self.model_proto, destination) # type: ignore[attr-defined]
else:
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
# Fallback to serializing the model with external data.
onnx.save_model( # type: ignore[attr-defined]
self.model_proto,
destination,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
else:
try:
serializer.serialize(self, destination)
if not isinstance(self.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
raise ValueError(
"onnx_program.ModelProto is not an onnx.ModelProto"
)
destination.write(self.model_proto.SerializeToString())
except ValueError as exc:
raise ValueError(
"'destination' should be provided as a path-like string when saving a model larger than 2GB. "
@ -1544,7 +1451,6 @@ __all__ = [
"DiagnosticOptions",
"ExportOptions",
"ONNXProgram",
"ONNXProgramSerializer",
"ONNXRuntimeOptions",
"InvalidExportOptionsError",
"OnnxExporterError",