mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ONNX] Delete ONNXProgramSerializer (#135261)
Fixes #135182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135261 Approved by: https://github.com/justinchuby
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2386bdca1
commit
28ccfba248
@ -28,7 +28,6 @@ The exporter is designed to be modular and extensible. It is composed of the fol
|
||||
- **FX Graph Extractor**: :class:`FXGraphExtractor` extracts the FX graph from the PyTorch model.
|
||||
- **Fake Mode**: :class:`ONNXFakeContext` is a context manager that enables fake mode for large scale models.
|
||||
- **ONNX Program**: :class:`ONNXProgram` is the output of the exporter that contains the exported ONNX graph and diagnostics.
|
||||
- **ONNX Program Serializer**: :class:`ONNXProgramSerializer` serializes the exported model to a file.
|
||||
- **ONNX Diagnostic Options**: :class:`DiagnosticOptions` has a set of options that control the diagnostics emitted by the exporter.
|
||||
|
||||
Dependencies
|
||||
@ -144,9 +143,6 @@ API Reference
|
||||
.. autoclass:: torch.onnx.ONNXProgram
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.onnx.ONNXProgramSerializer
|
||||
:members:
|
||||
|
||||
.. autoclass:: torch.onnx.ONNXRuntimeOptions
|
||||
:members:
|
||||
|
||||
|
@ -7,10 +7,7 @@ import onnx
|
||||
import torch
|
||||
from torch.onnx import dynamo_export, ExportOptions, ONNXProgram
|
||||
from torch.onnx._internal import _exporter_legacy
|
||||
from torch.onnx._internal._exporter_legacy import (
|
||||
ONNXProgramSerializer,
|
||||
ResolvedExportOptions,
|
||||
)
|
||||
from torch.onnx._internal._exporter_legacy import ResolvedExportOptions
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
@ -75,40 +72,6 @@ class TestDynamoExportAPI(common_utils.TestCase):
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(buffer)
|
||||
onnx.load(buffer)
|
||||
|
||||
def test_save_to_file_using_specified_serializer(self):
|
||||
expected_buffer = "I am not actually ONNX"
|
||||
|
||||
class CustomSerializer(ONNXProgramSerializer):
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
destination.write(expected_buffer.encode())
|
||||
|
||||
with common_utils.TemporaryFileName() as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
|
||||
path, serializer=CustomSerializer()
|
||||
)
|
||||
with open(path) as fp:
|
||||
self.assertEqual(fp.read(), expected_buffer)
|
||||
|
||||
def test_save_to_file_using_specified_serializer_without_inheritance(self):
|
||||
expected_buffer = "I am not actually ONNX"
|
||||
|
||||
# NOTE: Inheritance from `ONNXProgramSerializer` is not required.
|
||||
# Because `ONNXProgramSerializer` is a Protocol class.
|
||||
class CustomSerializer:
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
destination.write(expected_buffer.encode())
|
||||
|
||||
with common_utils.TemporaryFileName() as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save(
|
||||
path, serializer=CustomSerializer()
|
||||
)
|
||||
with open(path) as fp:
|
||||
self.assertEqual(fp.read(), expected_buffer)
|
||||
|
||||
def test_save_sarif_log_to_file_with_successful_export(self):
|
||||
with common_utils.TemporaryFileName(suffix=".sarif") as path:
|
||||
dynamo_export(SampleModel(), torch.randn(1, 1, 2)).save_diagnostics(path)
|
||||
|
@ -44,7 +44,6 @@ __all__ = [
|
||||
"DiagnosticOptions",
|
||||
"ExportOptions",
|
||||
"ONNXProgram",
|
||||
"ONNXProgramSerializer",
|
||||
"ONNXRuntimeOptions",
|
||||
"InvalidExportOptionsError",
|
||||
"OnnxExporterError",
|
||||
@ -109,7 +108,6 @@ from ._internal._exporter_legacy import ( # usort: skip. needs to be last to av
|
||||
DiagnosticOptions,
|
||||
ExportOptions,
|
||||
ONNXProgram,
|
||||
ONNXProgramSerializer,
|
||||
ONNXRuntimeOptions,
|
||||
InvalidExportOptionsError,
|
||||
OnnxExporterError,
|
||||
@ -127,7 +125,6 @@ ExportTypes.__module__ = "torch.onnx"
|
||||
JitScalarType.__module__ = "torch.onnx"
|
||||
ExportOptions.__module__ = "torch.onnx"
|
||||
ONNXProgram.__module__ = "torch.onnx"
|
||||
ONNXProgramSerializer.__module__ = "torch.onnx"
|
||||
ONNXRuntimeOptions.__module__ = "torch.onnx"
|
||||
dynamo_export.__module__ = "torch.onnx"
|
||||
InvalidExportOptionsError.__module__ = "torch.onnx"
|
||||
|
@ -11,17 +11,7 @@ import os
|
||||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
Mapping,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
@ -462,97 +452,6 @@ def enable_fake_mode():
|
||||
) # type: ignore[assignment]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ONNXProgramSerializer(Protocol):
|
||||
"""Protocol for serializing an ONNX graph into a specific format (e.g. Protobuf).
|
||||
Note that this is an advanced usage scenario."""
|
||||
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
"""Protocol method that must be implemented for serialization.
|
||||
|
||||
Args:
|
||||
onnx_program: Represents the in-memory exported ONNX model
|
||||
destination: A binary IO stream or pre-allocated buffer into which
|
||||
the serialized model should be written.
|
||||
|
||||
Example:
|
||||
|
||||
A simple serializer that writes the exported :py:obj:`onnx.ModelProto` in Protobuf
|
||||
format to ``destination``:
|
||||
|
||||
::
|
||||
|
||||
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
||||
>>> import io
|
||||
>>> import torch
|
||||
>>> import torch.onnx
|
||||
>>> class MyModel(torch.nn.Module): # Dummy model
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
... self.linear = torch.nn.Linear(2, 2)
|
||||
... def forward(self, x):
|
||||
... out = self.linear(x)
|
||||
... return out
|
||||
>>> class ProtobufONNXProgramSerializer:
|
||||
... def serialize(
|
||||
... self, onnx_program: torch.onnx.ONNXProgram, destination: io.BufferedIOBase
|
||||
... ) -> None:
|
||||
... destination.write(onnx_program.model_proto.SerializeToString())
|
||||
>>> model = MyModel()
|
||||
>>> arg1 = torch.randn(2, 2, 2) # positional input 1
|
||||
>>> torch.onnx.dynamo_export(model, arg1).save(
|
||||
... destination="exported_model.onnx",
|
||||
... serializer=ProtobufONNXProgramSerializer(),
|
||||
... )
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ProtobufONNXProgramSerializer:
|
||||
"""Serializes ONNX graph as Protobuf."""
|
||||
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
import onnx
|
||||
|
||||
if not isinstance(onnx_program.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
|
||||
raise ValueError("onnx_program.ModelProto is not an onnx.ModelProto")
|
||||
destination.write(onnx_program.model_proto.SerializeToString())
|
||||
|
||||
|
||||
class LargeProtobufONNXProgramSerializer:
|
||||
"""Serializes ONNX graph as Protobuf.
|
||||
|
||||
Fallback to serializing as Protobuf with external data for models larger than 2GB.
|
||||
"""
|
||||
|
||||
_destination_path: Final[str] # type: ignore[misc]
|
||||
|
||||
def __init__(self, destination_path: str):
|
||||
self._destination_path = destination_path
|
||||
|
||||
def serialize(
|
||||
self, onnx_program: ONNXProgram, destination: io.BufferedIOBase
|
||||
) -> None:
|
||||
"""`destination` is ignored. The model is saved to `self._destination_path` instead."""
|
||||
import onnx
|
||||
|
||||
if onnx_program.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT:
|
||||
onnx.save_model(onnx_program.model_proto, self._destination_path) # type: ignore[attr-defined]
|
||||
else:
|
||||
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
|
||||
# Fallback to serializing the model with external data.
|
||||
onnx.save_model( # type: ignore[attr-defined]
|
||||
onnx_program.model_proto,
|
||||
self._destination_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
|
||||
class ONNXRuntimeOptions:
|
||||
"""Options to influence the execution of the ONNX model through ONNX Runtime.
|
||||
|
||||
@ -959,7 +858,6 @@ class ONNXProgram:
|
||||
*,
|
||||
include_initializers: bool = True,
|
||||
model_state: dict[str, Any] | str | None = None,
|
||||
serializer: ONNXProgramSerializer | None = None,
|
||||
) -> None:
|
||||
"""Saves the in-memory ONNX model to ``destination`` using specified ``serializer``.
|
||||
|
||||
@ -975,17 +873,12 @@ class ONNXProgram:
|
||||
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
|
||||
The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
|
||||
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
|
||||
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
|
||||
"""
|
||||
import onnx
|
||||
|
||||
assert (
|
||||
include_initializers is True or model_state is None
|
||||
), "Cannot specify both `include_initializers=False` and `model_state`."
|
||||
if serializer is None:
|
||||
if isinstance(destination, str):
|
||||
serializer = LargeProtobufONNXProgramSerializer(destination)
|
||||
else:
|
||||
serializer = ProtobufONNXProgramSerializer()
|
||||
|
||||
# Add initializers when symbolic tracing is enabled
|
||||
_model_state_files: list[str | io.BytesIO | dict[str, Any]] = []
|
||||
@ -1031,10 +924,24 @@ class ONNXProgram:
|
||||
else:
|
||||
if isinstance(destination, str):
|
||||
with open(destination, "wb") as f:
|
||||
serializer.serialize(self, f)
|
||||
if self.model_proto.ByteSize() < _PROTOBUF_SIZE_MAX_LIMIT:
|
||||
onnx.save_model(self.model_proto, destination) # type: ignore[attr-defined]
|
||||
else:
|
||||
# ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
|
||||
# Fallback to serializing the model with external data.
|
||||
onnx.save_model( # type: ignore[attr-defined]
|
||||
self.model_proto,
|
||||
destination,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
serializer.serialize(self, destination)
|
||||
if not isinstance(self.model_proto, onnx.ModelProto): # type: ignore[attr-defined]
|
||||
raise ValueError(
|
||||
"onnx_program.ModelProto is not an onnx.ModelProto"
|
||||
)
|
||||
destination.write(self.model_proto.SerializeToString())
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"'destination' should be provided as a path-like string when saving a model larger than 2GB. "
|
||||
@ -1544,7 +1451,6 @@ __all__ = [
|
||||
"DiagnosticOptions",
|
||||
"ExportOptions",
|
||||
"ONNXProgram",
|
||||
"ONNXProgramSerializer",
|
||||
"ONNXRuntimeOptions",
|
||||
"InvalidExportOptionsError",
|
||||
"OnnxExporterError",
|
||||
|
Reference in New Issue
Block a user