[ONNX] Remove enable_fake_mode and exporter_legacy (#161222)

Remove enable_fake_mode and exporter_legacy entirely. Even though this is bc breaking, `enable_fake_mode` is no longer compatible with the latest version of transformers, and so it is no longer useful.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161222
Approved by: https://github.com/titaiwangms
This commit is contained in:
Justin Chu
2025-08-22 22:15:27 +00:00
committed by PyTorch MergeBot
parent 3373b074f5
commit 419a2dbf5f
10 changed files with 42 additions and 844 deletions

View File

@ -84,8 +84,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/
:noindex:
.. autofunction:: is_in_onnx_export
:noindex:
.. autofunction:: enable_fake_mode
:noindex:
```
### Classes

View File

@ -245,5 +245,4 @@ Each initialized value, input, output has the following metadata:
.. autofunction:: torch.onnx.is_in_onnx_export
.. autoclass:: torch.onnx.OnnxExporterError
:members:
.. autofunction:: torch.onnx.enable_fake_mode
```

View File

@ -7,11 +7,9 @@ import io
import logging
import os
import numpy as np
from onnxscript import BOOL, FLOAT, ir, opset18 as op
from onnxscript import BOOL, FLOAT, opset18 as op
import torch
import torch.onnx._flags
from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.testing._internal import common_utils
@ -339,6 +337,47 @@ class TestExportAPIDynamo(common_utils.TestCase):
),
)
def test_is_in_onnx_export(self):
class Mod(torch.nn.Module):
def forward(self, x):
def f(x):
return x.sin() if torch.onnx.is_in_onnx_export() else x.cos()
return f(x)
self.assertFalse(torch.onnx.is_in_onnx_export())
onnx_program = torch.onnx.export(
Mod(),
(torch.randn(3, 4),),
dynamo=True,
fallback=False,
)
self.assertFalse(torch.onnx.is_in_onnx_export())
node_names = [n.op_type for n in onnx_program.model.graph]
self.assertIn("Sin", node_names)
def test_torchscript_exporter_raises_deprecation_warning(self):
# Test that the deprecation warning is raised when using torchscript exporter
with self.assertWarnsRegex(
DeprecationWarning, "You are using the legacy TorchScript-based ONNX export"
):
torch.onnx.export(
SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False
)
def test_model_output_can_be_none(self):
class ModelWithNoneOutput(torch.nn.Module):
def forward(self, x):
return x + 1, None
onnx_program = torch.onnx.export(
ModelWithNoneOutput(),
(torch.randn(1, 1, 2),),
dynamo=True,
)
onnx_testing.assert_onnx_program(onnx_program)
class TestCustomTranslationTable(common_utils.TestCase):
def test_custom_translation_table_overrides_ops(self):
@ -471,147 +510,5 @@ class TestCustomTranslationTable(common_utils.TestCase):
self.assertNotIn("Sub", all_nodes_decomp)
class TestFakeTensorExport(common_utils.TestCase):
"""Test exporting in fake mode."""
def test_onnx_program_raises_when_model_defined_in_fake_mode(self):
with torch.onnx.enable_fake_mode():
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor(42.0))
def forward(self, x):
return self.weight + x
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1.0),), dynamo=True, optimize=False
)
assert onnx_program is not None
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
with self.assertRaises(Exception):
# The tensors need to be replaced with real tensors
_ = onnx_program.model_proto
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
with self.assertRaises(Exception):
# It doesn't matter if it is called inside or outside of the enable_fake_mode() context
_ = onnx_program.model_proto
# If we replace with concrete tensors, the serialization will succeed.
# This needs to happen outside of the fake context
onnx_program.apply_weights({"weight": torch.tensor(42.0)})
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
np.testing.assert_allclose(
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
)
def test_onnx_program_save_raises_when_model_initialized_in_fake_mode(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor(42.0))
def forward(self, x):
return self.weight + x
with torch.onnx.enable_fake_mode():
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1.0),), dynamo=True, optimize=False
)
assert onnx_program is not None
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
with self.assertRaises(Exception):
# The tensors need to be replaced with real tensors
_ = onnx_program.model_proto
with self.assertRaises(Exception):
# It doesn't matter if it is called inside or outside of the enable_fake_mode() context
_ = onnx_program.model_proto
# If we replace with concrete tensors, the serialization will succeed
# This needs to happen outside of the fake context
onnx_program.apply_weights({"weight": torch.tensor(42.0)})
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
np.testing.assert_allclose(
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
)
def test_onnx_program_save_succeeds_when_export_and_save_in_fake_mode(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor(42.0))
def forward(self, x):
return self.weight + x
real_model = Model()
with torch.onnx.enable_fake_mode():
onnx_program = torch.onnx.export(
real_model, (torch.tensor(1.0),), dynamo=True, optimize=False
)
assert onnx_program is not None
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
# Note that even though we are calling .model_proto (equivalently .save()) in fake mode,
# the concrete tensors are maintained.
# This is due to the usage of torch._subclasses.fake_tensor.unset_fake_temporarily() in
# TorchTensor.tobytes()
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
np.testing.assert_allclose(
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
)
# This works inside or outside the fake mode
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
np.testing.assert_allclose(
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
)
def test_is_in_onnx_export(self):
class Mod(torch.nn.Module):
def forward(self, x):
def f(x):
return x.sin() if torch.onnx.is_in_onnx_export() else x.cos()
return f(x)
self.assertFalse(torch.onnx.is_in_onnx_export())
onnx_program = torch.onnx.export(
Mod(),
(torch.randn(3, 4),),
dynamo=True,
fallback=False,
)
self.assertFalse(torch.onnx.is_in_onnx_export())
node_names = [n.op_type for n in onnx_program.model.graph]
self.assertIn("Sin", node_names)
def test_torchscript_exporter_raises_deprecation_warning(self):
# Test that the deprecation warning is raised when using torchscript exporter
with self.assertWarnsRegex(
DeprecationWarning, "You are using the legacy TorchScript-based ONNX export"
):
torch.onnx.export(
SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False
)
def test_model_output_can_be_none(self):
class ModelWithNoneOutput(torch.nn.Module):
def forward(self, x):
return x + 1, None
onnx_program = torch.onnx.export(
ModelWithNoneOutput(),
(torch.randn(1, 1, 2),),
dynamo=True,
)
onnx_testing.assert_onnx_program(onnx_program)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -1,60 +0,0 @@
# Owner(s): ["module: onnx"]
import torch
import torch._dynamo
import torch.fx
from torch.onnx._internal.fx.passes import _utils as pass_utils
from torch.testing._internal import common_utils
class TestFxPasses(common_utils.TestCase):
def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
def func(x, y, z):
return x + y + z
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func)(x, y, z)
torch._dynamo.reset()
# Purposely name the nodes in a way that will cause a recursive collision later.
# See :func:`set_node_name` for name collision renaming logic.
base_name = "tensor"
nodes = list(gm.graph.nodes)
for i, node in enumerate(nodes[1:]):
if i == 0:
node.name = base_name
else:
node.name = f"{base_name}.{i}"
# Run `set_node_name` and verify that the names are correct.
name_to_node = {node.name: node for node in gm.graph.nodes}
pass_utils.set_node_name(nodes[0], base_name, name_to_node)
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(nodes), (
f"Expected all names to be unique, got {nodes}"
)
def test_set_node_name_succeeds_when_no_name_collisions(self):
def func(x, y, z):
return x + y + z
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
gm, _ = torch._dynamo.export(func)(x, y, z)
torch._dynamo.reset()
# Run `set_node_name` and verify that the names are correct.
new_name = "some_tensor"
nodes = list(gm.graph.nodes)
name_to_node = {node.name: node for node in nodes}
pass_utils.set_node_name(nodes[1], new_name, name_to_node)
assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
assert len({node.name for node in nodes}) == len(nodes), (
f"Expected all names to be unique, got {nodes}"
)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -37,7 +37,6 @@ __all__ = [
# Base error
"OnnxExporterError",
"ONNXProgram",
"enable_fake_mode",
]
from typing import Any, Callable, TYPE_CHECKING
@ -47,7 +46,6 @@ import torch
from torch._C import _onnx as _C_onnx
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from ._internal._exporter_legacy import enable_fake_mode
from ._internal.exporter._onnx_program import ONNXProgram
from ._type_utils import JitScalarType
from .errors import OnnxExporterError
@ -90,7 +88,6 @@ if TYPE_CHECKING:
JitScalarType.__module__ = "torch.onnx"
ONNXProgram.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
enable_fake_mode.__module__ = "torch.onnx"
producer_name = "pytorch"
producer_version = _C_onnx.PRODUCER_VERSION

