mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.serialization.skip_data context manager (#134504)
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dbeb8a1691
commit
a096f2899d
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: serialization"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import gzip
|
||||
@ -19,6 +20,7 @@ from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.serialization import (
|
||||
@ -27,6 +29,7 @@ from torch.serialization import (
|
||||
LoadEndianness,
|
||||
safe_globals,
|
||||
set_default_load_endianness,
|
||||
skip_data,
|
||||
SourceChangeWarning,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
@ -4212,6 +4215,91 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
sd_loaded_ref = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd_loaded_ref)
|
||||
|
||||
@parametrize("materialize_fake", (True, False))
|
||||
def test_skip_data_serialization(self, materialize_fake):
|
||||
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
|
||||
t_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
t_v2 = torch.randn(2, 3, device=t_device)
|
||||
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
|
||||
i = torch.tensor([[0, 1, 1],
|
||||
[2, 0, 2]])
|
||||
v = torch.tensor([3, 4, 5], dtype=torch.float32)
|
||||
if not materialize_fake:
|
||||
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
|
||||
st = torch.sparse_coo_tensor(i, v, (2, 4))
|
||||
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
|
||||
|
||||
mode, converter = FakeTensorMode(), FakeTensorConverter()
|
||||
|
||||
def fn(t):
|
||||
return converter.from_real_tensor(mode, t) if materialize_fake else t
|
||||
|
||||
sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
|
||||
sd_expected = {
|
||||
't_v2': torch.zeros(2, 3, device=t_device),
|
||||
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
|
||||
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
|
||||
}
|
||||
|
||||
if not materialize_fake:
|
||||
sd['st'] = st
|
||||
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))
|
||||
|
||||
with BytesIOContext() as f:
|
||||
with skip_data(materialize_fake_tensors=materialize_fake):
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TwoTensor]):
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
|
||||
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
|
||||
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))
|
||||
|
||||
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
|
||||
if not materialize_fake:
|
||||
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
|
||||
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(ft, f)
|
||||
|
||||
@parametrize("materialize_fake", (True, False))
|
||||
def test_skip_data_serialization_preserves_views(self, materialize_fake):
|
||||
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
|
||||
with ctx():
|
||||
t = torch.randn(2, 3)
|
||||
t_view = t.view(-1)
|
||||
t_slice = t[1]
|
||||
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
|
||||
with BytesIOContext() as f:
|
||||
with skip_data(materialize_fake_tensors=materialize_fake):
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
|
||||
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
|
||||
|
||||
def test_skip_data_serialization_error_cases(self):
|
||||
def _save_load(t):
|
||||
with BytesIOContext() as f:
|
||||
with skip_data():
|
||||
torch.save(t, f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
|
||||
t = torch.randn(2, 3, device="meta")
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
|
||||
_save_load(nt)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
|
||||
_save_load(t)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(torch.randn(2, 3), f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user