mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ONNX] Remove enable_fake_mode and exporter_legacy (#161222)
Remove enable_fake_mode and exporter_legacy entirely. Even though this is bc breaking, `enable_fake_mode` is no longer compatible with the latest version of transformers, and so it is no longer useful. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161222 Approved by: https://github.com/titaiwangms
This commit is contained in:
committed by
PyTorch MergeBot
parent
3373b074f5
commit
419a2dbf5f
@ -84,8 +84,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/
|
||||
:noindex:
|
||||
.. autofunction:: is_in_onnx_export
|
||||
:noindex:
|
||||
.. autofunction:: enable_fake_mode
|
||||
:noindex:
|
||||
```
|
||||
|
||||
### Classes
|
||||
|
@ -245,5 +245,4 @@ Each initialized value, input, output has the following metadata:
|
||||
.. autofunction:: torch.onnx.is_in_onnx_export
|
||||
.. autoclass:: torch.onnx.OnnxExporterError
|
||||
:members:
|
||||
.. autofunction:: torch.onnx.enable_fake_mode
|
||||
```
|
||||
|
@ -7,11 +7,9 @@ import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from onnxscript import BOOL, FLOAT, ir, opset18 as op
|
||||
from onnxscript import BOOL, FLOAT, opset18 as op
|
||||
|
||||
import torch
|
||||
import torch.onnx._flags
|
||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
@ -339,6 +337,47 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_is_in_onnx_export(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
def f(x):
|
||||
return x.sin() if torch.onnx.is_in_onnx_export() else x.cos()
|
||||
|
||||
return f(x)
|
||||
|
||||
self.assertFalse(torch.onnx.is_in_onnx_export())
|
||||
onnx_program = torch.onnx.export(
|
||||
Mod(),
|
||||
(torch.randn(3, 4),),
|
||||
dynamo=True,
|
||||
fallback=False,
|
||||
)
|
||||
self.assertFalse(torch.onnx.is_in_onnx_export())
|
||||
|
||||
node_names = [n.op_type for n in onnx_program.model.graph]
|
||||
self.assertIn("Sin", node_names)
|
||||
|
||||
def test_torchscript_exporter_raises_deprecation_warning(self):
|
||||
# Test that the deprecation warning is raised when using torchscript exporter
|
||||
with self.assertWarnsRegex(
|
||||
DeprecationWarning, "You are using the legacy TorchScript-based ONNX export"
|
||||
):
|
||||
torch.onnx.export(
|
||||
SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False
|
||||
)
|
||||
|
||||
def test_model_output_can_be_none(self):
|
||||
class ModelWithNoneOutput(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1, None
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
ModelWithNoneOutput(),
|
||||
(torch.randn(1, 1, 2),),
|
||||
dynamo=True,
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
|
||||
class TestCustomTranslationTable(common_utils.TestCase):
|
||||
def test_custom_translation_table_overrides_ops(self):
|
||||
@ -471,147 +510,5 @@ class TestCustomTranslationTable(common_utils.TestCase):
|
||||
self.assertNotIn("Sub", all_nodes_decomp)
|
||||
|
||||
|
||||
class TestFakeTensorExport(common_utils.TestCase):
|
||||
"""Test exporting in fake mode."""
|
||||
|
||||
def test_onnx_program_raises_when_model_defined_in_fake_mode(self):
|
||||
with torch.onnx.enable_fake_mode():
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight + x
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(), (torch.tensor(1.0),), dynamo=True, optimize=False
|
||||
)
|
||||
assert onnx_program is not None
|
||||
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
||||
with self.assertRaises(Exception):
|
||||
# The tensors need to be replaced with real tensors
|
||||
_ = onnx_program.model_proto
|
||||
|
||||
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
||||
with self.assertRaises(Exception):
|
||||
# It doesn't matter if it is called inside or outside of the enable_fake_mode() context
|
||||
_ = onnx_program.model_proto
|
||||
|
||||
# If we replace with concrete tensors, the serialization will succeed.
|
||||
# This needs to happen outside of the fake context
|
||||
onnx_program.apply_weights({"weight": torch.tensor(42.0)})
|
||||
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
||||
np.testing.assert_allclose(
|
||||
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
||||
)
|
||||
|
||||
def test_onnx_program_save_raises_when_model_initialized_in_fake_mode(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight + x
|
||||
|
||||
with torch.onnx.enable_fake_mode():
|
||||
onnx_program = torch.onnx.export(
|
||||
Model(), (torch.tensor(1.0),), dynamo=True, optimize=False
|
||||
)
|
||||
assert onnx_program is not None
|
||||
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
||||
with self.assertRaises(Exception):
|
||||
# The tensors need to be replaced with real tensors
|
||||
_ = onnx_program.model_proto
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
# It doesn't matter if it is called inside or outside of the enable_fake_mode() context
|
||||
_ = onnx_program.model_proto
|
||||
|
||||
# If we replace with concrete tensors, the serialization will succeed
|
||||
# This needs to happen outside of the fake context
|
||||
onnx_program.apply_weights({"weight": torch.tensor(42.0)})
|
||||
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
||||
np.testing.assert_allclose(
|
||||
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
||||
)
|
||||
|
||||
def test_onnx_program_save_succeeds_when_export_and_save_in_fake_mode(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight + x
|
||||
|
||||
real_model = Model()
|
||||
|
||||
with torch.onnx.enable_fake_mode():
|
||||
onnx_program = torch.onnx.export(
|
||||
real_model, (torch.tensor(1.0),), dynamo=True, optimize=False
|
||||
)
|
||||
|
||||
assert onnx_program is not None
|
||||
# Convert to model proto and back to trigger to_bytes method which serializes the tensor
|
||||
# Note that even though we are calling .model_proto (equivalently .save()) in fake mode,
|
||||
# the concrete tensors are maintained.
|
||||
# This is due to the usage of torch._subclasses.fake_tensor.unset_fake_temporarily() in
|
||||
# TorchTensor.tobytes()
|
||||
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
||||
np.testing.assert_allclose(
|
||||
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
||||
)
|
||||
|
||||
# This works inside or outside the fake mode
|
||||
onnx_model = ir.serde.deserialize_model(onnx_program.model_proto)
|
||||
np.testing.assert_allclose(
|
||||
onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0
|
||||
)
|
||||
|
||||
def test_is_in_onnx_export(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
def f(x):
|
||||
return x.sin() if torch.onnx.is_in_onnx_export() else x.cos()
|
||||
|
||||
return f(x)
|
||||
|
||||
self.assertFalse(torch.onnx.is_in_onnx_export())
|
||||
onnx_program = torch.onnx.export(
|
||||
Mod(),
|
||||
(torch.randn(3, 4),),
|
||||
dynamo=True,
|
||||
fallback=False,
|
||||
)
|
||||
self.assertFalse(torch.onnx.is_in_onnx_export())
|
||||
|
||||
node_names = [n.op_type for n in onnx_program.model.graph]
|
||||
self.assertIn("Sin", node_names)
|
||||
|
||||
def test_torchscript_exporter_raises_deprecation_warning(self):
|
||||
# Test that the deprecation warning is raised when using torchscript exporter
|
||||
with self.assertWarnsRegex(
|
||||
DeprecationWarning, "You are using the legacy TorchScript-based ONNX export"
|
||||
):
|
||||
torch.onnx.export(
|
||||
SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False
|
||||
)
|
||||
|
||||
def test_model_output_can_be_none(self):
|
||||
class ModelWithNoneOutput(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1, None
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
ModelWithNoneOutput(),
|
||||
(torch.randn(1, 1, 2),),
|
||||
dynamo=True,
|
||||
)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -1,60 +0,0 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.fx
|
||||
from torch.onnx._internal.fx.passes import _utils as pass_utils
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class TestFxPasses(common_utils.TestCase):
|
||||
def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
|
||||
def func(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3)
|
||||
gm, _ = torch._dynamo.export(func)(x, y, z)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Purposely name the nodes in a way that will cause a recursive collision later.
|
||||
# See :func:`set_node_name` for name collision renaming logic.
|
||||
base_name = "tensor"
|
||||
nodes = list(gm.graph.nodes)
|
||||
for i, node in enumerate(nodes[1:]):
|
||||
if i == 0:
|
||||
node.name = base_name
|
||||
else:
|
||||
node.name = f"{base_name}.{i}"
|
||||
|
||||
# Run `set_node_name` and verify that the names are correct.
|
||||
name_to_node = {node.name: node for node in gm.graph.nodes}
|
||||
pass_utils.set_node_name(nodes[0], base_name, name_to_node)
|
||||
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
|
||||
assert len({node.name for node in nodes}) == len(nodes), (
|
||||
f"Expected all names to be unique, got {nodes}"
|
||||
)
|
||||
|
||||
def test_set_node_name_succeeds_when_no_name_collisions(self):
|
||||
def func(x, y, z):
|
||||
return x + y + z
|
||||
|
||||
x = torch.randn(3)
|
||||
y = torch.randn(3)
|
||||
z = torch.randn(3)
|
||||
gm, _ = torch._dynamo.export(func)(x, y, z)
|
||||
torch._dynamo.reset()
|
||||
|
||||
# Run `set_node_name` and verify that the names are correct.
|
||||
new_name = "some_tensor"
|
||||
nodes = list(gm.graph.nodes)
|
||||
name_to_node = {node.name: node for node in nodes}
|
||||
pass_utils.set_node_name(nodes[1], new_name, name_to_node)
|
||||
assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
|
||||
assert len({node.name for node in nodes}) == len(nodes), (
|
||||
f"Expected all names to be unique, got {nodes}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
@ -37,7 +37,6 @@ __all__ = [
|
||||
# Base error
|
||||
"OnnxExporterError",
|
||||
"ONNXProgram",
|
||||
"enable_fake_mode",
|
||||
]
|
||||
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
@ -47,7 +46,6 @@ import torch
|
||||
from torch._C import _onnx as _C_onnx
|
||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||
|
||||
from ._internal._exporter_legacy import enable_fake_mode
|
||||
from ._internal.exporter._onnx_program import ONNXProgram
|
||||
from ._type_utils import JitScalarType
|
||||
from .errors import OnnxExporterError
|
||||
@ -90,7 +88,6 @@ if TYPE_CHECKING:
|
||||
JitScalarType.__module__ = "torch.onnx"
|
||||
ONNXProgram.__module__ = "torch.onnx"
|
||||
OnnxExporterError.__module__ = "torch.onnx"
|
||||
enable_fake_mode.__module__ = "torch.onnx"
|
||||
|
||||
producer_name = "pytorch"
|
||||
producer_version = _C_onnx.PRODUCER_VERSION
|
||||
|
@ -1,118 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
__all__ = [
|
||||
"enable_fake_mode",
|
||||
]
|
||||
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
from torch.onnx._internal.fx import patcher as patcher
|
||||
|
||||
|
||||
# We can only import onnx from this module in a type-checking context to ensure that
|
||||
# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
|
||||
# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ONNXFakeContext:
|
||||
"""A dataclass used to store context for model export using FakeTensor.
|
||||
|
||||
This dataclass stores the FakeTensorMode instance used to convert
|
||||
real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is
|
||||
reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`.
|
||||
"""
|
||||
|
||||
fake_mode: fake_tensor.FakeTensorMode
|
||||
"""The fake tensor mode used for tracing model using fake tensors and parameters."""
|
||||
|
||||
state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None
|
||||
"""List of paths of files that contain the model :meth:`state_dict`"""
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_fake_mode():
|
||||
"""Enable fake mode for the duration of the context.
|
||||
|
||||
Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager
|
||||
that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`.
|
||||
|
||||
A :class:`torch._subclasses.fake_tensor.FakeTensor`
|
||||
is a :class:`torch.Tensor` with the ability to run PyTorch code without having to
|
||||
actually do computation through tensors allocated on a ``meta`` device. Because
|
||||
there is no actual data being allocated on the device, this API allows for
|
||||
initializing and exporting large models without the actual memory footprint needed for executing it.
|
||||
|
||||
It is highly recommended to initialize the model in fake mode when exporting models that
|
||||
are too large to fit into memory.
|
||||
|
||||
.. note::
|
||||
This function does not support torch.onnx.export(..., dynamo=True, optimize=True).
|
||||
Please call ONNXProgram.optimize() outside of the function after the model is exported.
|
||||
|
||||
Example::
|
||||
|
||||
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
||||
>>> import torch
|
||||
>>> class MyModel(torch.nn.Module): # Model with a parameter
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
... self.weight = torch.nn.Parameter(torch.tensor(42.0))
|
||||
... def forward(self, x):
|
||||
... return self.weight + x
|
||||
>>> with torch.onnx.enable_fake_mode():
|
||||
... # When initialized in fake mode, the model's parameters are fake tensors
|
||||
... # They do not take up memory so we can initialize large models
|
||||
... my_nn_module = MyModel()
|
||||
... arg1 = torch.randn(2, 2, 2)
|
||||
>>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True, optimize=False)
|
||||
>>> # Saving model WITHOUT initializers (only the architecture)
|
||||
>>> onnx_program.save(
|
||||
... "my_model_without_initializers.onnx",
|
||||
... include_initializers=False,
|
||||
... keep_initializers_as_inputs=True,
|
||||
... )
|
||||
>>> # Saving model WITH initializers after applying concrete weights
|
||||
>>> onnx_program.apply_weights({"weight": torch.tensor(42.0)})
|
||||
>>> onnx_program.save("my_model_with_initializers.onnx")
|
||||
|
||||
.. warning::
|
||||
This API is experimental and is *NOT* backward-compatible.
|
||||
|
||||
"""
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
# This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1].
|
||||
# It is a good idea to keep them in sync (constructor args) to maintain the same default behavior
|
||||
# [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__`
|
||||
# Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode`
|
||||
# This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run
|
||||
fake_mode = fake_tensor.FakeTensorMode(
|
||||
allow_non_fake_inputs=not torch._guards.detect_fake_mode(),
|
||||
shape_env=ShapeEnv(
|
||||
allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False
|
||||
),
|
||||
)
|
||||
# The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode
|
||||
patcher_context = patcher.ONNXTorchPatcher()
|
||||
fake_context = ONNXFakeContext(fake_mode=fake_mode)
|
||||
with fake_mode, patcher_context:
|
||||
yield fake_context
|
||||
fake_context.state_dict_paths = tuple(
|
||||
patcher_context.paths,
|
||||
) # type: ignore[assignment]
|
@ -1,8 +0,0 @@
|
||||
from .patcher import ONNXTorchPatcher
|
||||
from .serialization import save_model_with_external_data
|
||||
|
||||
|
||||
__all__ = [
|
||||
"save_model_with_external_data",
|
||||
"ONNXTorchPatcher",
|
||||
]
|
||||
|
@ -1,114 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Common utility functions for FX passes.
|
||||
|
||||
These functions should NOT be directly invoked outside of `passes` package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import torch.fx
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
|
||||
def wrap_graph_module_for_node_meta_preservation(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
) -> Callable:
|
||||
"""Wrap a GraphModule with contexts to preserve node meta information, such as stacktrace info.
|
||||
|
||||
This is typically useful before calling `make_fx`. Without this wrapper, the
|
||||
stacktrace information will be lost afterwards.
|
||||
"""
|
||||
|
||||
def wrapped(*args):
|
||||
with fx_traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph_module).run(*args)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _get_node_base_name(node_name: str) -> tuple[str, int | None]:
|
||||
pattern = r"(.*)\.(\d+)"
|
||||
match = re.match(pattern, node_name)
|
||||
if match is not None:
|
||||
base_name, count_str = match.groups()
|
||||
return base_name, int(count_str)
|
||||
return node_name, None
|
||||
|
||||
|
||||
def set_node_name(
|
||||
node: torch.fx.Node,
|
||||
new_name: str,
|
||||
name_to_node_cache: dict[str, torch.fx.Node],
|
||||
):
|
||||
"""Safely set the unique name of a node.
|
||||
|
||||
If the new name is already taken by another node, the name of the other node will be
|
||||
updated. If `new_name` is a string of format f"{base_name}.{count}", where `count`
|
||||
is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not,
|
||||
the other node will be renamed as "{new_name}.1". This function will iteratively
|
||||
update the names until there is no conflict.
|
||||
|
||||
``name_to_node_cache`` is required as an argument to avoid recomputation. The caller
|
||||
is responsible for ensuring the cache is accurate and in sync with the owning module
|
||||
of the node. The values in the cache will be updated accordingly.
|
||||
|
||||
Args:
|
||||
node: The node to update.
|
||||
new_name: The new name to use.
|
||||
name_to_node_cache: A cache of node names to nodes.
|
||||
"""
|
||||
node_name_to_set = collections.deque([(node, new_name)])
|
||||
|
||||
while node_name_to_set:
|
||||
node, new_name = node_name_to_set.pop()
|
||||
if new_name in name_to_node_cache and name_to_node_cache[new_name] != node:
|
||||
base_name, postfix_count = _get_node_base_name(new_name)
|
||||
if postfix_count is None:
|
||||
postfix_count = 0
|
||||
node_name_to_set.append(
|
||||
(name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}")
|
||||
)
|
||||
node.name = new_name
|
||||
name_to_node_cache[new_name] = node
|
||||
|
||||
|
||||
def replace_placeholder_name_and_target(
|
||||
module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
|
||||
):
|
||||
"""Replace the argument names in module with those in reference_module.
|
||||
|
||||
This function assumes the two modules have the same signature structure.
|
||||
The caller is responsible for ensuring this. Otherwise, the behavior of this
|
||||
function is undefined. This function only does minimal sanity check that the two
|
||||
modules have the same number of arguments.
|
||||
|
||||
Name conflicts between new names and existing node names in the graph are handled.
|
||||
Check the documentation of :func:`set_node_name` for more details.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the two modules have different number of arguments.
|
||||
"""
|
||||
placeholders = [node for node in module.graph.nodes if node.op == "placeholder"]
|
||||
reference_placeholders = [
|
||||
node for node in reference_module.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
|
||||
if len(placeholders) != len(reference_placeholders):
|
||||
raise RuntimeError(
|
||||
"The two modules have different number of arguments. "
|
||||
f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}"
|
||||
)
|
||||
|
||||
name_to_node: dict[str, torch.fx.Node] = {}
|
||||
for node in module.graph.nodes:
|
||||
name_to_node[node.name] = node
|
||||
|
||||
for placeholder, reference_placeholder in zip(placeholders, reference_placeholders):
|
||||
placeholder.target = reference_placeholder.target
|
||||
set_node_name(placeholder, reference_placeholder.name, name_to_node)
|
||||
|
||||
module.recompile()
|
@ -1,143 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import io
|
||||
|
||||
|
||||
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
|
||||
@functools.cache
|
||||
def has_safetensors_and_transformers():
|
||||
try:
|
||||
# safetensors is not an exporter requirement, but needed for some huggingface models
|
||||
import safetensors # type: ignore[import] # noqa: F401
|
||||
import transformers # type: ignore[import] # noqa: F401
|
||||
from safetensors import torch as safetensors_torch # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class ONNXTorchPatcher:
|
||||
"""Context manager to temporarily patch PyTorch during FX-to-ONNX export.
|
||||
|
||||
This class is a collection of "patches" required by FX-to-ONNX exporter.
|
||||
|
||||
This context overrides several torch functions to support symbolic
|
||||
export of large scale models.
|
||||
|
||||
torch.load:
|
||||
This function is patched to record the files PyTorch stores model
|
||||
parameters and buffers. Downstream FX-to-ONNX exporter can create
|
||||
initializers from these files.
|
||||
torch.fx._symbolic_trace._wrapped_methods_to_patch:
|
||||
This list is extended with (torch.Tensor, "__getitem__") so that
|
||||
weight[x, :, y] becomes exportable with torch.fx.symbolic_trace.
|
||||
safetensors.torch.load_file:
|
||||
This function is patched to allow safetensors to be loaded within
|
||||
FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318
|
||||
|
||||
Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for
|
||||
example usage.
|
||||
|
||||
TODO: Should this really be a global patcher? Can we make it a local patcher?
|
||||
A reason for splitting this into several patchers is to patch one part of the code
|
||||
as a collateral damage of patching another part of the code. For example, we
|
||||
for tracing model with torch._dynamo.export, we don't need to patch
|
||||
`torch.fx._symbolic_trace._wrapped_methods_to_patch`
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# List of file paths processed by torch.load.
|
||||
self.paths: list[Union[str, io.BufferedIOBase]] = []
|
||||
|
||||
def torch_load_wrapper(f, *args, **kwargs):
|
||||
# Record path for later serialization into ONNX proto
|
||||
self.paths.append(f)
|
||||
# Then, call the original torch.load.
|
||||
return self.torch_load(f, *args, **kwargs)
|
||||
|
||||
# Original version of torch.load.
|
||||
self.torch_load = torch.load
|
||||
|
||||
# Wrapper or modified version of torch functions.
|
||||
self.torch_load_wrapper = torch_load_wrapper
|
||||
|
||||
if has_safetensors_and_transformers():
|
||||
import safetensors
|
||||
import transformers
|
||||
|
||||
def safetensors_load_file_wrapper(filename, device="cpu"):
|
||||
# Record path for later serialization into ONNX proto
|
||||
self.paths.append(filename)
|
||||
result = {}
|
||||
with safetensors.torch.safe_open( # type: ignore[attr-defined]
|
||||
filename, framework="pt", device=device
|
||||
) as f:
|
||||
for k in f.keys():
|
||||
fake_mode = torch._guards.detect_fake_mode()
|
||||
if not fake_mode:
|
||||
result[k] = f.get_tensor(k)
|
||||
else:
|
||||
empty_tensor = f.get_slice(k)
|
||||
result[k] = torch.empty(
|
||||
tuple(empty_tensor.get_shape()),
|
||||
dtype=safetensors.torch._getdtype(
|
||||
empty_tensor.get_dtype()
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
self.safetensors_torch_load_file = safetensors.torch.load_file
|
||||
self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper
|
||||
self.transformers_modeling_utils_safe_load_file = (
|
||||
transformers.modeling_utils.safe_load_file
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
torch.load = self.torch_load_wrapper
|
||||
|
||||
self.torch_fx__symbolic_trace__wrapped_methods_to_patch = (
|
||||
torch.fx._symbolic_trace._wrapped_methods_to_patch
|
||||
)
|
||||
desired_wrapped_methods = copy.deepcopy(
|
||||
torch.fx._symbolic_trace._wrapped_methods_to_patch
|
||||
)
|
||||
if (torch.Tensor, "__getitem__") not in desired_wrapped_methods:
|
||||
# Adding `__getitem__` to the patching list will make tensor indexing traceable via
|
||||
# torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced.
|
||||
# This happens because `__getitem__` is neither under torch domain nor an aten operator,
|
||||
# so the patching (or similar Proxy-generating mechanism) doesn't happen automatically.
|
||||
# Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for
|
||||
# enabling the line below for patching.
|
||||
desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
|
||||
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
|
||||
|
||||
if has_safetensors_and_transformers():
|
||||
import safetensors
|
||||
import transformers
|
||||
|
||||
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
|
||||
transformers.modeling_utils.safe_load_file = (
|
||||
self.safetensors_torch_load_file_wrapper
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
torch.load = self.torch_load
|
||||
torch.fx._symbolic_trace._wrapped_methods_to_patch = (
|
||||
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
|
||||
)
|
||||
if has_safetensors_and_transformers():
|
||||
import safetensors
|
||||
import transformers
|
||||
|
||||
safetensors.torch.load_file = self.safetensors_torch_load_file
|
||||
transformers.modeling_utils.safe_load_file = (
|
||||
self.transformers_modeling_utils_safe_load_file
|
||||
)
|
@ -1,250 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import IO, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.onnx import _type_utils as jit_type_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import onnx
|
||||
|
||||
from torch.types import FileLike
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_tensor_proto_with_external_data(
|
||||
tensor: torch.Tensor,
|
||||
name: str,
|
||||
location: str,
|
||||
basepath: str,
|
||||
dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined]
|
||||
) -> onnx.TensorProto: # type: ignore[name-defined]
|
||||
"""Create a TensorProto with external data from a PyTorch tensor.
|
||||
The external data is saved to os.path.join(basepath, location).
|
||||
|
||||
Args:
|
||||
tensor: Tensor to be saved.
|
||||
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
||||
location: Relative location of the external data file
|
||||
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
||||
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
||||
|
||||
|
||||
Reference for ONNX's external data format:
|
||||
How to load?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
||||
How to save?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
||||
How to set ONNX fields?
|
||||
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
||||
"""
|
||||
# FIXME: Avoid importing onnx into torch.onnx.
|
||||
import onnx
|
||||
|
||||
scalar_type = (
|
||||
jit_type_utils.JitScalarType.from_onnx_type(
|
||||
dtype_override.tensor_type.elem_type
|
||||
)
|
||||
if dtype_override is not None
|
||||
else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
|
||||
)
|
||||
|
||||
# Checkpoints can be stored with a different dtype as the model expects because
|
||||
# the user script can explicitly cast the original type to something or maybe
|
||||
# PyTorch's type promotion might do it
|
||||
if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
|
||||
tensor = tensor.to(scalar_type.dtype())
|
||||
|
||||
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
|
||||
tensor_proto.name = name
|
||||
tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment]
|
||||
|
||||
tensor_proto.dims.extend(tensor.shape)
|
||||
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
|
||||
|
||||
# Settings for saving one tensor per file.
|
||||
# Offset is zero because there is no other tensor in the same file.
|
||||
key_value_pairs = {
|
||||
"location": location,
|
||||
"offset": 0,
|
||||
"length": tensor.untyped_storage().nbytes(),
|
||||
}
|
||||
for k, v in key_value_pairs.items():
|
||||
entry = tensor_proto.external_data.add()
|
||||
entry.key = k
|
||||
entry.value = str(v)
|
||||
|
||||
# Actual path to write content of tensor.
|
||||
external_data_file_path = os.path.join(basepath, location)
|
||||
if os.path.exists(external_data_file_path):
|
||||
os.remove(external_data_file_path)
|
||||
|
||||
# Create external data's folder if not exists.
|
||||
external_data_dir_path = os.path.dirname(external_data_file_path)
|
||||
if not os.path.exists(external_data_dir_path):
|
||||
# if the demo_folder directory is not present
|
||||
# then create it.
|
||||
os.makedirs(external_data_dir_path)
|
||||
|
||||
# Create a fresh file.
|
||||
with open(external_data_file_path, "xb") as data_file:
|
||||
# No need to call "seek" because offset is 0.
|
||||
# data_file.seek(0)
|
||||
# Write tensor content to the file.
|
||||
data_file.write(tensor.numpy(force=True).tobytes())
|
||||
|
||||
return tensor_proto
|
||||
|
||||
|
||||
def _convert_safetensors_to_torch_format(safetensors_file):
|
||||
# It this function is called, safetensors is guaranteed to exist
|
||||
# because the HF model with safetensors was already loaded and exported to ONNX
|
||||
from safetensors import safe_open # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
tensors = {}
|
||||
with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined]
|
||||
for k in f.keys():
|
||||
tensors[k] = f.get_tensor(k).cpu()
|
||||
return tensors
|
||||
|
||||
|
||||
# TODO: generalize to allow more checkpoints formats (torch or gguf)
|
||||
def save_model_with_external_data(
|
||||
basepath: str,
|
||||
model_location: str,
|
||||
initializer_location: str,
|
||||
torch_state_dicts: tuple[dict | FileLike, ...],
|
||||
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
|
||||
rename_initializer: bool = False,
|
||||
) -> None:
|
||||
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
||||
|
||||
Output files:
|
||||
ONNX model file path:
|
||||
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
||||
|
||||
After running this function, you can do
|
||||
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
||||
to execute the model.
|
||||
|
||||
Arguments:
|
||||
basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
|
||||
model_location: Relative location of the ONNX model file.
|
||||
E.g., "model.onnx" so that the model file is saved to
|
||||
"<basepath>/model.onnx".
|
||||
initializer_location: Relative location of the ONNX initializer folder.
|
||||
E.g., "initializers" so that the initializers are saved to
|
||||
"<basepath>/initializers/".
|
||||
Note: When initializers are >2GB, must be the same as `model_location`.
|
||||
torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
|
||||
as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
|
||||
onnx_model: ONNX model to be saved with external initializers.
|
||||
If an input name matches a tensor loaded from "torch_state_dicts",
|
||||
the tensor will be saved as that input's external initializer.
|
||||
rename_initializer: Replaces "." by "_" for all ONNX initializer names.
|
||||
Not needed by the official torch.onnx.dynamo_export. This is a hack
|
||||
for supporting `FXSymbolicTracer` tracer with fake tensor mode.
|
||||
In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
|
||||
as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
|
||||
"""
|
||||
# FIXME: Avoid importing onnx into torch.onnx.
|
||||
import onnx
|
||||
|
||||
initializers_to_be_deleted = {} # Using dict because it is **ordered**
|
||||
existing_initializers = {
|
||||
k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
|
||||
}
|
||||
onnx_input_names = {input.name for input in onnx_model.graph.input}
|
||||
for el in torch_state_dicts:
|
||||
if isinstance(el, dict):
|
||||
# Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
|
||||
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
|
||||
state_dict = el
|
||||
else:
|
||||
if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith(
|
||||
".safetensors"
|
||||
):
|
||||
state_dict = _convert_safetensors_to_torch_format(el)
|
||||
else:
|
||||
try:
|
||||
# Loads checkpoint using memory-map on CPU to support really large models
|
||||
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
|
||||
state_dict = torch.load(el, map_location="cpu", mmap=True)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "mmap can only be used with files saved with" in str(e) or (
|
||||
isinstance(el, (io.IOBase, IO))
|
||||
and el.readable()
|
||||
and el.seekable()
|
||||
):
|
||||
log.warning(
|
||||
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
|
||||
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
|
||||
)
|
||||
if isinstance(el, (io.IOBase, IO)):
|
||||
el.seek(0) # torch.load from `try:` has read the file.
|
||||
state_dict = torch.load(el, map_location="cpu")
|
||||
else:
|
||||
raise e
|
||||
|
||||
for name, tensor in state_dict.items():
|
||||
if rename_initializer:
|
||||
# Basically, "transformer.attention.self.query.weight" is mapped
|
||||
# to "transformer_attention_self_query_weight" for mimicking the
|
||||
# name-modifying code in FX-to-ONNX exporter.
|
||||
# See function _replace_get_attr_with_placeholder for details.
|
||||
name = name.replace(".", "_")
|
||||
|
||||
# This block tries to match the onnx initializer name with torch parameter/buffer
|
||||
# e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
|
||||
# For each PyTorch tensor name loaded by torch.load,
|
||||
# 1. Search its best match in ONNX model. E.g., the match of
|
||||
# "transformer_attention_weight" could be "attention_weight".
|
||||
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
||||
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
||||
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
||||
# loaded by torch.load.
|
||||
if name in onnx_input_names:
|
||||
# Same input name shouldn't be matched again
|
||||
onnx_input_names.remove(name)
|
||||
else:
|
||||
for onnx_input_name in onnx_input_names:
|
||||
if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
|
||||
# Find a match. Change name to the matched ONNX input name, so that we
|
||||
# create initializer with the right ONNX name.
|
||||
name = onnx_input_name
|
||||
onnx_input_names.remove(onnx_input_name)
|
||||
break
|
||||
|
||||
relative_tensor_file_path = os.path.join(initializer_location, name)
|
||||
# Create one file per tensor.
|
||||
# tensor_proto.raw_data is stored to external file at
|
||||
# os.path.join(basepath, relative_tensor_file_path).
|
||||
model_input_types = {k.name: k.type for k in onnx_model.graph.input}
|
||||
|
||||
# Mark for deletion - a replacement will be appended next
|
||||
if name in existing_initializers:
|
||||
initializers_to_be_deleted[existing_initializers[name]] = name
|
||||
tensor_proto = _create_tensor_proto_with_external_data(
|
||||
tensor,
|
||||
name,
|
||||
relative_tensor_file_path,
|
||||
basepath,
|
||||
model_input_types.pop(name, None),
|
||||
)
|
||||
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
||||
onnx_model.graph.initializer.append(tensor_proto)
|
||||
# Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
|
||||
initializers_to_be_deleted = dict(
|
||||
sorted(initializers_to_be_deleted.items(), reverse=True)
|
||||
)
|
||||
for idx in initializers_to_be_deleted.keys():
|
||||
del onnx_model.graph.initializer[idx]
|
||||
|
||||
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
||||
onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined]
|
Reference in New Issue
Block a user