mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155613 Approved by: https://github.com/ezyang ghstack dependencies: #155612
144 lines
5.9 KiB
Python
144 lines
5.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
import functools
|
|
from typing import TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import io
|
|
|
|
|
|
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
|
|
@functools.cache
|
|
def has_safetensors_and_transformers():
|
|
try:
|
|
# safetensors is not an exporter requirement, but needed for some huggingface models
|
|
import safetensors # type: ignore[import] # noqa: F401
|
|
import transformers # type: ignore[import] # noqa: F401
|
|
from safetensors import torch as safetensors_torch # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
class ONNXTorchPatcher:
|
|
"""Context manager to temporarily patch PyTorch during FX-to-ONNX export.
|
|
|
|
This class is a collection of "patches" required by FX-to-ONNX exporter.
|
|
|
|
This context overrides several torch functions to support symbolic
|
|
export of large scale models.
|
|
|
|
torch.load:
|
|
This function is patched to record the files PyTorch stores model
|
|
parameters and buffers. Downstream FX-to-ONNX exporter can create
|
|
initializers from these files.
|
|
torch.fx._symbolic_trace._wrapped_methods_to_patch:
|
|
This list is extended with (torch.Tensor, "__getitem__") so that
|
|
weight[x, :, y] becomes exportable with torch.fx.symbolic_trace.
|
|
safetensors.torch.load_file:
|
|
This function is patched to allow safetensors to be loaded within
|
|
FakeTensorMode. Remove after https://github.com/huggingface/safetensors/pull/318
|
|
|
|
Search for ONNXTorchPatcher in test_fx_to_onnx_with_onnxruntime.py for
|
|
example usage.
|
|
|
|
TODO: Should this really be a global patcher? Can we make it a local patcher?
|
|
A reason for splitting this into several patchers is to patch one part of the code
|
|
as a collateral damage of patching another part of the code. For example, we
|
|
for tracing model with torch._dynamo.export, we don't need to patch
|
|
`torch.fx._symbolic_trace._wrapped_methods_to_patch`
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# List of file paths processed by torch.load.
|
|
self.paths: list[Union[str, io.BufferedIOBase]] = []
|
|
|
|
def torch_load_wrapper(f, *args, **kwargs):
|
|
# Record path for later serialization into ONNX proto
|
|
self.paths.append(f)
|
|
# Then, call the original torch.load.
|
|
return self.torch_load(f, *args, **kwargs)
|
|
|
|
# Original version of torch.load.
|
|
self.torch_load = torch.load
|
|
|
|
# Wrapper or modified version of torch functions.
|
|
self.torch_load_wrapper = torch_load_wrapper
|
|
|
|
if has_safetensors_and_transformers():
|
|
import safetensors
|
|
import transformers
|
|
|
|
def safetensors_load_file_wrapper(filename, device="cpu"):
|
|
# Record path for later serialization into ONNX proto
|
|
self.paths.append(filename)
|
|
result = {}
|
|
with safetensors.torch.safe_open( # type: ignore[attr-defined]
|
|
filename, framework="pt", device=device
|
|
) as f:
|
|
for k in f.keys():
|
|
fake_mode = torch._guards.detect_fake_mode()
|
|
if not fake_mode:
|
|
result[k] = f.get_tensor(k)
|
|
else:
|
|
empty_tensor = f.get_slice(k)
|
|
result[k] = torch.empty(
|
|
tuple(empty_tensor.get_shape()),
|
|
dtype=safetensors.torch._getdtype(
|
|
empty_tensor.get_dtype()
|
|
),
|
|
)
|
|
return result
|
|
|
|
self.safetensors_torch_load_file = safetensors.torch.load_file
|
|
self.safetensors_torch_load_file_wrapper = safetensors_load_file_wrapper
|
|
self.transformers_modeling_utils_safe_load_file = (
|
|
transformers.modeling_utils.safe_load_file
|
|
)
|
|
|
|
def __enter__(self):
|
|
torch.load = self.torch_load_wrapper
|
|
|
|
self.torch_fx__symbolic_trace__wrapped_methods_to_patch = (
|
|
torch.fx._symbolic_trace._wrapped_methods_to_patch
|
|
)
|
|
desired_wrapped_methods = copy.deepcopy(
|
|
torch.fx._symbolic_trace._wrapped_methods_to_patch
|
|
)
|
|
if (torch.Tensor, "__getitem__") not in desired_wrapped_methods:
|
|
# Adding `__getitem__` to the patching list will make tensor indexing traceable via
|
|
# torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced.
|
|
# This happens because `__getitem__` is neither under torch domain nor an aten operator,
|
|
# so the patching (or similar Proxy-generating mechanism) doesn't happen automatically.
|
|
# Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for
|
|
# enabling the line below for patching.
|
|
desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
|
|
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
|
|
|
|
if has_safetensors_and_transformers():
|
|
import safetensors
|
|
import transformers
|
|
|
|
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
|
|
transformers.modeling_utils.safe_load_file = (
|
|
self.safetensors_torch_load_file_wrapper
|
|
)
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
torch.load = self.torch_load
|
|
torch.fx._symbolic_trace._wrapped_methods_to_patch = (
|
|
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
|
|
)
|
|
if has_safetensors_and_transformers():
|
|
import safetensors
|
|
import transformers
|
|
|
|
safetensors.torch.load_file = self.safetensors_torch_load_file
|
|
transformers.modeling_utils.safe_load_file = (
|
|
self.transformers_modeling_utils_safe_load_file
|
|
)
|