Add support to save safetensors checkpoint directly into onnx (#121001)

Currently, when `torch.onnx.dynamo_export` is called within `torch.onnx.enable_fake_mode`, all the external pytorch checkpoint files used to initialize the model are automatically and used by `torch.onnx.ONNXProgram.save` to recreate the initializers for
the newly exported ONNX model.

This API extends the mechanism for HuggingFace models that use safetensors weights. This PR detects safetensors state files and converts them to PyTorch format using mmap on a temporary file, which is deleted after conversion is finished.

Without this PR, the user would have to convert the safetensors files to pytorch format manually and feed it to `torch.onnx.ONNXProgram.save` manually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121001
Approved by: https://github.com/BowenBao, https://github.com/malfet
This commit is contained in:
Thiago Crepaldi
2024-03-11 15:21:59 +00:00
committed by PyTorch MergeBot
parent 485f8ebc07
commit 6c11d3ce0c
6 changed files with 55 additions and 38 deletions

View File

@ -458,7 +458,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
# Variant 1: Save ONNX proto using Model's state_dict()
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
model_state_dict = Model().state_dict() # Create a state_dict for testing
onnx_program.save(tmp_onnx_file.name, model_state_dict=model_state_dict)
onnx_program.save(tmp_onnx_file.name, model_state=model_state_dict)
assert (
len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
), "Initializers must be present after loading it from model_state_dict"
@ -472,9 +472,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
torch.save(
Model().state_dict(), tmp_checkpoint_file.name
) # Create checkpoint file for testing
onnx_program.save(
tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name
)
onnx_program.save(tmp_onnx_file.name, model_state=tmp_checkpoint_file.name)
assert (
len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
), "Initializers must be present after loading it from model_state_dict"

View File

@ -1073,7 +1073,7 @@ class TestFxToOnnxFakeTensorWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
onnx_program.save(
tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name
tmp_onnx_file.name, model_state=tmp_checkpoint_file.name
)
# Generate random inputs.

View File

@ -478,7 +478,7 @@ def enable_fake_mode():
>>> # Saving model WITHOUT initializers
>>> onnx_program.save("my_model_without_initializers.onnx")
>>> # Saving model WITH initializers
>>> onnx_program.save("my_model_with_initializers.onnx", model_state_dict=MyModel().state_dict())
>>> onnx_program.save("my_model_with_initializers.onnx", model_state=MyModel().state_dict())
.. warning::
This API is experimental and is *NOT* backward-compatible.
@ -964,19 +964,20 @@ class ONNXProgram:
self,
destination: Union[str, io.BufferedIOBase],
*,
model_state_dict: Optional[Union[Dict[str, Any], str]] = None,
model_state: Optional[Union[Dict[str, Any], str]] = None,
serializer: Optional[ONNXProgramSerializer] = None,
) -> None:
"""Saves the in-memory ONNX model to ``destination`` using specified ``serializer``.
Args:
destination: The destination to save the ONNX model. It can be either a string or a file-like object.
When used with ``model_state_dict``, it must be a string with a full path to the destination.
When used with ``model_state``, it must be a string with a full path to the destination.
If `destination` is a string, besides saving the ONNX model into a file, model weights are also stored
in separate files in the same directory as the ONNX model. E.g. for `destination="/path/model.onnx"`,
the initializers are saved in "/path/" folder along with "onnx.model".
model_state_dict: The state_dict of the PyTorch model containing all weights on it.
model_state: The state_dict of the PyTorch model containing all weights on it.
It can be either a string with the path to a checkpoint or a dictionary with the actual model state.
The supported file formats are the same as those supported by `torch.load` and `safetensors.safe_open`.
Required when :func:`enable_fake_mode` is used but real initializers are needed on the ONNX graph.
serializer: The serializer to use. If not specified, the model will be serialized as Protobuf.
"""
@ -987,27 +988,27 @@ class ONNXProgram:
serializer = ProtobufONNXProgramSerializer()
# Add initializers when symbolic tracing is enabled
_model_state_dict_files: List[Union[str, io.BytesIO, Dict[str, Any]]] = []
if model_state_dict is not None:
_model_state_files: List[Union[str, io.BytesIO, Dict[str, Any]]] = []
if model_state is not None:
assert isinstance(
model_state_dict, (dict, str)
), "model_state_dict must be a path to the model's state_dict or the actual state_dict"
model_state, (dict, str)
), "model_state must be a path to the model's state_dict or the actual state_dict"
# NOTE: For dict, there can be performance penalty or high memory usage that might lead to OOM
# if the dict wasn't loaded with torch.load(..., mmap=True, map_location="cpu")
_model_state_dict_files.append(model_state_dict)
_model_state_files.append(model_state)
elif self._fake_context and self._fake_context.state_dict_paths:
# Load state from previous model.load_state_dict() call within enable_fake_mode() context
for path in self._fake_context.state_dict_paths:
if path in _model_state_dict_files:
if path in _model_state_files:
# ignore duplicate
continue
if os.path.exists(path): # type: ignore[arg-type]
_model_state_dict_files.append(path)
_model_state_files.append(path)
if _model_state_dict_files:
if _model_state_files:
if not isinstance(destination, str):
raise RuntimeError(
"`destination` must be a string with a path when `model_state_dict` is specified."
"`destination` must be a string with a path when `model_state` is specified."
)
destination_path, destination_filename = os.path.split(destination)
destination_path = destination_path or os.getcwd()
@ -1018,7 +1019,7 @@ class ONNXProgram:
destination_path,
onnx_model_location,
"", # When initializers >2GB, must be in the same folder as the model
tuple(_model_state_dict_files),
tuple(_model_state_files),
self.model_proto,
)
else:

