mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support to save safetensors checkpoint directly into onnx (#121001)
Currently, when `torch.onnx.dynamo_export` is called within `torch.onnx.enable_fake_mode`, all the external pytorch checkpoint files used to initialize the model are automatically and used by `torch.onnx.ONNXProgram.save` to recreate the initializers for the newly exported ONNX model. This API extends the mechanism for HuggingFace models that use safetensors weights. This PR detects safetensors state files and converts them to PyTorch format using mmap on a temporary file, which is deleted after conversion is finished. Without this PR, the user would have to convert the safetensors files to pytorch format manually and feed it to `torch.onnx.ONNXProgram.save` manually. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121001 Approved by: https://github.com/BowenBao, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
485f8ebc07
commit
6c11d3ce0c
@ -458,7 +458,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
# Variant 1: Save ONNX proto using Model's state_dict()
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
|
||||
model_state_dict = Model().state_dict() # Create a state_dict for testing
|
||||
onnx_program.save(tmp_onnx_file.name, model_state_dict=model_state_dict)
|
||||
onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict)
|
||||
assert (
|
||||
len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
|
||||
), "Initializers must be present after loading it from model_state_dict"
|
||||
@ -472,9 +472,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
torch.save(
|
||||
Model().state_dict(), tmp_checkpoint_file.name
|
||||
) # Create checkpoint file for testing
|
||||
onnx_program.save(
|
||||
tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name
|
||||
)
|
||||
onnx_program.save(tmp_onnx_file.name, model_state=tmp_checkpoint_file.name)
|
||||
assert (
|
||||
len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
|
||||
), "Initializers must be present after loading it from model_state_dict"
|
||||
|
@ -1073,7 +1073,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
|
||||
onnx_program.save(
|
||||
tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name
|
||||
tmp_onnx_file.name, model_state=tmp_checkpoint_file.name
|
||||
)
|
||||
|
||||
# Generate random inputs.
|
||||
|
@ -478,7 +478,7 @@ def enable_fake_mode():
|
||||
>>> # Saving model WITHOUT initializers
|
||||
>>> onnx_program.save("my_model_without_initializers.onnx")
|
||||
>>> # Saving model WITH initializers
|
||||
>>> onnx_program.save("my_model_with_initializers.onnx", model_state_dict=MyModel().state_dict())
|
||||
>>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict())
|
||||
|
||||
.. warning::
|
||||
This API is experimental and is *NOT* backward-compatible.
|
||||
@ -964,19 +964,20 @@ class ONNXProgram:
|
||||
self,
|
||||
destination: Union[str, io.BufferedIOBase],
|
||||
*,
|
||||
model_state_dict: Optional[Union[Dict[str, Any], str]] = None,
|
||||
model_state: Optional[Union[Dict[str, Any], str]] = None,
|
||||
serializer: Optional[ONNXProgramSerializer] = None,
|
||||
) -> None:
|
||||
"""Saves the in-memory ONNX model to ``destination`` using specified ``serializer``.
|
||||
|
||||
Args:
|
||||
destination: The destination to save the ONNX model. It can be either a string or a file-like object.
|
||||
When used with ``model_state_dict``, it must be a string with a full path to the destination.
|
||||
When used with ``model_state``, it must be a string with a full path to the destination.
|
||||
If `destination` is a string, besides saving the ONNX model into a file, model weights are also stored
|
||||
in separate files in the same directory as the ONNX model. E.g. for `destination="/path/model.onnx"`,
|
||||
the initializers are saved in "/path/" folder along with "onnx.model".
|
||||
model_state_dict: The state_dict of the PyTorch model containing all weights on it.
|
||||
model_state: The state_dict of the PyTorch model containing all weights on it.
|
||||
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
|
||||
The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
|
||||
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
|
||||
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
|
||||
"""
|
||||
@ -987,27 +988,27 @@ class ONNXProgram:
|
||||
serializer = ProtobufONNXProgramSerializer()
|
||||
|
||||
# Add initializers when symbolic tracing is enabled
|
||||
_model_state_dict_files: List[Union[str, io.BytesIO, Dict[str, Any]]] = []
|
||||
if model_state_dict is not None:
|
||||
_model_state_files: List[Union[str, io.BytesIO, Dict[str, Any]]] = []
|
||||
if model_state is not None:
|
||||
assert isinstance(
|
||||
model_state_dict, (dict, str)
|
||||
), "model_state_dict must be a path to the model's state_dict or the actual state_dict"
|
||||
model_state, (dict, str)
|
||||
), "model_state must be a path to the model's state_dict or the actual state_dict"
|
||||
# NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
|
||||
# if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
|
||||
_model_state_dict_files.append(model_state_dict)
|
||||
_model_state_files.append(model_state)
|
||||
elif self._fake_context and self._fake_context.state_dict_paths:
|
||||
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
|
||||
for path in self._fake_context.state_dict_paths:
|
||||
if path in _model_state_dict_files:
|
||||
if path in _model_state_files:
|
||||
# ignore duplicate
|
||||
continue
|
||||
if os.path.exists(path): # type: ignore[arg-type]
|
||||
_model_state_dict_files.append(path)
|
||||
_model_state_files.append(path)
|
||||
|
||||
if _model_state_dict_files:
|
||||
if _model_state_files:
|
||||
if not isinstance(destination, str):
|
||||
raise RuntimeError(
|
||||
"`destination` must be a string with a path when `model_state_dict` is specified."
|
||||
"`destination` must be a string with a path when `model_state` is specified."
|
||||
)
|
||||
destination_path, destination_filename = os.path.split(destination)
|
||||
destination_path = destination_path or os.getcwd()
|
||||
@ -1018,7 +1019,7 @@ class ONNXProgram:
|
||||
destination_path,
|
||||
onnx_model_location,
|
||||
"", # When initializers >2GB, must be in the same folder as the model
|
||||
tuple(_model_state_dict_files),
|
||||
tuple(_model_state_files),
|
||||
self.model_proto,
|
||||
)
|
||||
else:
|
||||
|
@ -50,7 +50,7 @@ class ONNXTorchPatcher:
|
||||
self.paths: List[Union[str, io.BufferedIOBase]] = []
|
||||
|
||||
def torch_load_wrapper(f, *args, **kwargs):
|
||||
# Record path.
|
||||
# Record path for later serialization into ONNX proto
|
||||
self.paths.append(f)
|
||||
# Then, call the original torch.load.
|
||||
return self.torch_load(f, *args, **kwargs)
|
||||
@ -64,6 +64,8 @@ class ONNXTorchPatcher:
|
||||
if has_safetensors_and_transformers:
|
||||
|
||||
def safetensors_load_file_wrapper(filename, device="cpu"):
|
||||
# Record path for later serialization into ONNX proto
|
||||
self.paths.append(filename)
|
||||
result = {}
|
||||
with safetensors.torch.safe_open( # type: ignore[attr-defined]
|
||||
filename, framework="pt", device=device
|
||||
|
@ -83,6 +83,18 @@ def _create_tensor_proto_with_external_data(
|
||||
return tensor_proto
|
||||
|
||||
|
||||
def _convert_safetensors_to_torch_format(safetensors_file):
|
||||
# It this function is called, safetensors is guaranteed to exist
|
||||
# because the HF model with safetensors was already loaded and exported to ONNX
|
||||
from safetensors import safe_open # type: ignore[import-not-found]
|
||||
|
||||
tensors = {}
|
||||
with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined]
|
||||
for k in f.keys():
|
||||
tensors[k] = f.get_tensor(k).cpu()
|
||||
return tensors
|
||||
|
||||
|
||||
# TODO: generalize to allow more checkpoints formats (torch or gguf)
|
||||
@_beartype.beartype
|
||||
def save_model_with_external_data(
|
||||
@ -135,23 +147,26 @@ def save_model_with_external_data(
|
||||
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
|
||||
state_dict = el
|
||||
else:
|
||||
try:
|
||||
# Loads checkpoint using memory-map on CPU to support really large models
|
||||
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
|
||||
state_dict = torch.load(el, map_location="cpu", mmap=True)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "mmap can only be used with files saved with" in str(
|
||||
e
|
||||
) or isinstance(el, io.BytesIO):
|
||||
log.warning(
|
||||
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
|
||||
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
|
||||
)
|
||||
if isinstance(el, io.BytesIO):
|
||||
el.seek(0) # torch.load from `try:` has read the file.
|
||||
state_dict = torch.load(el, map_location="cpu")
|
||||
else:
|
||||
raise e
|
||||
if isinstance(el, str) and el.endswith(".safetensors"):
|
||||
state_dict = _convert_safetensors_to_torch_format(el)
|
||||
else:
|
||||
try:
|
||||
# Loads checkpoint using memory-map on CPU to support really large models
|
||||
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
|
||||
state_dict = torch.load(el, map_location="cpu", mmap=True)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if "mmap can only be used with files saved with" in str(
|
||||
e
|
||||
) or isinstance(el, io.BytesIO):
|
||||
log.warning(
|
||||
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
|
||||
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
|
||||
)
|
||||
if isinstance(el, io.BytesIO):
|
||||
el.seek(0) # torch.load from `try:` has read the file.
|
||||
state_dict = torch.load(el, map_location="cpu")
|
||||
else:
|
||||
raise e
|
||||
for name, tensor in state_dict.items():
|
||||
if rename_initializer:
|
||||
# Basically, "transformer.attention.self.query.weight" is mapped
|
||||
|
@ -1028,8 +1028,9 @@ def load(
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args)
|
||||
if mmap:
|
||||
f_name = "" if not isinstance(f, str) else f"{f}, "
|
||||
raise RuntimeError("mmap can only be used with files saved with "
|
||||
"`torch.save(_use_new_zipfile_serialization=True), "
|
||||
f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
|
||||
"please torch.save your checkpoint with this option in order to use mmap.")
|
||||
if weights_only:
|
||||
try:
|
||||
|
Reference in New Issue
Block a user