mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
1/n PR to
- Move code from torch-onnx from commit 395495e566
into torch.onnx and fixes imports.
- Integrate the new export logic with the torch.onnx.export API and include basic set of tests.
- Refactor the API for the change.
- Improve documentation.
Next PRs will be more tests and docs.
Fix https://github.com/pytorch/pytorch/issues/129277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132530
Approved by: https://github.com/titaiwangms, https://github.com/malfet
245 lines
11 KiB
Python
245 lines
11 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.onnx import _type_utils as jit_type_utils
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import onnx
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _create_tensor_proto_with_external_data(
|
|
tensor: torch.Tensor,
|
|
name: str,
|
|
location: str,
|
|
basepath: str,
|
|
dtype_override: onnx.TypeProto | None = None, # type: ignore[name-defined]
|
|
) -> onnx.TensorProto: # type: ignore[name-defined]
|
|
"""Create a TensorProto with external data from a PyTorch tensor.
|
|
The external data is saved to os.path.join(basepath, location).
|
|
|
|
Args:
|
|
tensor: Tensor to be saved.
|
|
name: Name of the tensor (i.e., initializer name in ONNX graph).
|
|
location: Relative location of the external data file
|
|
(e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
|
|
basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
|
|
|
|
|
|
Reference for ONNX's external data format:
|
|
How to load?
|
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
|
|
How to save?
|
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
|
|
How to set ONNX fields?
|
|
https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
|
|
"""
|
|
# FIXME: Avoid importing onnx into torch.onnx.
|
|
import onnx
|
|
|
|
scalar_type = (
|
|
jit_type_utils.JitScalarType.from_onnx_type(
|
|
dtype_override.tensor_type.elem_type
|
|
)
|
|
if dtype_override is not None
|
|
else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
|
|
)
|
|
|
|
# Checkpoints can be stored with a different dtype as the model expects because
|
|
# the user script can explicitly cast the original type to something or maybe
|
|
# PyTorch's type promotion might do it
|
|
if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
|
|
tensor = tensor.to(scalar_type.dtype())
|
|
|
|
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
|
|
tensor_proto.name = name
|
|
tensor_proto.data_type = scalar_type.onnx_type() # type: ignore[assignment]
|
|
|
|
tensor_proto.dims.extend(tensor.shape)
|
|
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
|
|
|
|
# Settings for saving one tensor per file.
|
|
# Offset is zero because there is no other tensor in the same file.
|
|
key_value_pairs = {
|
|
"location": location,
|
|
"offset": 0,
|
|
"length": tensor.untyped_storage().nbytes(),
|
|
}
|
|
for k, v in key_value_pairs.items():
|
|
entry = tensor_proto.external_data.add()
|
|
entry.key = k
|
|
entry.value = str(v)
|
|
|
|
# Actual path to write content of tensor.
|
|
external_data_file_path = os.path.join(basepath, location)
|
|
if os.path.exists(external_data_file_path):
|
|
os.remove(external_data_file_path)
|
|
|
|
# Create external data's folder if not exists.
|
|
external_data_dir_path = os.path.dirname(external_data_file_path)
|
|
if not os.path.exists(external_data_dir_path):
|
|
# if the demo_folder directory is not present
|
|
# then create it.
|
|
os.makedirs(external_data_dir_path)
|
|
|
|
# Create a fresh file.
|
|
with open(external_data_file_path, "xb") as data_file:
|
|
# No need to call "seek" because offset is 0.
|
|
# data_file.seek(0)
|
|
# Write tensor content to the file.
|
|
data_file.write(tensor.numpy(force=True).tobytes())
|
|
|
|
return tensor_proto
|
|
|
|
|
|
def _convert_safetensors_to_torch_format(safetensors_file):
|
|
# It this function is called, safetensors is guaranteed to exist
|
|
# because the HF model with safetensors was already loaded and exported to ONNX
|
|
from safetensors import safe_open # type: ignore[import-not-found]
|
|
|
|
tensors = {}
|
|
with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined]
|
|
for k in f.keys():
|
|
tensors[k] = f.get_tensor(k).cpu()
|
|
return tensors
|
|
|
|
|
|
# TODO: generalize to allow more checkpoints formats (torch or gguf)
|
|
def save_model_with_external_data(
|
|
basepath: str,
|
|
model_location: str,
|
|
initializer_location: str,
|
|
torch_state_dicts: tuple[dict | str | io.BytesIO, ...],
|
|
onnx_model: onnx.ModelProto, # type: ignore[name-defined]
|
|
rename_initializer: bool = False,
|
|
) -> None:
|
|
"""Load PyTorch tensors from files and add to "onnx_model" as external initializers.
|
|
|
|
Output files:
|
|
ONNX model file path:
|
|
ONNX initializer folder: os.path.join(basepath, initializer_location)
|
|
|
|
After running this function, you can do
|
|
ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
|
|
to execute the model.
|
|
|
|
Arguments:
|
|
basepath: Base path of the ONNX external data file (e.g., "/path/to/large_model/").
|
|
model_location: Relative location of the ONNX model file.
|
|
E.g., "model.onnx" so that the model file is saved to
|
|
"<basepath>/model.onnx".
|
|
initializer_location: Relative location of the ONNX initializer folder.
|
|
E.g., "initializers" so that the initializers are saved to
|
|
"<basepath>/initializers/".
|
|
Note: When initializers are >2GB, must be the same as `model_location`.
|
|
torch_state_dicts: Dictionaries or files which contain PyTorch tensors to be saved
|
|
as ONNX initializers. For non-dict arguments, `torch.load` will be used to load them from file-like objects.
|
|
onnx_model: ONNX model to be saved with external initializers.
|
|
If an input name matches a tensor loaded from "torch_state_dicts",
|
|
the tensor will be saved as that input's external initializer.
|
|
rename_initializer: Replaces "." by "_" for all ONNX initializer names.
|
|
Not needed by the official torch.onnx.dynamo_export. This is a hack
|
|
for supporting `FXSymbolicTracer` tracer with fake tensor mode.
|
|
In short, `FXSymbolicTracer` lifts FX parameters (self.linear_weight)
|
|
as inputs (`def forward(self, linear_weight)`) and therefore, `.` cannot be used.
|
|
"""
|
|
# FIXME: Avoid importing onnx into torch.onnx.
|
|
import onnx
|
|
|
|
initializers_to_be_deleted = {} # Using dict because it is **ordered**
|
|
existing_initializers = {
|
|
k.name: idx for idx, k in enumerate(onnx_model.graph.initializer)
|
|
}
|
|
onnx_input_names = {input.name for input in onnx_model.graph.input}
|
|
for el in torch_state_dicts:
|
|
if isinstance(el, dict):
|
|
# Useful for when state_dict is loaded with torch.load(..., mmap=True, map_location="cpu") by the user
|
|
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
|
|
state_dict = el
|
|
else:
|
|
if isinstance(el, str) and el.endswith(".safetensors"):
|
|
state_dict = _convert_safetensors_to_torch_format(el)
|
|
else:
|
|
try:
|
|
# Loads checkpoint using memory-map on CPU to support really large models
|
|
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
|
|
state_dict = torch.load(el, map_location="cpu", mmap=True)
|
|
except (RuntimeError, ValueError) as e:
|
|
if "mmap can only be used with files saved with" in str(
|
|
e
|
|
) or isinstance(el, io.BytesIO):
|
|
log.warning(
|
|
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
|
|
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
|
|
)
|
|
if isinstance(el, io.BytesIO):
|
|
el.seek(0) # torch.load from `try:` has read the file.
|
|
state_dict = torch.load(el, map_location="cpu")
|
|
else:
|
|
raise e
|
|
|
|
for name, tensor in state_dict.items():
|
|
if rename_initializer:
|
|
# Basically, "transformer.attention.self.query.weight" is mapped
|
|
# to "transformer_attention_self_query_weight" for mimicking the
|
|
# name-modifying code in FX-to-ONNX exporter.
|
|
# See function _replace_get_attr_with_placeholder for details.
|
|
name = name.replace(".", "_")
|
|
|
|
# This block tries to match the onnx initializer name with torch parameter/buffer
|
|
# e.g. A pytorch buffer 'transformer.h.0.attn.bias' can be named 'h.0.attn.bias' in a ONNX initializer
|
|
# For each PyTorch tensor name loaded by torch.load,
|
|
# 1. Search its best match in ONNX model. E.g., the match of
|
|
# "transformer_attention_weight" could be "attention_weight".
|
|
# 2. Set "tensor" as the initializer of the matched ONNX input.
|
|
# E.g., "tensor" is stored as the initializer of "attention_weight".
|
|
# Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
|
|
# loaded by torch.load.
|
|
if name in onnx_input_names:
|
|
# Same input name shouldn't be matched again
|
|
onnx_input_names.remove(name)
|
|
else:
|
|
for onnx_input_name in onnx_input_names:
|
|
if onnx_input_name.endswith(name) or name.endswith(onnx_input_name):
|
|
# Find a match. Change name to the matched ONNX input name, so that we
|
|
# create initializer with the right ONNX name.
|
|
name = onnx_input_name
|
|
onnx_input_names.remove(onnx_input_name)
|
|
break
|
|
|
|
relative_tensor_file_path = os.path.join(initializer_location, name)
|
|
# Create one file per tensor.
|
|
# tensor_proto.raw_data is stored to external file at
|
|
# os.path.join(basepath, relative_tensor_file_path).
|
|
model_input_types = {k.name: k.type for k in onnx_model.graph.input}
|
|
|
|
# Mark for deletion - a replacement will be appended next
|
|
if name in existing_initializers:
|
|
initializers_to_be_deleted[existing_initializers[name]] = name
|
|
tensor_proto = _create_tensor_proto_with_external_data(
|
|
tensor,
|
|
name,
|
|
relative_tensor_file_path,
|
|
basepath,
|
|
model_input_types.pop(name, None),
|
|
)
|
|
# Add the tensor_proto to the ONNX model as an initializer with external data.
|
|
onnx_model.graph.initializer.append(tensor_proto)
|
|
# Remove old duplicated initializers, if any. delete in desc order to not invalidate deletion indices
|
|
initializers_to_be_deleted = dict(
|
|
sorted(initializers_to_be_deleted.items(), reverse=True)
|
|
)
|
|
for idx in initializers_to_be_deleted.keys():
|
|
del onnx_model.graph.initializer[idx]
|
|
|
|
# model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
|
|
onnx.save(onnx_model, os.path.join(basepath, model_location)) # type: ignore[attr-defined]
|