mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add context manager mentioned in https://github.com/pytorch/pytorch/pull/127808#pullrequestreview-2096298486 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127939 Approved by: https://github.com/albanD
401 lines
14 KiB
ReStructuredText
401 lines
14 KiB
ReStructuredText
|
||
Serialization semantics
|
||
=======================
|
||
|
||
This note describes how you can save and load PyTorch tensors and module states
|
||
in Python, and how to serialize Python modules so they can be loaded in C++.
|
||
|
||
.. contents:: Table of Contents
|
||
|
||
.. _saving-loading-tensors:
|
||
|
||
Saving and loading tensors
|
||
--------------------------
|
||
|
||
:func:`torch.save` and :func:`torch.load` let you easily save and load tensors:
|
||
|
||
::
|
||
|
||
>>> t = torch.tensor([1., 2.])
|
||
>>> torch.save(t, 'tensor.pt')
|
||
>>> torch.load('tensor.pt')
|
||
tensor([1., 2.])
|
||
|
||
By convention, PyTorch files are typically written with a ‘.pt’ or ‘.pth’ extension.
|
||
|
||
:func:`torch.save` and :func:`torch.load` use Python’s pickle by default,
|
||
so you can also save multiple tensors as part of Python objects like tuples,
|
||
lists, and dicts:
|
||
|
||
::
|
||
|
||
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
|
||
>>> torch.save(d, 'tensor_dict.pt')
|
||
>>> torch.load('tensor_dict.pt')
|
||
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
|
||
|
||
Custom data structures that include PyTorch tensors can also be saved if the
|
||
data structure is pickle-able.
|
||
|
||
.. _preserve-storage-sharing:
|
||
|
||
Saving and loading tensors preserves views
|
||
---------------------------------------------
|
||
|
||
Saving tensors preserves their view relationships:
|
||
|
||
::
|
||
|
||
>>> numbers = torch.arange(1, 10)
|
||
>>> evens = numbers[1::2]
|
||
>>> torch.save([numbers, evens], 'tensors.pt')
|
||
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
|
||
>>> loaded_evens *= 2
|
||
>>> loaded_numbers
|
||
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
|
||
|
||
Behind the scenes, these tensors share the same "storage." See
|
||
`Tensor Views <https://pytorch.org/docs/main/tensor_view.html>`_ for more
|
||
on views and storage.
|
||
|
||
When PyTorch saves tensors it saves their storage objects and tensor
|
||
metadata separately. This is an implementation detail that may change in the
|
||
future, but it typically saves space and lets PyTorch easily
|
||
reconstruct the view relationships between the loaded tensors. In the above
|
||
snippet, for example, only a single storage is written to 'tensors.pt'.
|
||
|
||
In some cases, however, saving the current storage objects may be unnecessary
|
||
and create prohibitively large files. In the following snippet a storage much
|
||
larger than the saved tensor is written to a file:
|
||
|
||
::
|
||
|
||
>>> large = torch.arange(1, 1000)
|
||
>>> small = large[0:5]
|
||
>>> torch.save(small, 'small.pt')
|
||
>>> loaded_small = torch.load('small.pt')
|
||
>>> loaded_small.storage().size()
|
||
999
|
||
|
||
Instead of saving only the five values in the `small` tensor to 'small.pt,'
|
||
the 999 values in the storage it shares with `large` were saved and loaded.
|
||
|
||
When saving tensors with fewer elements than their storage objects, the size of
|
||
the saved file can be reduced by first cloning the tensors. Cloning a tensor
|
||
produces a new tensor with a new storage object containing only the values
|
||
in the tensor:
|
||
|
||
::
|
||
|
||
>>> large = torch.arange(1, 1000)
|
||
>>> small = large[0:5]
|
||
>>> torch.save(small.clone(), 'small.pt') # saves a clone of small
|
||
>>> loaded_small = torch.load('small.pt')
|
||
>>> loaded_small.storage().size()
|
||
5
|
||
|
||
Since the cloned tensors are independent of each other, however, they have
|
||
none of the view relationships the original tensors did. If both file size and
|
||
view relationships are important when saving tensors smaller than their
|
||
storage objects, then care must be taken to construct new tensors that minimize
|
||
the size of their storage objects but still have the desired view relationships
|
||
before saving.
|
||
|
||
.. _saving-loading-python-modules:
|
||
|
||
Saving and loading torch.nn.Modules
|
||
-----------------------------------
|
||
|
||
See also: `Tutorial: Saving and loading modules <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_
|
||
|
||
In PyTorch, a module’s state is frequently serialized using a ‘state dict.’
|
||
A module’s state dict contains all of its parameters and persistent buffers:
|
||
|
||
::
|
||
|
||
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
|
||
>>> list(bn.named_parameters())
|
||
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
|
||
('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
|
||
|
||
>>> list(bn.named_buffers())
|
||
[('running_mean', tensor([0., 0., 0.])),
|
||
('running_var', tensor([1., 1., 1.])),
|
||
('num_batches_tracked', tensor(0))]
|
||
|
||
>>> bn.state_dict()
|
||
OrderedDict([('weight', tensor([1., 1., 1.])),
|
||
('bias', tensor([0., 0., 0.])),
|
||
('running_mean', tensor([0., 0., 0.])),
|
||
('running_var', tensor([1., 1., 1.])),
|
||
('num_batches_tracked', tensor(0))])
|
||
|
||
Instead of saving a module directly, for compatibility reasons it is recommended
|
||
to instead save only its state dict. Python modules even have a function,
|
||
:meth:`~torch.nn.Module.load_state_dict`, to restore their states from a state dict:
|
||
|
||
::
|
||
|
||
>>> torch.save(bn.state_dict(), 'bn.pt')
|
||
>>> bn_state_dict = torch.load('bn.pt')
|
||
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
|
||
>>> new_bn.load_state_dict(bn_state_dict)
|
||
<All keys matched successfully>
|
||
|
||
Note that the state dict is first loaded from its file with :func:`torch.load`
|
||
and the state then restored with :meth:`~torch.nn.Module.load_state_dict`.
|
||
|
||
Even custom modules and modules containing other modules have state dicts and
|
||
can use this pattern:
|
||
|
||
::
|
||
|
||
# A module with two linear layers
|
||
>>> class MyModule(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.l0 = torch.nn.Linear(4, 2)
|
||
self.l1 = torch.nn.Linear(2, 1)
|
||
|
||
def forward(self, input):
|
||
out0 = self.l0(input)
|
||
out0_relu = torch.nn.functional.relu(out0)
|
||
return self.l1(out0_relu)
|
||
|
||
>>> m = MyModule()
|
||
>>> m.state_dict()
|
||
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
|
||
[-0.3289, 0.2827, 0.4588, 0.2031]])),
|
||
('l0.bias', tensor([ 0.0300, -0.1316])),
|
||
('l1.weight', tensor([[0.6533, 0.3413]])),
|
||
('l1.bias', tensor([-0.1112]))])
|
||
|
||
>>> torch.save(m.state_dict(), 'mymodule.pt')
|
||
>>> m_state_dict = torch.load('mymodule.pt')
|
||
>>> new_m = MyModule()
|
||
>>> new_m.load_state_dict(m_state_dict)
|
||
<All keys matched successfully>
|
||
|
||
.. _serialized-file-format:
|
||
|
||
Serialized file format for ``torch.save``
|
||
-----------------------------------------
|
||
|
||
Since PyTorch 1.6.0, ``torch.save`` defaults to returning an uncompressed ZIP64
|
||
archive unless the user sets ``_use_new_zipfile_serialization=False``.
|
||
|
||
In this archive, the files are ordered as such
|
||
|
||
.. code-block:: text
|
||
|
||
checkpoint.pth
|
||
├── data.pkl
|
||
├── byteorder # added in PyTorch 2.1.0
|
||
├── data/
|
||
│ ├── 0
|
||
│ ├── 1
|
||
│ ├── 2
|
||
│ └── …
|
||
└── version
|
||
|
||
The entries are as follows:
|
||
* ``data.pkl`` is the result of pickling the object passed to ``torch.save``
|
||
excluding ``torch.Storage`` objects that it contains
|
||
* ``byteorder`` contains a string with the ``sys.byteorder`` when saving (“little” or “big”)
|
||
* ``data/`` contains all the storages in the object, where each storage is a separate file
|
||
* ``version`` contains a version number at save time that can be used at load time
|
||
|
||
When saving, PyTorch will ensure that the local file header of each file is padded
|
||
to an offset that is a multiple of 64 bytes, ensuring that the offset of each file
|
||
is 64-byte aligned.
|
||
|
||
.. note::
|
||
Tensors on certain devices such as XLA are serialized as pickled numpy arrays. As
|
||
such, their storages are not serialized. In these cases ``data/`` might not exist
|
||
in the checkpoint.
|
||
|
||
.. _serializing-python-modules:
|
||
|
||
Serializing torch.nn.Modules and loading them in C++
|
||
----------------------------------------------------
|
||
|
||
See also: `Tutorial: Loading a TorchScript Model in C++ <https://pytorch.org/tutorials/advanced/cpp_export.html>`_
|
||
|
||
ScriptModules can be serialized as a TorchScript program and loaded
|
||
using :func:`torch.jit.load`.
|
||
This serialization encodes all the modules’ methods, submodules, parameters,
|
||
and attributes, and it allows the serialized program to be loaded in C++
|
||
(i.e. without Python).
|
||
|
||
The distinction between :func:`torch.jit.save` and :func:`torch.save` may not
|
||
be immediately clear. :func:`torch.save` saves Python objects with pickle.
|
||
This is especially useful for prototyping, researching, and training.
|
||
:func:`torch.jit.save`, on the other hand, serializes ScriptModules to a format
|
||
that can be loaded in Python or C++. This is useful when saving and loading C++
|
||
modules or for running modules trained in Python with C++, a common practice
|
||
when deploying PyTorch models.
|
||
|
||
To script, serialize and load a module in Python:
|
||
|
||
::
|
||
|
||
>>> scripted_module = torch.jit.script(MyModule())
|
||
>>> torch.jit.save(scripted_module, 'mymodule.pt')
|
||
>>> torch.jit.load('mymodule.pt')
|
||
RecursiveScriptModule( original_name=MyModule
|
||
(l0): RecursiveScriptModule(original_name=Linear)
|
||
(l1): RecursiveScriptModule(original_name=Linear) )
|
||
|
||
|
||
Traced modules can also be saved with :func:`torch.jit.save`, with the caveat
|
||
that only the traced code path is serialized. The following example demonstrates
|
||
this:
|
||
|
||
::
|
||
|
||
# A module with control flow
|
||
>>> class ControlFlowModule(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.l0 = torch.nn.Linear(4, 2)
|
||
self.l1 = torch.nn.Linear(2, 1)
|
||
|
||
def forward(self, input):
|
||
if input.dim() > 1:
|
||
return torch.tensor(0)
|
||
|
||
out0 = self.l0(input)
|
||
out0_relu = torch.nn.functional.relu(out0)
|
||
return self.l1(out0_relu)
|
||
|
||
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
|
||
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
|
||
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
|
||
>>> loaded(torch.randn(2, 4)))
|
||
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
|
||
|
||
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
|
||
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
|
||
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
|
||
>> loaded(torch.randn(2, 4))
|
||
tensor(0)
|
||
|
||
The above module has an if statement that is not triggered by the traced inputs,
|
||
and so is not part of the traced module and not serialized with it.
|
||
The scripted module, however, contains the if statement and is serialized with it.
|
||
See the `TorchScript documentation <https://pytorch.org/docs/stable/jit.html>`_
|
||
for more on scripting and tracing.
|
||
|
||
Finally, to load the module in C++:
|
||
|
||
::
|
||
|
||
>>> torch::jit::script::Module module;
|
||
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
|
||
|
||
See the `PyTorch C++ API documentation <https://pytorch.org/cppdocs/>`_
|
||
for details about how to use PyTorch modules in C++.
|
||
|
||
.. _saving-loading-across-versions:
|
||
|
||
Saving and loading ScriptModules across PyTorch versions
|
||
-----------------------------------------------------------
|
||
|
||
The PyTorch Team recommends saving and loading modules with the same version of
|
||
PyTorch. Older versions of PyTorch may not support newer modules, and newer
|
||
versions may have removed or modified older behavior. These changes are
|
||
explicitly described in
|
||
PyTorch’s `release notes <https://github.com/pytorch/pytorch/releases>`_,
|
||
and modules relying on functionality that has changed may need to be updated
|
||
to continue working properly. In limited cases, detailed below, PyTorch will
|
||
preserve the historic behavior of serialized ScriptModules so they do not require
|
||
an update.
|
||
|
||
torch.div performing integer division
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
In PyTorch 1.5 and earlier :func:`torch.div` would perform floor division when
|
||
given two integer inputs:
|
||
|
||
::
|
||
|
||
# PyTorch 1.5 (and earlier)
|
||
>>> a = torch.tensor(5)
|
||
>>> b = torch.tensor(3)
|
||
>>> a / b
|
||
tensor(1)
|
||
|
||
In PyTorch 1.7, however, :func:`torch.div` will always perform a true division
|
||
of its inputs, just like division in Python 3:
|
||
|
||
::
|
||
|
||
# PyTorch 1.7
|
||
>>> a = torch.tensor(5)
|
||
>>> b = torch.tensor(3)
|
||
>>> a / b
|
||
tensor(1.6667)
|
||
|
||
The behavior of :func:`torch.div` is preserved in serialized ScriptModules.
|
||
That is, ScriptModules serialized with versions of PyTorch before 1.6 will continue
|
||
to see :func:`torch.div` perform floor division when given two integer inputs
|
||
even when loaded with newer versions of PyTorch. ScriptModules using :func:`torch.div`
|
||
and serialized on PyTorch 1.6 and later cannot be loaded in earlier versions of
|
||
PyTorch, however, since those earlier versions do not understand the new behavior.
|
||
|
||
torch.full always inferring a float dtype
|
||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||
|
||
In PyTorch 1.5 and earlier :func:`torch.full` always returned a float tensor,
|
||
regardless of the fill value it’s given:
|
||
|
||
::
|
||
|
||
# PyTorch 1.5 and earlier
|
||
>>> torch.full((3,), 1) # Note the integer fill value...
|
||
tensor([1., 1., 1.]) # ...but float tensor!
|
||
|
||
In PyTorch 1.7, however, :func:`torch.full` will infer the returned tensor’s
|
||
dtype from the fill value:
|
||
|
||
::
|
||
|
||
# PyTorch 1.7
|
||
>>> torch.full((3,), 1)
|
||
tensor([1, 1, 1])
|
||
|
||
>>> torch.full((3,), True)
|
||
tensor([True, True, True])
|
||
|
||
>>> torch.full((3,), 1.)
|
||
tensor([1., 1., 1.])
|
||
|
||
>>> torch.full((3,), 1 + 1j)
|
||
tensor([1.+1.j, 1.+1.j, 1.+1.j])
|
||
|
||
The behavior of :func:`torch.full` is preserved in serialized ScriptModules. That is,
|
||
ScriptModules serialized with versions of PyTorch before 1.6 will continue to see
|
||
torch.full return float tensors by default, even when given bool or
|
||
integer fill values. ScriptModules using :func:`torch.full` and serialized on PyTorch 1.6
|
||
and later cannot be loaded in earlier versions of PyTorch, however, since those
|
||
earlier versions do not understand the new behavior.
|
||
|
||
.. _utility functions:
|
||
|
||
Utility functions
|
||
-----------------
|
||
|
||
The following utility functions are related to serialization:
|
||
|
||
.. currentmodule:: torch.serialization
|
||
|
||
.. autofunction:: register_package
|
||
.. autofunction:: get_default_load_endianness
|
||
.. autofunction:: set_default_load_endianness
|
||
.. autofunction:: get_default_mmap_options
|
||
.. autofunction:: set_default_mmap_options
|
||
.. autofunction:: add_safe_globals
|
||
.. autofunction:: clear_safe_globals
|
||
.. autofunction:: get_safe_globals
|
||
.. autoclass:: safe_globals
|