diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index 019865e3b535..7834563f26e4 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -215,6 +215,46 @@ is 64-byte aligned. such, their storages are not serialized. In these cases ``data/`` might not exist in the checkpoint. +.. _layout-control: + +Layout Control +-------------- + +The ``mmap`` argument in :func:`torch.load` allows for lazy loading of tensor storages. + +In addition, there are some advanced features that allow for more fine-grained +control and manipulation of a ``torch.save`` checkpoint. + +The :class:`torch.serialization.skip_data` context manager enables + * Saving a checkpoint with ``torch.save`` that includes empty space for data bytes + to be written later. + * Loading a checkpoint with ``torch.load`` and filling in the data bytes of tensors later. + +To inspect tensor metadata in a ``torch.save`` checkpoint without allocating memory for storage +data, use ``torch.load`` within the ``FakeTensorMode`` context manager. On top of skipping loading +storage data similar to ``skip_data`` above, it additionally tags storages with their offset within +the checkpoint, enabling direct checkpoint manipulation. + +.. code-block:: python + + import torch.nn as nn + from torch._subclasses.fake_tensor import FakeTensorMode + + m = nn.Linear(10, 10) + torch.save(m.state_dict(), "checkpoint.pt") + + with FakeTensorMode() as mode: + fake_sd = torch.load("checkpoint.pt") + + for k, v in fake_sd.items(): + print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}") + # offset of the storage in the checkpoint + print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}") + +For more information, `this tutorial `_ +offers a comprehensive example of using these features to manipulate a checkpoint. + + .. _weights-only: ``torch.load`` with ``weights_only=True`` diff --git a/torch/serialization.py b/torch/serialization.py index 8dbd510b0b49..5ad421437518 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -924,6 +924,8 @@ def save( See also: :ref:`saving-loading-tensors` + See :ref:`layout-control` for more advanced tools to manipulate a checkpoint. + Args: obj: saved object f: a file-like object (has to implement write and flush) or a string or @@ -1313,6 +1315,8 @@ def load( User extensions can register their own location tags and tagging and deserialization methods using :func:`torch.serialization.register_package`. + See :ref:`layout-control` for more advanced tools to manipulate a checkpoint. + Args: f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), or a string or os.PathLike object containing a file name @@ -1328,7 +1332,8 @@ def load( Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they are moved to the location that they were tagged with when saving, or specified by ``map_location``. This second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the - tensor storages from disk to CPU memory in the first step, ``f`` is mmaped. + tensor storages from disk to CPU memory in the first step, ``f`` is mmaped, which means tensor storages + will be lazily loaded when their data is accessed. pickle_load_args: (Python 3 only) optional keyword arguments passed over to :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., :attr:`errors=...`.