Update serialization docs (#153631)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153631
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-05-19 08:31:24 -07:00
committed by PyTorch MergeBot
parent 2fcbb903cb
commit 6383ddcfa4
2 changed files with 46 additions and 1 deletions

View File

@ -215,6 +215,46 @@ is 64-byte aligned.
such, their storages are not serialized. In these cases ``data/`` might not exist
in the checkpoint.
.. _layout-control:
Layout Control
--------------
The ``mmap`` argument in :func:`torch.load` allows for lazy loading of tensor storages.
In addition, there are some advanced features that allow for more fine-grained
control and manipulation of a ``torch.save`` checkpoint.
The :class:`torch.serialization.skip_data` context manager enables
* Saving a checkpoint with ``torch.save`` that includes empty space for data bytes
to be written later.
* Loading a checkpoint with ``torch.load`` and filling in the data bytes of tensors later.
To inspect tensor metadata in a ``torch.save`` checkpoint without allocating memory for storage
data, use ``torch.load`` within the ``FakeTensorMode`` context manager. On top of skipping loading
storage data similar to ``skip_data`` above, it additionally tags storages with their offset within
the checkpoint, enabling direct checkpoint manipulation.
.. code-block:: python
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
m = nn.Linear(10, 10)
torch.save(m.state_dict(), "checkpoint.pt")
with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")
for k, v in fake_sd.items():
print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
# offset of the storage in the checkpoint
print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")
For more information, `this tutorial <https://docs.pytorch.org/tutorials/prototype/gpu_direct_storage.html>`_
offers a comprehensive example of using these features to manipulate a checkpoint.
.. _weights-only:
``torch.load`` with ``weights_only=True``

View File

@ -924,6 +924,8 @@ def save(
See also: :ref:`saving-loading-tensors`
See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string or
@ -1313,6 +1315,8 @@ def load(
User extensions can register their own location tags and tagging and
deserialization methods using :func:`torch.serialization.register_package`.
See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
Args:
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
or a string or os.PathLike object containing a file name
@ -1328,7 +1332,8 @@ def load(
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped, which means tensor storages
will be lazily loaded when their data is accessed.
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
:attr:`errors=...`.