View File

@ -1,118 +0,0 @@
# mypy: allow-untyped-defs
from __future__ import annotations
__all__ = [
"enable_fake_mode",
]
import contextlib
import dataclasses
import logging
from typing import Any, TYPE_CHECKING
import torch
import torch._ops
from torch.onnx._internal.fx import patcher as patcher
# We can only import onnx from this module in a type-checking context to ensure that
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
if TYPE_CHECKING:
import io
from torch._subclasses import fake_tensor
log = logging.getLogger(__name__)
@dataclasses.dataclass
class ONNXFakeContext:
"""A dataclass used to store context for model export using FakeTensor.
This dataclass stores the FakeTensorMode instance used to convert
real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is
reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`.
"""
fake_mode: fake_tensor.FakeTensorMode
"""The fake tensor mode used for tracing model using fake tensors and parameters."""
state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None
"""List of paths of files that contain the model :meth:`state_dict`"""
@contextlib.contextmanager
def enable_fake_mode():
"""Enable fake mode for the duration of the context.
Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager
that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`.
A :class:`torch._subclasses.fake_tensor.FakeTensor`
is a :class:`torch.Tensor` with the ability to run PyTorch code without having to
actually do computation through tensors allocated on a ``meta`` device. Because
there is no actual data being allocated on the device, this API allows for
initializing and exporting large models without the actual memory footprint needed for executing it.
It is highly recommended to initialize the model in fake mode when exporting models that
are too large to fit into memory.
.. note::
This function does not support torch.onnx.export(..., dynamo=True, optimize=True).
Please call ONNXProgram.optimize() outside of the function after the model is exported.
Example::
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> class MyModel(torch.nn.Module): # Model with a parameter
... def __init__(self) -> None:
... super().__init__()
... self.weight = torch.nn.Parameter(torch.tensor(42.0))
... def forward(self, x):
... return self.weight + x
>>> with torch.onnx.enable_fake_mode():
... # When initialized in fake mode, the model's parameters are fake tensors
... # They do not take up memory so we can initialize large models
... my_nn_module = MyModel()
... arg1 = torch.randn(2, 2, 2)
>>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True, optimize=False)
>>> # Saving model WITHOUT initializers (only the architecture)
>>> onnx_program.save(
... "my_model_without_initializers.onnx",
... include_initializers=False,
... keep_initializers_as_inputs=True,
... )
>>> # Saving model WITH initializers after applying concrete weights
>>> onnx_program.apply_weights({"weight": torch.tensor(42.0)})
>>> onnx_program.save("my_model_with_initializers.onnx")
.. warning::
This API is experimental and is *NOT* backward-compatible.
"""
from torch._subclasses import fake_tensor
from torch.fx.experimental.symbolic_shapes import ShapeEnv
# This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1].
# It is a good idea to keep them in sync (constructor args) to maintain the same default behavior
# [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__`
# Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode`
# This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run
fake_mode = fake_tensor.FakeTensorMode(
allow_non_fake_inputs=not torch._guards.detect_fake_mode(),
shape_env=ShapeEnv(
allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False
),
)
# The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode
patcher_context = patcher.ONNXTorchPatcher()
fake_context = ONNXFakeContext(fake_mode=fake_mode)
with fake_mode, patcher_context:
yield fake_context
fake_context.state_dict_paths = tuple(
patcher_context.paths,
) # type: ignore[assignment]

