From 419a2dbf5f69cee52382090200b532a81da92c69 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 22 Aug 2025 22:15:27 +0000 Subject: [PATCH] [ONNX] Remove enable_fake_mode and exporter_legacy (#161222) Remove enable_fake_mode and exporter_legacy entirely. Even though this is bc breaking, `enable_fake_mode` is no longer compatible with the latest version of transformers, and so it is no longer useful. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161222 Approved by: https://github.com/titaiwangms --- docs/source/onnx.md | 2 - docs/source/onnx_export.md | 1 - test/onnx/exporter/test_api.py | 187 ++++------------- test/onnx/test_fx_passes.py | 60 ------ torch/onnx/__init__.py | 3 - torch/onnx/_internal/_exporter_legacy.py | 118 ----------- torch/onnx/_internal/fx/__init__.py | 8 - torch/onnx/_internal/fx/passes/_utils.py | 114 ----------- torch/onnx/_internal/fx/patcher.py | 143 ------------- torch/onnx/_internal/fx/serialization.py | 250 ----------------------- 10 files changed, 42 insertions(+), 844 deletions(-) delete mode 100644 test/onnx/test_fx_passes.py delete mode 100644 torch/onnx/_internal/_exporter_legacy.py delete mode 100644 torch/onnx/_internal/fx/passes/_utils.py delete mode 100644 torch/onnx/_internal/fx/patcher.py delete mode 100644 torch/onnx/_internal/fx/serialization.py diff --git a/docs/source/onnx.md b/docs/source/onnx.md index 06b049ec39bc..b0ed78dbe69b 100644 --- a/docs/source/onnx.md +++ b/docs/source/onnx.md @@ -84,8 +84,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/ :noindex: .. autofunction:: is_in_onnx_export :noindex: -.. autofunction:: enable_fake_mode - :noindex: ``` ### Classes diff --git a/docs/source/onnx_export.md b/docs/source/onnx_export.md index 029952aa4e99..0adfec359d0b 100644 --- a/docs/source/onnx_export.md +++ b/docs/source/onnx_export.md @@ -245,5 +245,4 @@ Each initialized value, input, output has the following metadata: .. autofunction:: torch.onnx.is_in_onnx_export .. autoclass:: torch.onnx.OnnxExporterError :members: -.. autofunction:: torch.onnx.enable_fake_mode ``` diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 593cc524ebe7..67f38902acc6 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -7,11 +7,9 @@ import io import logging import os -import numpy as np -from onnxscript import BOOL, FLOAT, ir, opset18 as op +from onnxscript import BOOL, FLOAT, opset18 as op import torch -import torch.onnx._flags from torch.onnx._internal.exporter import _testing as onnx_testing from torch.testing._internal import common_utils @@ -339,6 +337,47 @@ class TestExportAPIDynamo(common_utils.TestCase): ), ) + def test_is_in_onnx_export(self): + class Mod(torch.nn.Module): + def forward(self, x): + def f(x): + return x.sin() if torch.onnx.is_in_onnx_export() else x.cos() + + return f(x) + + self.assertFalse(torch.onnx.is_in_onnx_export()) + onnx_program = torch.onnx.export( + Mod(), + (torch.randn(3, 4),), + dynamo=True, + fallback=False, + ) + self.assertFalse(torch.onnx.is_in_onnx_export()) + + node_names = [n.op_type for n in onnx_program.model.graph] + self.assertIn("Sin", node_names) + + def test_torchscript_exporter_raises_deprecation_warning(self): + # Test that the deprecation warning is raised when using torchscript exporter + with self.assertWarnsRegex( + DeprecationWarning, "You are using the legacy TorchScript-based ONNX export" + ): + torch.onnx.export( + SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False + ) + + def test_model_output_can_be_none(self): + class ModelWithNoneOutput(torch.nn.Module): + def forward(self, x): + return x + 1, None + + onnx_program = torch.onnx.export( + ModelWithNoneOutput(), + (torch.randn(1, 1, 2),), + dynamo=True, + ) + onnx_testing.assert_onnx_program(onnx_program) + class TestCustomTranslationTable(common_utils.TestCase): def test_custom_translation_table_overrides_ops(self): @@ -471,147 +510,5 @@ class TestCustomTranslationTable(common_utils.TestCase): self.assertNotIn("Sub", all_nodes_decomp) -class TestFakeTensorExport(common_utils.TestCase): - """Test exporting in fake mode.""" - - def test_onnx_program_raises_when_model_defined_in_fake_mode(self): - with torch.onnx.enable_fake_mode(): - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.tensor(42.0)) - - def forward(self, x): - return self.weight + x - - onnx_program = torch.onnx.export( - Model(), (torch.tensor(1.0),), dynamo=True, optimize=False - ) - assert onnx_program is not None - # Convert to model proto and back to trigger to_bytes method which serializes the tensor - with self.assertRaises(Exception): - # The tensors need to be replaced with real tensors - _ = onnx_program.model_proto - - # Convert to model proto and back to trigger to_bytes method which serializes the tensor - with self.assertRaises(Exception): - # It doesn't matter if it is called inside or outside of the enable_fake_mode() context - _ = onnx_program.model_proto - - # If we replace with concrete tensors, the serialization will succeed. - # This needs to happen outside of the fake context - onnx_program.apply_weights({"weight": torch.tensor(42.0)}) - onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) - np.testing.assert_allclose( - onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 - ) - - def test_onnx_program_save_raises_when_model_initialized_in_fake_mode(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.tensor(42.0)) - - def forward(self, x): - return self.weight + x - - with torch.onnx.enable_fake_mode(): - onnx_program = torch.onnx.export( - Model(), (torch.tensor(1.0),), dynamo=True, optimize=False - ) - assert onnx_program is not None - # Convert to model proto and back to trigger to_bytes method which serializes the tensor - with self.assertRaises(Exception): - # The tensors need to be replaced with real tensors - _ = onnx_program.model_proto - - with self.assertRaises(Exception): - # It doesn't matter if it is called inside or outside of the enable_fake_mode() context - _ = onnx_program.model_proto - - # If we replace with concrete tensors, the serialization will succeed - # This needs to happen outside of the fake context - onnx_program.apply_weights({"weight": torch.tensor(42.0)}) - onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) - np.testing.assert_allclose( - onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 - ) - - def test_onnx_program_save_succeeds_when_export_and_save_in_fake_mode(self): - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = torch.nn.Parameter(torch.tensor(42.0)) - - def forward(self, x): - return self.weight + x - - real_model = Model() - - with torch.onnx.enable_fake_mode(): - onnx_program = torch.onnx.export( - real_model, (torch.tensor(1.0),), dynamo=True, optimize=False - ) - - assert onnx_program is not None - # Convert to model proto and back to trigger to_bytes method which serializes the tensor - # Note that even though we are calling .model_proto (equivalently .save()) in fake mode, - # the concrete tensors are maintained. - # This is due to the usage of torch._subclasses.fake_tensor.unset_fake_temporarily() in - # TorchTensor.tobytes() - onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) - np.testing.assert_allclose( - onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 - ) - - # This works inside or outside the fake mode - onnx_model = ir.serde.deserialize_model(onnx_program.model_proto) - np.testing.assert_allclose( - onnx_model.graph.initializers["weight"].const_value.numpy(), 42.0 - ) - - def test_is_in_onnx_export(self): - class Mod(torch.nn.Module): - def forward(self, x): - def f(x): - return x.sin() if torch.onnx.is_in_onnx_export() else x.cos() - - return f(x) - - self.assertFalse(torch.onnx.is_in_onnx_export()) - onnx_program = torch.onnx.export( - Mod(), - (torch.randn(3, 4),), - dynamo=True, - fallback=False, - ) - self.assertFalse(torch.onnx.is_in_onnx_export()) - - node_names = [n.op_type for n in onnx_program.model.graph] - self.assertIn("Sin", node_names) - - def test_torchscript_exporter_raises_deprecation_warning(self): - # Test that the deprecation warning is raised when using torchscript exporter - with self.assertWarnsRegex( - DeprecationWarning, "You are using the legacy TorchScript-based ONNX export" - ): - torch.onnx.export( - SampleModel(), (torch.randn(1, 1, 2),), io.BytesIO(), dynamo=False - ) - - def test_model_output_can_be_none(self): - class ModelWithNoneOutput(torch.nn.Module): - def forward(self, x): - return x + 1, None - - onnx_program = torch.onnx.export( - ModelWithNoneOutput(), - (torch.randn(1, 1, 2),), - dynamo=True, - ) - onnx_testing.assert_onnx_program(onnx_program) - - if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py deleted file mode 100644 index 97d255abdcb1..000000000000 --- a/test/onnx/test_fx_passes.py +++ /dev/null @@ -1,60 +0,0 @@ -# Owner(s): ["module: onnx"] -import torch -import torch._dynamo -import torch.fx -from torch.onnx._internal.fx.passes import _utils as pass_utils -from torch.testing._internal import common_utils - - -class TestFxPasses(common_utils.TestCase): - def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self): - def func(x, y, z): - return x + y + z - - x = torch.randn(3) - y = torch.randn(3) - z = torch.randn(3) - gm, _ = torch._dynamo.export(func)(x, y, z) - torch._dynamo.reset() - - # Purposely name the nodes in a way that will cause a recursive collision later. - # See :func:`set_node_name` for name collision renaming logic. - base_name = "tensor" - nodes = list(gm.graph.nodes) - for i, node in enumerate(nodes[1:]): - if i == 0: - node.name = base_name - else: - node.name = f"{base_name}.{i}" - - # Run `set_node_name` and verify that the names are correct. - name_to_node = {node.name: node for node in gm.graph.nodes} - pass_utils.set_node_name(nodes[0], base_name, name_to_node) - assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}" - assert len({node.name for node in nodes}) == len(nodes), ( - f"Expected all names to be unique, got {nodes}" - ) - - def test_set_node_name_succeeds_when_no_name_collisions(self): - def func(x, y, z): - return x + y + z - - x = torch.randn(3) - y = torch.randn(3) - z = torch.randn(3) - gm, _ = torch._dynamo.export(func)(x, y, z) - torch._dynamo.reset() - - # Run `set_node_name` and verify that the names are correct. - new_name = "some_tensor" - nodes = list(gm.graph.nodes) - name_to_node = {node.name: node for node in nodes} - pass_utils.set_node_name(nodes[1], new_name, name_to_node) - assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}" - assert len({node.name for node in nodes}) == len(nodes), ( - f"Expected all names to be unique, got {nodes}" - ) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 6c301ef294eb..8c6f295a8a50 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -37,7 +37,6 @@ __all__ = [ # Base error "OnnxExporterError", "ONNXProgram", - "enable_fake_mode", ] from typing import Any, Callable, TYPE_CHECKING @@ -47,7 +46,6 @@ import torch from torch._C import _onnx as _C_onnx from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode -from ._internal._exporter_legacy import enable_fake_mode from ._internal.exporter._onnx_program import ONNXProgram from ._type_utils import JitScalarType from .errors import OnnxExporterError @@ -90,7 +88,6 @@ if TYPE_CHECKING: JitScalarType.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" -enable_fake_mode.__module__ = "torch.onnx" producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py deleted file mode 100644 index f9ae42b26b84..000000000000 --- a/torch/onnx/_internal/_exporter_legacy.py +++ /dev/null @@ -1,118 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - - -__all__ = [ - "enable_fake_mode", -] - - -import contextlib -import dataclasses -import logging -from typing import Any, TYPE_CHECKING - -import torch -import torch._ops -from torch.onnx._internal.fx import patcher as patcher - - -# We can only import onnx from this module in a type-checking context to ensure that -# 'import torch.onnx' continues to work without having 'onnx' installed. We fully -# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). -if TYPE_CHECKING: - import io - - from torch._subclasses import fake_tensor - -log = logging.getLogger(__name__) - - -@dataclasses.dataclass -class ONNXFakeContext: - """A dataclass used to store context for model export using FakeTensor. - - This dataclass stores the FakeTensorMode instance used to convert - real tensors and model parameters into fake tensors. This :attr:`ONNXFakeContext.fake_mode` is - reused internally during tracing of a :class:`torch.nn.Module` into a FX :class:`GraphModule`. - """ - - fake_mode: fake_tensor.FakeTensorMode - """The fake tensor mode used for tracing model using fake tensors and parameters.""" - - state_dict_paths: tuple[str | io.BytesIO | dict[str, Any]] | None = None - """List of paths of files that contain the model :meth:`state_dict`""" - - -@contextlib.contextmanager -def enable_fake_mode(): - """Enable fake mode for the duration of the context. - - Internally it instantiates a :class:`torch._subclasses.fake_tensor.FakeTensorMode` context manager - that converts user input and model parameters into :class:`torch._subclasses.fake_tensor.FakeTensor`. - - A :class:`torch._subclasses.fake_tensor.FakeTensor` - is a :class:`torch.Tensor` with the ability to run PyTorch code without having to - actually do computation through tensors allocated on a ``meta`` device. Because - there is no actual data being allocated on the device, this API allows for - initializing and exporting large models without the actual memory footprint needed for executing it. - - It is highly recommended to initialize the model in fake mode when exporting models that - are too large to fit into memory. - - .. note:: - This function does not support torch.onnx.export(..., dynamo=True, optimize=True). - Please call ONNXProgram.optimize() outside of the function after the model is exported. - - Example:: - - # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import torch - >>> class MyModel(torch.nn.Module): # Model with a parameter - ... def __init__(self) -> None: - ... super().__init__() - ... self.weight = torch.nn.Parameter(torch.tensor(42.0)) - ... def forward(self, x): - ... return self.weight + x - >>> with torch.onnx.enable_fake_mode(): - ... # When initialized in fake mode, the model's parameters are fake tensors - ... # They do not take up memory so we can initialize large models - ... my_nn_module = MyModel() - ... arg1 = torch.randn(2, 2, 2) - >>> onnx_program = torch.onnx.export(my_nn_module, (arg1,), dynamo=True, optimize=False) - >>> # Saving model WITHOUT initializers (only the architecture) - >>> onnx_program.save( - ... "my_model_without_initializers.onnx", - ... include_initializers=False, - ... keep_initializers_as_inputs=True, - ... ) - >>> # Saving model WITH initializers after applying concrete weights - >>> onnx_program.apply_weights({"weight": torch.tensor(42.0)}) - >>> onnx_program.save("my_model_with_initializers.onnx") - - .. warning:: - This API is experimental and is *NOT* backward-compatible. - - """ - from torch._subclasses import fake_tensor - from torch.fx.experimental.symbolic_shapes import ShapeEnv - - # This overrides the internal `FakeTensorMode` instance created by `torch._dynamo.export`[1]. - # It is a good idea to keep them in sync (constructor args) to maintain the same default behavior - # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__` - # Mixed fake/real tensors are only allowed when `torch.onnx.dynamo_export` is not called within `FakeTensorMode` - # This is needed because models can create new parameters during `forward(self, *args, **kwargs)` run - fake_mode = fake_tensor.FakeTensorMode( - allow_non_fake_inputs=not torch._guards.detect_fake_mode(), - shape_env=ShapeEnv( - allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False - ), - ) - # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode - patcher_context = patcher.ONNXTorchPatcher() - fake_context = ONNXFakeContext(fake_mode=fake_mode) - with fake_mode, patcher_context: - yield fake_context - fake_context.state_dict_paths = tuple( - patcher_context.paths, - ) # type: ignore[assignment] diff --git a/torch/onnx/_internal/fx/__init__.py b/torch/onnx/_internal/fx/__init__.py index b5716bdafced..e69de29bb2d1 100644 --- a/torch/onnx/_internal/fx/__init__.py +++ b/torch/onnx/_internal/fx/__init__.py @@ -1,8 +0,0 @@ -from .patcher import ONNXTorchPatcher -from .serialization import save_model_with_external_data - - -__all__ = [ - "save_model_with_external_data", - "ONNXTorchPatcher", -] diff --git a/torch/onnx/_internal/fx/passes/_utils.py b/torch/onnx/_internal/fx/passes/_utils.py deleted file mode 100644 index a7b05786ab17..000000000000 --- a/torch/onnx/_internal/fx/passes/_utils.py +++ /dev/null @@ -1,114 +0,0 @@ -# mypy: allow-untyped-defs -"""Common utility functions for FX passes. - -These functions should NOT be directly invoked outside of `passes` package. -""" - -from __future__ import annotations - -import collections -import re -from typing import Callable - -import torch.fx -import torch.fx.traceback as fx_traceback - - -def wrap_graph_module_for_node_meta_preservation( - graph_module: torch.fx.GraphModule, -) -> Callable: - """Wrap a GraphModule with contexts to preserve node meta information, such as stacktrace info. - - This is typically useful before calling `make_fx`. Without this wrapper, the - stacktrace information will be lost afterwards. - """ - - def wrapped(*args): - with fx_traceback.preserve_node_meta(): - return torch.fx.Interpreter(graph_module).run(*args) - - return wrapped - - -def _get_node_base_name(node_name: str) -> tuple[str, int | None]: - pattern = r"(.*)\.(\d+)" - match = re.match(pattern, node_name) - if match is not None: - base_name, count_str = match.groups() - return base_name, int(count_str) - return node_name, None - - -def set_node_name( - node: torch.fx.Node, - new_name: str, - name_to_node_cache: dict[str, torch.fx.Node], -): - """Safely set the unique name of a node. - - If the new name is already taken by another node, the name of the other node will be - updated. If `new_name` is a string of format f"{base_name}.{count}", where `count` - is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not, - the other node will be renamed as "{new_name}.1". This function will iteratively - update the names until there is no conflict. - - ``name_to_node_cache`` is required as an argument to avoid recomputation. The caller - is responsible for ensuring the cache is accurate and in sync with the owning module - of the node. The values in the cache will be updated accordingly. - - Args: - node: The node to update. - new_name: The new name to use. - name_to_node_cache: A cache of node names to nodes. - """ - node_name_to_set = collections.deque([(node, new_name)]) - - while node_name_to_set: - node, new_name = node_name_to_set.pop() - if new_name in name_to_node_cache and name_to_node_cache[new_name] != node: - base_name, postfix_count = _get_node_base_name(new_name) - if postfix_count is None: - postfix_count = 0 - node_name_to_set.append( - (name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}") - ) - node.name = new_name - name_to_node_cache[new_name] = node - - -def replace_placeholder_name_and_target( - module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule -): - """Replace the argument names in module with those in reference_module. - - This function assumes the two modules have the same signature structure. - The caller is responsible for ensuring this. Otherwise, the behavior of this - function is undefined. This function only does minimal sanity check that the two - modules have the same number of arguments. - - Name conflicts between new names and existing node names in the graph are handled. - Check the documentation of :func:`set_node_name` for more details. - - Raises: - RuntimeError: If the two modules have different number of arguments. - """ - placeholders = [node for node in module.graph.nodes if node.op == "placeholder"] - reference_placeholders = [ - node for node in reference_module.graph.nodes if node.op == "placeholder" - ] - - if len(placeholders) != len(reference_placeholders): - raise RuntimeError( - "The two modules have different number of arguments. " - f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}" - ) - - name_to_node: dict[str, torch.fx.Node] = {} - for node in module.graph.nodes: - name_to_node[node.name] = node - - for placeholder, reference_placeholder in zip(placeholders, reference_placeholders): - placeholder.target = reference_placeholder.target - set_node_name(placeholder, reference_placeholder.name, name_to_node) - - module.recompile() diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py deleted file mode 100644 index 6c9724e9f5a7..000000000000 --- a/torch/onnx/_internal/fx/patcher.py +++ /dev/null @@ -1,143 +0,0 @@ -# mypy: allow-untyped-defs -import copy -import functools -from typing import TYPE_CHECKING, Union - -import torch - - -if TYPE_CHECKING: - import io - - -# TODO: Remove after https://github.com/huggingface/safetensors/pull/318 -@functools.cache -def has_safetensors_and_transformers(): - try: - # safetensors is not an exporter requirement, but needed for some huggingface models - import safetensors # type: ignore[import] # noqa: F401 - import transformers # type: ignore[import] # noqa: F401 - from safetensors import torch as safetensors_torch # noqa: F401 - - return True - except ImportError: - return False - - -class ONNXTorchPatcher: - """Context manager to temporarily patch PyTorch during FX-to-ONNX export. - - This class is a collection of "patches" required by FX-to-ONNX exporter. - - This context overrides several torch functions to support symbolic - export of large scale models. - - torch.load: - This function is patched to record the files PyTorch stores model - parameters and buffers. Downstream FX-to-ONNX exporter can create - initializers from these files. - torch.fx._symbolic_trace._wrapped_methods_to_patch: - This list is extended with (torch.Tensor, "__getitem__") so that - weight[x, :, y] becomes exportable with torch.fx.symbolic_trace. - safetensors.torch.load_file: - This function is patched to allow safetensors to be loaded within - FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318 - - Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for - example usage. - - TODO: Should this really be a global patcher? Can we make it a local patcher? - A reason for splitting this into several patchers is to patch one part of the code - as a collateral damage of patching another part of the code. For example, we - for tracing model with torch._dynamo.export, we don't need to patch - `torch.fx._symbolic_trace._wrapped_methods_to_patch` - """ - - def __init__(self) -> None: - # List of file paths processed by torch.load. - self.paths: list[Union[str, io.BufferedIOBase]] = [] - - def torch_load_wrapper(f, *args, **kwargs): - # Record path for later serialization into ONNX proto - self.paths.append(f) - # Then, call the original torch.load. - return self.torch_load(f, *args, **kwargs) - - # Original version of torch.load. - self.torch_load = torch.load - - # Wrapper or modified version of torch functions. - self.torch_load_wrapper = torch_load_wrapper - - if has_safetensors_and_transformers(): - import safetensors - import transformers - - def safetensors_load_file_wrapper(filename, device="cpu"): - # Record path for later serialization into ONNX proto - self.paths.append(filename) - result = {} - with safetensors.torch.safe_open( # type: ignore[attr-defined] - filename, framework="pt", device=device - ) as f: - for k in f.keys(): - fake_mode = torch._guards.detect_fake_mode() - if not fake_mode: - result[k] = f.get_tensor(k) - else: - empty_tensor = f.get_slice(k) - result[k] = torch.empty( - tuple(empty_tensor.get_shape()), - dtype=safetensors.torch._getdtype( - empty_tensor.get_dtype() - ), - ) - return result - - self.safetensors_torch_load_file = safetensors.torch.load_file - self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper - self.transformers_modeling_utils_safe_load_file = ( - transformers.modeling_utils.safe_load_file - ) - - def __enter__(self): - torch.load = self.torch_load_wrapper - - self.torch_fx__symbolic_trace__wrapped_methods_to_patch = ( - torch.fx._symbolic_trace._wrapped_methods_to_patch - ) - desired_wrapped_methods = copy.deepcopy( - torch.fx._symbolic_trace._wrapped_methods_to_patch - ) - if (torch.Tensor, "__getitem__") not in desired_wrapped_methods: - # Adding `__getitem__` to the patching list will make tensor indexing traceable via - # torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced. - # This happens because `__getitem__` is neither under torch domain nor an aten operator, - # so the patching (or similar Proxy-generating mechanism) doesn't happen automatically. - # Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for - # enabling the line below for patching. - desired_wrapped_methods.append((torch.Tensor, "__getitem__")) - torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods - - if has_safetensors_and_transformers(): - import safetensors - import transformers - - safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper - transformers.modeling_utils.safe_load_file = ( - self.safetensors_torch_load_file_wrapper - ) - - def __exit__(self, exc_type, exc_value, traceback): - torch.load = self.torch_load - torch.fx._symbolic_trace._wrapped_methods_to_patch = ( - self.torch_fx__symbolic_trace__wrapped_methods_to_patch - ) - if has_safetensors_and_transformers(): - import safetensors - import transformers - - safetensors.torch.load_file = self.safetensors_torch_load_file - transformers.modeling_utils.safe_load_file = ( - self.transformers_modeling_utils_safe_load_file - ) diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py deleted file mode 100644 index cda71e465758..000000000000 --- a/torch/onnx/_internal/fx/serialization.py +++ /dev/null @@ -1,250 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import io -import logging -import os -from typing import IO, TYPE_CHECKING - -import torch -from torch.onnx import _type_utils as jit_type_utils - - -if TYPE_CHECKING: - import onnx - - from torch.types import FileLike - -log = logging.getLogger(__name__) - - -def _create_tensor_proto_with_external_data( - tensor: torch.Tensor, - name: str, - location: str, - basepath: str, - dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined] -) -> onnx.TensorProto: # type: ignore[name-defined] - """Create a TensorProto with external data from a PyTorch tensor. - The external data is saved to os.path.join(basepath, location). - - Args: - tensor: Tensor to be saved. - name: Name of the tensor (i.e., initializer name in ONNX graph). - location: Relative location of the external data file - (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx"). - basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp"). - - - Reference for ONNX's external data format: - How to load? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187 - How to save? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43 - How to set ONNX fields? - https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88 - """ - # FIXME: Avoid importing onnx into torch.onnx. - import onnx - - scalar_type = ( - jit_type_utils.JitScalarType.from_onnx_type( - dtype_override.tensor_type.elem_type - ) - if dtype_override is not None - else jit_type_utils.JitScalarType.from_dtype(tensor.dtype) - ) - - # Checkpoints can be stored with a different dtype as the model expects because - # the user script can explicitly cast the original type to something or maybe - # PyTorch's type promotion might do it - if dtype_override is not None and scalar_type.dtype() != tensor.dtype: - tensor = tensor.to(scalar_type.dtype()) - - tensor_proto = onnx.TensorProto() # type: ignore[attr-defined] - tensor_proto.name = name - tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment] - - tensor_proto.dims.extend(tensor.shape) - tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined] - - # Settings for saving one tensor per file. - # Offset is zero because there is no other tensor in the same file. - key_value_pairs = { - "location": location, - "offset": 0, - "length": tensor.untyped_storage().nbytes(), - } - for k, v in key_value_pairs.items(): - entry = tensor_proto.external_data.add() - entry.key = k - entry.value = str(v) - - # Actual path to write content of tensor. - external_data_file_path = os.path.join(basepath, location) - if os.path.exists(external_data_file_path): - os.remove(external_data_file_path) - - # Create external data's folder if not exists. - external_data_dir_path = os.path.dirname(external_data_file_path) - if not os.path.exists(external_data_dir_path): - # if the demo_folder directory is not present - # then create it. - os.makedirs(external_data_dir_path) - - # Create a fresh file. - with open(external_data_file_path, "xb") as data_file: - # No need to call "seek" because offset is 0. - # data_file.seek(0) - # Write tensor content to the file. - data_file.write(tensor.numpy(force=True).tobytes()) - - return tensor_proto - - -def _convert_safetensors_to_torch_format(safetensors_file): - # It this function is called, safetensors is guaranteed to exist - # because the HF model with safetensors was already loaded and exported to ONNX - from safetensors import safe_open # type: ignore[import-not-found, import-untyped] - - tensors = {} - with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined] - for k in f.keys(): - tensors[k] = f.get_tensor(k).cpu() - return tensors - - -# TODO: generalize to allow more checkpoints formats (torch or gguf) -def save_model_with_external_data( - basepath: str, - model_location: str, - initializer_location: str, - torch_state_dicts: tuple[dict | FileLike, ...], - onnx_model: onnx.ModelProto, # type: ignore[name-defined] - rename_initializer: bool = False, -) -> None: - """Load PyTorch tensors from files and add to "onnx_model" as external initializers. - - Output files: - ONNX model file path: - ONNX initializer folder: os.path.join(basepath, initializer_location) - - After running this function, you can do - ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location)) - to execute the model. - - Arguments: - basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/"). - model_location: Relative location of the ONNX model file. - E.g., "model.onnx" so that the model file is saved to - "/model.onnx". - initializer_location: Relative location of the ONNX initializer folder. - E.g., "initializers" so that the initializers are saved to - "/initializers/". - Note: When initializers are >2GB, must be the same as `model_location`. - torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved - as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects. - onnx_model: ONNX model to be saved with external initializers. - If an input name matches a tensor loaded from "torch_state_dicts", - the tensor will be saved as that input's external initializer. - rename_initializer: Replaces "." by "_" for all ONNX initializer names. - Not needed by the official torch.onnx.dynamo_export. This is a hack - for supporting `FXSymbolicTracer` tracer with fake tensor mode. - In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight) - as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used. - """ - # FIXME: Avoid importing onnx into torch.onnx. - import onnx - - initializers_to_be_deleted = {} # Using dict because it is **ordered** - existing_initializers = { - k.name: idx for idx, k in enumerate(onnx_model.graph.initializer) - } - onnx_input_names = {input.name for input in onnx_model.graph.input} - for el in torch_state_dicts: - if isinstance(el, dict): - # Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user - # Using torch.save wouldn't leverage mmap, leading to higher memory usage - state_dict = el - else: - if isinstance(el, (str, os.PathLike)) and os.fspath(el).endswith( - ".safetensors" - ): - state_dict = _convert_safetensors_to_torch_format(el) - else: - try: - # Loads checkpoint using memory-map on CPU to support really large models - # The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded - state_dict = torch.load(el, map_location="cpu", mmap=True) - except (RuntimeError, ValueError) as e: - if "mmap can only be used with files saved with" in str(e) or ( - isinstance(el, (io.IOBase, IO)) - and el.readable() - and el.seekable() - ): - log.warning( - "Failed to load the checkpoint with memory-map enabled, retrying without memory-map." - "Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6." - ) - if isinstance(el, (io.IOBase, IO)): - el.seek(0) # torch.load from `try:` has read the file. - state_dict = torch.load(el, map_location="cpu") - else: - raise e - - for name, tensor in state_dict.items(): - if rename_initializer: - # Basically, "transformer.attention.self.query.weight" is mapped - # to "transformer_attention_self_query_weight" for mimicking the - # name-modifying code in FX-to-ONNX exporter. - # See function _replace_get_attr_with_placeholder for details. - name = name.replace(".", "_") - - # This block tries to match the onnx initializer name with torch parameter/buffer - # e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer - # For each PyTorch tensor name loaded by torch.load, - # 1. Search its best match in ONNX model. E.g., the match of - # "transformer_attention_weight" could be "attention_weight". - # 2. Set "tensor" as the initializer of the matched ONNX input. - # E.g., "tensor" is stored as the initializer of "attention_weight". - # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary - # loaded by torch.load. - if name in onnx_input_names: - # Same input name shouldn't be matched again - onnx_input_names.remove(name) - else: - for onnx_input_name in onnx_input_names: - if onnx_input_name.endswith(name) or name.endswith(onnx_input_name): - # Find a match. Change name to the matched ONNX input name, so that we - # create initializer with the right ONNX name. - name = onnx_input_name - onnx_input_names.remove(onnx_input_name) - break - - relative_tensor_file_path = os.path.join(initializer_location, name) - # Create one file per tensor. - # tensor_proto.raw_data is stored to external file at - # os.path.join(basepath, relative_tensor_file_path). - model_input_types = {k.name: k.type for k in onnx_model.graph.input} - - # Mark for deletion - a replacement will be appended next - if name in existing_initializers: - initializers_to_be_deleted[existing_initializers[name]] = name - tensor_proto = _create_tensor_proto_with_external_data( - tensor, - name, - relative_tensor_file_path, - basepath, - model_input_types.pop(name, None), - ) - # Add the tensor_proto to the ONNX model as an initializer with external data. - onnx_model.graph.initializer.append(tensor_proto) - # Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices - initializers_to_be_deleted = dict( - sorted(initializers_to_be_deleted.items(), reverse=True) - ) - for idx in initializers_to_be_deleted.keys(): - del onnx_model.graph.initializer[idx] - - # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". - onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined]