mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update serialization docs (#153631)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153631 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
2fcbb903cb
commit
6383ddcfa4
@ -215,6 +215,46 @@ is 64-byte aligned.
|
||||
such, their storages are not serialized. In these cases ``data/`` might not exist
|
||||
in the checkpoint.
|
||||
|
||||
.. _layout-control:
|
||||
|
||||
Layout Control
|
||||
--------------
|
||||
|
||||
The ``mmap`` argument in :func:`torch.load` allows for lazy loading of tensor storages.
|
||||
|
||||
In addition, there are some advanced features that allow for more fine-grained
|
||||
control and manipulation of a ``torch.save`` checkpoint.
|
||||
|
||||
The :class:`torch.serialization.skip_data` context manager enables
|
||||
* Saving a checkpoint with ``torch.save`` that includes empty space for data bytes
|
||||
to be written later.
|
||||
* Loading a checkpoint with ``torch.load`` and filling in the data bytes of tensors later.
|
||||
|
||||
To inspect tensor metadata in a ``torch.save`` checkpoint without allocating memory for storage
|
||||
data, use ``torch.load`` within the ``FakeTensorMode`` context manager. On top of skipping loading
|
||||
storage data similar to ``skip_data`` above, it additionally tags storages with their offset within
|
||||
the checkpoint, enabling direct checkpoint manipulation.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch.nn as nn
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
m = nn.Linear(10, 10)
|
||||
torch.save(m.state_dict(), "checkpoint.pt")
|
||||
|
||||
with FakeTensorMode() as mode:
|
||||
fake_sd = torch.load("checkpoint.pt")
|
||||
|
||||
for k, v in fake_sd.items():
|
||||
print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
|
||||
# offset of the storage in the checkpoint
|
||||
print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")
|
||||
|
||||
For more information, `this tutorial <https://docs.pytorch.org/tutorials/prototype/gpu_direct_storage.html>`_
|
||||
offers a comprehensive example of using these features to manipulate a checkpoint.
|
||||
|
||||
|
||||
.. _weights-only:
|
||||
|
||||
``torch.load`` with ``weights_only=True``
|
||||
|
@ -924,6 +924,8 @@ def save(
|
||||
|
||||
See also: :ref:`saving-loading-tensors`
|
||||
|
||||
See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
|
||||
|
||||
Args:
|
||||
obj: saved object
|
||||
f: a file-like object (has to implement write and flush) or a string or
|
||||
@ -1313,6 +1315,8 @@ def load(
|
||||
User extensions can register their own location tags and tagging and
|
||||
deserialization methods using :func:`torch.serialization.register_package`.
|
||||
|
||||
See :ref:`layout-control` for more advanced tools to manipulate a checkpoint.
|
||||
|
||||
Args:
|
||||
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
|
||||
or a string or os.PathLike object containing a file name
|
||||
@ -1328,7 +1332,8 @@ def load(
|
||||
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
|
||||
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
|
||||
second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
|
||||
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
|
||||
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped, which means tensor storages
|
||||
will be lazily loaded when their data is accessed.
|
||||
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
|
||||
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
|
||||
:attr:`errors=...`.
|
||||
|
Reference in New Issue
Block a user