View File

@ -1,8 +0,0 @@
from .patcher import ONNXTorchPatcher
from .serialization import save_model_with_external_data
__all__ = [
"save_model_with_external_data",
"ONNXTorchPatcher",
]

View File

@ -1,114 +0,0 @@
# mypy: allow-untyped-defs
"""Common utility functions for FX passes.
These functions should NOT be directly invoked outside of `passes` package.
"""
from __future__ import annotations
import collections
import re
from typing import Callable
import torch.fx
import torch.fx.traceback as fx_traceback
def wrap_graph_module_for_node_meta_preservation(
graph_module: torch.fx.GraphModule,
) -> Callable:
"""Wrap a GraphModule with contexts to preserve node meta information, such as stacktrace info.
This is typically useful before calling `make_fx`. Without this wrapper, the
stacktrace information will be lost afterwards.
"""
def wrapped(*args):
with fx_traceback.preserve_node_meta():
return torch.fx.Interpreter(graph_module).run(*args)
return wrapped
def _get_node_base_name(node_name: str) -> tuple[str, int | None]:
pattern = r"(.*)\.(\d+)"
match = re.match(pattern, node_name)
if match is not None:
base_name, count_str = match.groups()
return base_name, int(count_str)
return node_name, None
def set_node_name(
node: torch.fx.Node,
new_name: str,
name_to_node_cache: dict[str, torch.fx.Node],
):
"""Safely set the unique name of a node.
If the new name is already taken by another node, the name of the other node will be
updated. If `new_name` is a string of format f"{base_name}.{count}", where `count`
is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not,
the other node will be renamed as "{new_name}.1". This function will iteratively
update the names until there is no conflict.
``name_to_node_cache`` is required as an argument to avoid recomputation. The caller
is responsible for ensuring the cache is accurate and in sync with the owning module
of the node. The values in the cache will be updated accordingly.
Args:
node: The node to update.
new_name: The new name to use.
name_to_node_cache: A cache of node names to nodes.
"""
node_name_to_set = collections.deque([(node, new_name)])
while node_name_to_set:
node, new_name = node_name_to_set.pop()
if new_name in name_to_node_cache and name_to_node_cache[new_name] != node:
base_name, postfix_count = _get_node_base_name(new_name)
if postfix_count is None:
postfix_count = 0
node_name_to_set.append(
(name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}")
)
node.name = new_name
name_to_node_cache[new_name] = node
def replace_placeholder_name_and_target(
module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
):
"""Replace the argument names in module with those in reference_module.
This function assumes the two modules have the same signature structure.
The caller is responsible for ensuring this. Otherwise, the behavior of this
function is undefined. This function only does minimal sanity check that the two
modules have the same number of arguments.
Name conflicts between new names and existing node names in the graph are handled.
Check the documentation of :func:`set_node_name` for more details.
Raises:
RuntimeError: If the two modules have different number of arguments.
"""
placeholders = [node for node in module.graph.nodes if node.op == "placeholder"]
reference_placeholders = [
node for node in reference_module.graph.nodes if node.op == "placeholder"
]
if len(placeholders) != len(reference_placeholders):
raise RuntimeError(
"The two modules have different number of arguments. "
f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}"
)
name_to_node: dict[str, torch.fx.Node] = {}
for node in module.graph.nodes:
name_to_node[node.name] = node
for placeholder, reference_placeholder in zip(placeholders, reference_placeholders):
placeholder.target = reference_placeholder.target
set_node_name(placeholder, reference_placeholder.name, name_to_node)
module.recompile()