View File

@ -50,7 +50,7 @@ class ONNXTorchPatcher:
self.paths: List[Union[str, io.BufferedIOBase]] = []
def torch_load_wrapper(f, *args, **kwargs):
# Record path.
# Record path for later serialization into ONNX proto
self.paths.append(f)
# Then, call the original torch.load.
return self.torch_load(f, *args, **kwargs)
@ -64,6 +64,8 @@ class ONNXTorchPatcher:
if has_safetensors_and_transformers:
def safetensors_load_file_wrapper(filename, device="cpu"):
# Record path for later serialization into ONNX proto
self.paths.append(filename)
result = {}
with safetensors.torch.safe_open( # type: ignore[attr-defined]
filename, framework="pt", device=device

View File

@ -83,6 +83,18 @@ def _create_tensor_proto_with_external_data(
return tensor_proto
def _convert_safetensors_to_torch_format(safetensors_file):
# It this function is called, safetensors is guaranteed to exist
# because the HF model with safetensors was already loaded and exported to ONNX
from safetensors import safe_open # type: ignore[import-not-found]
tensors = {}
with safe_open(safetensors_file, framework="pt", device="cpu") as f: # type: ignore[attr-defined]
for k in f.keys():
tensors[k] = f.get_tensor(k).cpu()
return tensors
# TODO: generalize to allow more checkpoints formats (torch or gguf)
@_beartype.beartype
def save_model_with_external_data(
@ -135,23 +147,26 @@ def save_model_with_external_data(
# Using torch.save wouldn't leverage mmap, leading to higher memory usage
state_dict = el
else:
try:
# Loads checkpoint using memory-map on CPU to support really large models
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
state_dict = torch.load(el, map_location="cpu", mmap=True)
except (RuntimeError, ValueError) as e:
if "mmap can only be used with files saved with" in str(
e
) or isinstance(el, io.BytesIO):
log.warning(
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
)
if isinstance(el, io.BytesIO):
el.seek(0) # torch.load from `try:` has read the file.
state_dict = torch.load(el, map_location="cpu")
else:
raise e
if isinstance(el, str) and el.endswith(".safetensors"):
state_dict = _convert_safetensors_to_torch_format(el)
else:
try:
# Loads checkpoint using memory-map on CPU to support really large models
# The underlying torch.UntypedStorage is memory mapped, so state_dict is lazy loaded
state_dict = torch.load(el, map_location="cpu", mmap=True)
except (RuntimeError, ValueError) as e:
if "mmap can only be used with files saved with" in str(
e
) or isinstance(el, io.BytesIO):
log.warning(
"Failed to load the checkpoint with memory-map enabled, retrying without memory-map."
"Consider updating the checkpoint with mmap by using torch.save() on PyTorch version >= 1.6."
)
if isinstance(el, io.BytesIO):
el.seek(0) # torch.load from `try:` has read the file.
state_dict = torch.load(el, map_location="cpu")
else:
raise e
for name, tensor in state_dict.items():
if rename_initializer:
# Basically, "transformer.attention.self.query.weight" is mapped

View File

@ -1028,8 +1028,9 @@ def load(
overall_storage=overall_storage,
**pickle_load_args)
if mmap:
f_name = "" if not isinstance(f, str) else f"{f}, "
raise RuntimeError("mmap can only be used with files saved with "
"`torch.save(_use_new_zipfile_serialization=True), "
f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
"please torch.save your checkpoint with this option in order to use mmap.")
if weights_only:
try: