Add torch.serialization.skip_data context manager (#134504)

## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-09-05 06:50:12 -07:00
committed by PyTorch MergeBot
parent dbeb8a1691
commit a096f2899d
7 changed files with 270 additions and 26 deletions

View File

@ -398,3 +398,4 @@ The following utility functions are related to serialization:
.. autofunction:: clear_safe_globals
.. autofunction:: get_safe_globals
.. autoclass:: safe_globals
.. autoclass:: skip_data

View File

@ -540,7 +540,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
# call _fused_adamw_ with undefined tensor.
self.module.fallback_with_undefined_tensor()
def test_open_device_numpy_serialization_map_location(self):
def test_open_device_numpy_serialization(self):
torch.utils.rename_privateuse1_backend("foo")
device = self.module.custom_device()
default_protocol = torch.serialization.DEFAULT_PROTOCOL
@ -553,6 +553,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertTrue(
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
)
# Test map_location
with TemporaryFileName() as f:
torch.save(sd, f)
with safe_globals(
@ -569,6 +570,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
sd_loaded = torch.load(f, map_location="cpu")
self.assertTrue(sd_loaded["x"].is_cpu)
# Test metadata_only
with TemporaryFileName() as f:
with self.assertRaisesRegex(
RuntimeError,
"Cannot serialize tensors on backends with no storage under skip_data context manager",
):
with torch.serialization.skip_data():
torch.save(sd, f)
if __name__ == "__main__":
common.run_tests()

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: serialization"]
import contextlib
import copy
import gc
import gzip
@ -19,6 +20,7 @@ from itertools import product
from pathlib import Path
import torch
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
from torch._utils import _rebuild_tensor
from torch._utils_internal import get_file_path_2
from torch.serialization import (
@ -27,6 +29,7 @@ from torch.serialization import (
LoadEndianness,
safe_globals,
set_default_load_endianness,
skip_data,
SourceChangeWarning,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@ -4212,6 +4215,91 @@ class TestSerialization(TestCase, SerializationMixin):
sd_loaded_ref = torch.load(f)
self.assertEqual(sd_loaded, sd_loaded_ref)
@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization(self, materialize_fake):
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
t_device = "cuda" if torch.cuda.is_available() else "cpu"
t_v2 = torch.randn(2, 3, device=t_device)
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)
if not materialize_fake:
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
st = torch.sparse_coo_tensor(i, v, (2, 4))
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
mode, converter = FakeTensorMode(), FakeTensorConverter()
def fn(t):
return converter.from_real_tensor(mode, t) if materialize_fake else t
sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
sd_expected = {
't_v2': torch.zeros(2, 3, device=t_device),
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
}
if not materialize_fake:
sd['st'] = st
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))
with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
with safe_globals([TwoTensor]):
sd_loaded = torch.load(f, weights_only=True)
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
if not materialize_fake:
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)
@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization_preserves_views(self, materialize_fake):
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
with ctx():
t = torch.randn(2, 3)
t_view = t.view(-1)
t_slice = t[1]
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
sd_loaded = torch.load(f, weights_only=True)
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
def test_skip_data_serialization_error_cases(self):
def _save_load(t):
with BytesIOContext() as f:
with skip_data():
torch.save(t, f)
f.seek(0)
torch.load(f, weights_only=True)
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
t = torch.randn(2, 3, device="meta")
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
_save_load(nt)
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
_save_load(t)
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
with skip_data(), BytesIOContext() as f:
torch.save(torch.randn(2, 3), f)
f.seek(0)
torch.load(f, weights_only=True)
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)

View File

@ -209,8 +209,19 @@ class Tensor(torch._C.TensorBase):
return new_tensor
def __reduce_ex__(self, proto):
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)
state = torch._utils._get_obj_state(self)
if type(self) is Tensor and not state:
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
# some state that cannot be pickled
if (
# TODO: remove hasattr, it's a hack to support versions of torch that
# don't have _subclasses
hasattr(torch, "_subclasses")
and type(self) is torch._subclasses.fake_tensor.FakeTensor
and materialize_fake_tensors
) or (type(self) is Tensor and not state):
# Fast path for regular tensor without Python state.
return self._reduce_ex_internal(proto)
if has_torch_function_unary(self):
@ -251,6 +262,12 @@ class Tensor(torch._C.TensorBase):
# See Note [Don't serialize hooks]
warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()
skip_data = torch.serialization._serialization_tls.skip_data
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)
# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
# We considered a few options:
# 1. CPU tensor can't be used here.
@ -268,6 +285,10 @@ class Tensor(torch._C.TensorBase):
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
if skip_data:
raise RuntimeError(
"Cannot serialize tensors on backends with no storage under skip_data context manager"
)
numpy_tensor = (
self.cpu().numpy()
if self.dtype != torch.bfloat16
@ -280,6 +301,10 @@ class Tensor(torch._C.TensorBase):
if self.device.type == "meta":
# NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors.
if skip_data:
warnings.warn(
"Serializing tensors on the meta device under skip_data context manager is a no-op"
)
arg_meta = (
self.dtype,
tuple(self.size()),
@ -288,6 +313,10 @@ class Tensor(torch._C.TensorBase):
)
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
if self.is_quantized:
if skip_data:
raise RuntimeError(
"Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
)
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
@ -369,6 +398,10 @@ class Tensor(torch._C.TensorBase):
)
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
elif self.is_nested:
if skip_data:
raise RuntimeError(
"Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
)
args_nested = (
# NB: values() currently returns the storage as a buffer in an unsafe way.
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
@ -383,14 +416,30 @@ class Tensor(torch._C.TensorBase):
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(
self,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
or (
not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and self.data_ptr() == 0
)
or self.data_ptr() == 0
)
):
arg_wrapper_subclass = (
type(self),
self.dtype,
tuple(self.size()),
self.stride(),
self.storage_offset(),
self.layout,
self.device,
self.requires_grad,
)
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
elif (
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and not (skip_data and materialize_fake_tensors)
)
):
arg_wrapper_subclass = (
@ -418,6 +467,16 @@ class Tensor(torch._C.TensorBase):
dtype=self.dtype,
_internal=True,
) # type: ignore[assignment]
# TODO: remove hasattr, it's a hack to support versions of torch that
# don't have _subclasses
if (
hasattr(torch, "_subclasses")
and isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and skip_data
):
storage._fake_device = self.device
args = (
storage,
self.storage_offset(),

View File

@ -3,7 +3,6 @@ import copyreg
import functools
import logging
import sys
import threading
import traceback
import warnings
from collections import defaultdict
@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
return kwargs["async"]
_thread_local_state = threading.local()
def _get_restore_location(device):
"""Return the map_location location.
Used for rebuild functions where the tensor device is distinct from the storage
"""
map_location = getattr(_thread_local_state, "map_location", None)
map_location = torch.serialization._serialization_tls.map_location
if map_location is None:
return device
else:

View File

@ -11,6 +11,7 @@ import struct
import sys
import tarfile
import tempfile
import threading
import warnings
from contextlib import closing, contextmanager
from enum import Enum
@ -60,6 +61,7 @@ __all__ = [
"get_safe_globals",
"add_safe_globals",
"safe_globals",
"skip_data",
]
@ -87,6 +89,22 @@ else:
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
# _serialization_tls is used to store thread local state specific to serialization
# that needs to be propagated to other files, in particular we use this for
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
class _SerializationLocal(threading.local):
def __init__(self):
super().__init__()
self.map_location: Optional[MAP_LOCATION] = None
self.skip_data: bool = False
self.materialize_fake_tensors: bool = False
_serialization_tls = _SerializationLocal()
class SourceChangeWarning(Warning):
pass
@ -268,6 +286,47 @@ class safe_globals(_weights_only_unpickler._safe_globals):
"""
class skip_data:
"""
Context-manager that skips writing storage bytes for ``torch.save`` calls.
Storages will still be saved, but the space that their bytes would usually be written to
will be empty space. The storage bytes can then be populated in a separate pass.
.. warning::
The ``skip_data`` context manager is an early prototype and is subject to change.
Args:
materialize_fake_tensors: Whether to materialize FakeTensors.
Example:
>>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
... with torch.serialization.skip_data():
... torch.save(t, f.name)
... torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
[0., 0., 0.]])
"""
def __init__(self, materialize_fake_tensors: bool = False):
self.materialize_fake_tensors = materialize_fake_tensors
def __enter__(self):
global _serialization_tls
self._old_skip_data = _serialization_tls.skip_data
self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors
_serialization_tls.skip_data = True
_serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors
def __exit__(self, type, value, tb):
global _serialization_tls
_serialization_tls.skip_data = self._old_skip_data
_serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors
def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
@ -797,6 +856,11 @@ def save(
)
return
else:
global _serialization_tls
if _serialization_tls.skip_data:
raise RuntimeError(
"Cannot use skip_data=True with _use_new_zipfile_serialization=False"
)
with _open_file_like(f, "wb") as opened_file:
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
@ -955,7 +1019,13 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
)
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
def _save(
obj,
zip_file,
pickle_module,
pickle_protocol,
_disable_byteorder_record,
):
serialized_storages = {}
id_map: Dict[int, str] = {}
@ -990,7 +1060,7 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if str(storage.device) != "meta" and storage.data_ptr() != 0:
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
@ -1001,7 +1071,10 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
location = location_tag(storage)
if hasattr(obj, "_fake_device") and obj._fake_device is not None:
location = str(obj._fake_device)
else:
location = location_tag(storage)
serialized_storages[storage_key] = storage
return ("storage", storage_type, storage_key, location, storage_numel)
@ -1027,14 +1100,18 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
for key in sorted(serialized_storages.keys()):
name = f"data/{key}"
storage = serialized_storages[key]
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != "cpu":
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
num_bytes = storage.nbytes()
zip_file.write_record(name, storage, num_bytes)
global _serialization_tls
if _serialization_tls.skip_data:
zip_file.write_record_metadata(name, num_bytes)
else:
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != "cpu":
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
zip_file.write_record(name, storage, num_bytes)
def load(
@ -1184,6 +1261,14 @@ def load(
updated_message += message
return updated_message + DOCS_MESSAGE
global _serialization_tls
skip_data = _serialization_tls.skip_data
if skip_data:
raise RuntimeError(
"`torch.load` called within a torch.serialization.skip_data context manager "
"is not supported yet. Please call torch.load outside the skip_data context manager."
)
if weights_only is None:
weights_only, warn_weights_only = False, True
else:
@ -1758,9 +1843,10 @@ def _load(
unpickler.persistent_load = persistent_load
# Needed for tensors where storage device and rebuild tensor device are
# not connected (wrapper subclasses and tensors rebuilt using numpy)
torch._utils._thread_local_state.map_location = map_location
global _serialization_tls
_serialization_tls.map_location = map_location
result = unpickler.load()
del torch._utils._thread_local_state.map_location
_serialization_tls.map_location = None
torch._utils._validate_loaded_sparse_tensors()
torch._C._log_api_usage_metadata(

View File

@ -39,6 +39,8 @@ class _StorageBase:
is_sparse: _bool = False
is_sparse_csr: _bool = False
device: torch.device
# Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
_fake_device: _Optional[torch.device] = None
def __init__(self, *args, **kwargs):
pass
@ -649,6 +651,8 @@ def _get_device_from_module(module: str):
class TypedStorage:
is_sparse: _bool = False
# Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
_fake_device: _Optional[torch.device] = None
dtype: torch.dtype