View File

@ -1,143 +0,0 @@
# mypy: allow-untyped-defs
import copy
import functools
from typing import TYPE_CHECKING, Union
import torch
if TYPE_CHECKING:
import io
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
@functools.cache
def has_safetensors_and_transformers():
try:
# safetensors is not an exporter requirement, but needed for some huggingface models
import safetensors # type: ignore[import] # noqa: F401
import transformers # type: ignore[import] # noqa: F401
from safetensors import torch as safetensors_torch # noqa: F401
return True
except ImportError:
return False
class ONNXTorchPatcher:
"""Context manager to temporarily patch PyTorch during FX-to-ONNX export.
This class is a collection of "patches" required by FX-to-ONNX exporter.
This context overrides several torch functions to support symbolic
export of large scale models.
torch.load:
This function is patched to record the files PyTorch stores model
parameters and buffers. Downstream FX-to-ONNX exporter can create
initializers from these files.
torch.fx._symbolic_trace._wrapped_methods_to_patch:
This list is extended with (torch.Tensor, "__getitem__") so that
weight[x, :, y] becomes exportable with torch.fx.symbolic_trace.
safetensors.torch.load_file:
This function is patched to allow safetensors to be loaded within
FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318
Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for
example usage.
TODO: Should this really be a global patcher? Can we make it a local patcher?
A reason for splitting this into several patchers is to patch one part of the code
as a collateral damage of patching another part of the code. For example, we
for tracing model with torch._dynamo.export, we don't need to patch
`torch.fx._symbolic_trace._wrapped_methods_to_patch`
"""
def __init__(self) -> None:
# List of file paths processed by torch.load.
self.paths: list[Union[str, io.BufferedIOBase]] = []
def torch_load_wrapper(f, *args, **kwargs):
# Record path for later serialization into ONNX proto
self.paths.append(f)
# Then, call the original torch.load.
return self.torch_load(f, *args, **kwargs)
# Original version of torch.load.
self.torch_load = torch.load
# Wrapper or modified version of torch functions.
self.torch_load_wrapper = torch_load_wrapper
if has_safetensors_and_transformers():
import safetensors
import transformers
def safetensors_load_file_wrapper(filename, device="cpu"):
# Record path for later serialization into ONNX proto
self.paths.append(filename)
result = {}
with safetensors.torch.safe_open( # type: ignore[attr-defined]
filename, framework="pt", device=device
) as f:
for k in f.keys():
fake_mode = torch._guards.detect_fake_mode()
if not fake_mode:
result[k] = f.get_tensor(k)
else:
empty_tensor = f.get_slice(k)
result[k] = torch.empty(
tuple(empty_tensor.get_shape()),
dtype=safetensors.torch._getdtype(
empty_tensor.get_dtype()
),
)
return result
self.safetensors_torch_load_file = safetensors.torch.load_file
self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper
self.transformers_modeling_utils_safe_load_file = (
transformers.modeling_utils.safe_load_file
)
def __enter__(self):
torch.load = self.torch_load_wrapper
self.torch_fx__symbolic_trace__wrapped_methods_to_patch = (
torch.fx._symbolic_trace._wrapped_methods_to_patch
)
desired_wrapped_methods = copy.deepcopy(
torch.fx._symbolic_trace._wrapped_methods_to_patch
)
if (torch.Tensor, "__getitem__") not in desired_wrapped_methods:
# Adding `__getitem__` to the patching list will make tensor indexing traceable via
# torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced.
# This happens because `__getitem__` is neither under torch domain nor an aten operator,
# so the patching (or similar Proxy-generating mechanism) doesn't happen automatically.
# Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for
# enabling the line below for patching.
desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
if has_safetensors_and_transformers():
import safetensors
import transformers
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
transformers.modeling_utils.safe_load_file = (
self.safetensors_torch_load_file_wrapper
)
def __exit__(self, exc_type, exc_value, traceback):
torch.load = self.torch_load
torch.fx._symbolic_trace._wrapped_methods_to_patch = (
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
)
if has_safetensors_and_transformers():
import safetensors
import transformers
safetensors.torch.load_file = self.safetensors_torch_load_file
transformers.modeling_utils.safe_load_file = (
self.transformers_modeling_utils_safe_load_file
)

View File

@ -1,250 +0,0 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import io
import logging
import os
from typing import IO, TYPE_CHECKING
import torch
from torch.onnx import _type_utils as jit_type_utils
if TYPE_CHECKING:
import onnx
from torch.types import FileLike
log = logging.getLogger(__name__)
def _create_tensor_proto_with_external_data(
tensor: torch.Tensor,
name: str,
location: str,
basepath: str,
dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined]
) -> onnx.TensorProto: # type: ignore[name-defined]
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
Args:
tensor: Tensor to be saved.
name: Name of the tensor (i.e., initializer name in ONNX graph).
location: Relative location of the external data file
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
Reference for ONNX's external data format:
How to load?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
How to save?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
How to set ONNX fields?
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
"""
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
scalar_type = (
jit_type_utils.JitScalarType.from_onnx_type(
dtype_override.tensor_type.elem_type
)
if dtype_override is not None
else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
)
# Checkpoints can be stored with a different dtype as the model expects because
# the user script can explicitly cast the original type to something or maybe
# PyTorch's type promotion might do it
if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
tensor = tensor.to(scalar_type.dtype())
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment]
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
# Settings for saving one tensor per file.
# Offset is zero because there is no other tensor in the same file.
key_value_pairs = {
"location": location,
"offset": 0,
"length": tensor.untyped_storage().nbytes(),
}
for k, v in key_value_pairs.items():
entry = tensor_proto.external_data.add()
entry.key = k
entry.value = str(v)
# Actual path to write content of tensor.
external_data_file_path = os.path.join(basepath, location)
if os.path.exists(external_data_file_path):
os.remove(external_data_file_path)
# Create external data's folder if not exists.
external_data_dir_path = os.path.dirname(external_data_file_path)
if not os.path.exists(external_data_dir_path):
# if the demo_folder directory is not present
# then create it.
os.makedirs(external_data_dir_path)
# Create a fresh file.
with open(external_data_file_path, "xb") as data_file:
# No need to call "seek" because offset is 0.
# data_file.seek(0)
# Write tensor content to the file.
data_file.write(tensor.numpy(force=True).tobytes())
return tensor_proto
def _convert_safetensors_to_torch_format(safetensors_file):
# It this function is called, safetensors is guaranteed to exist
# because the HF model with safetensors was already loaded and exported to ONNX
from safetensors import safe_open # type: ignore[import-not-found, import-untyped]
tensors = {}
with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined]
for k in f.keys():
tensors[k] = f.get_tensor(k).cpu()
return tensors
# TODO: generalize to allow more checkpoints formats (torch or gguf)
def save_model_with_external_data(
basepath: str,
model_location: str,
initializer_location: str,
torch_state_dicts: tuple[dict | FileLike, ...],
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
rename_initializer: bool = False,
) -> None:
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
Output files:
ONNX model file path:
ONNX initializer folder: os.path.join(basepath, initializer_location)
After running this function, you can do
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
to execute the model.
Arguments:
basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
model_location: Relative location of the ONNX model file.
E.g., "model.onnx" so that the model file is saved to
"<basepath>/model.onnx".
initializer_location: Relative location of the ONNX initializer folder.
E.g., "initializers" so that the initializers are saved to
"<basepath>/initializers/".
Note: When initializers are >2GB, must be the same as `model_location`.
torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
onnx_model: ONNX model to be saved with external initializers.
If an input name matches a tensor loaded from "torch_state_dicts",
the tensor will be saved as that input's external initializer.
rename_initializer: Replaces "." by "_" for all ONNX initializer names.
Not needed by the official torch.onnx.dynamo_export. This is a hack
for supporting `FXSymbolicTracer` tracer with fake tensor mode.
In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
"""
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
initializers_to_be_deleted = {} # Using dict because it is **ordered**
existing_initializers = {
k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
}
onnx_input_names = {input.name for input in onnx_model.graph.input}
for el in torch_state_dicts:
if isinstance(el, dict):
# Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
state_dict = el
else:
if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith(
".safetensors"
):
state_dict = _convert_safetensors_to_torch_format(el)
else:
try:
# Loads checkpoint using memory-map on CPU to support really large models
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
state_dict = torch.load(el, map_location="cpu", mmap=True)
except (RuntimeError, ValueError) as e:
if "mmap can only be used with files saved with" in str(e) or (
isinstance(el, (io.IOBase, IO))
and el.readable()
and el.seekable()
):
log.warning(
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
)
if isinstance(el, (io.IOBase, IO)):
el.seek(0) # torch.load from `try:` has read the file.
state_dict = torch.load(el, map_location="cpu")
else:
raise e
for name, tensor in state_dict.items():
if rename_initializer:
# Basically, "transformer.attention.self.query.weight" is mapped
# to "transformer_attention_self_query_weight" for mimicking the
# name-modifying code in FX-to-ONNX exporter.
# See function _replace_get_attr_with_placeholder for details.
name = name.replace(".", "_")
# This block tries to match the onnx initializer name with torch parameter/buffer
# e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
# For each PyTorch tensor name loaded by torch.load,
# 1. Search its best match in ONNX model. E.g., the match of
# "transformer_attention_weight" could be "attention_weight".
# 2. Set "tensor" as the initializer of the matched ONNX input.
# E.g., "tensor" is stored as the initializer of "attention_weight".
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
# loaded by torch.load.
if name in onnx_input_names:
# Same input name shouldn't be matched again
onnx_input_names.remove(name)
else:
for onnx_input_name in onnx_input_names:
if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
# Find a match. Change name to the matched ONNX input name, so that we
# create initializer with the right ONNX name.
name = onnx_input_name
onnx_input_names.remove(onnx_input_name)
break
relative_tensor_file_path = os.path.join(initializer_location, name)
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
model_input_types = {k.name: k.type for k in onnx_model.graph.input}
# Mark for deletion - a replacement will be appended next
if name in existing_initializers:
initializers_to_be_deleted[existing_initializers[name]] = name
tensor_proto = _create_tensor_proto_with_external_data(
tensor,
name,
relative_tensor_file_path,
basepath,
model_input_types.pop(name, None),
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model.graph.initializer.append(tensor_proto)
# Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
initializers_to_be_deleted = dict(
sorted(initializers_to_be_deleted.items(), reverse=True)
)
for idx in initializers_to_be_deleted.keys():
del onnx_model.graph.initializer[idx]
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined]