Add torch.serialization.skip_data context manager (#134504)

## Semantic

The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).

```python
import torch
import torch.nn as nn

sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
    torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```

(2)  With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)

```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode():
    m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')

sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
    torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.],
#        [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])

```

## Follow Ups

- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)

Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-09-05 06:50:12 -07:00
committed by PyTorch MergeBot
parent dbeb8a1691
commit a096f2899d
7 changed files with 270 additions and 26 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: serialization"]
import contextlib
import copy
import gc
import gzip
@ -19,6 +20,7 @@ from itertools import product
from pathlib import Path
import torch
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
from torch._utils import _rebuild_tensor
from torch._utils_internal import get_file_path_2
from torch.serialization import (
@ -27,6 +29,7 @@ from torch.serialization import (
LoadEndianness,
safe_globals,
set_default_load_endianness,
skip_data,
SourceChangeWarning,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@ -4212,6 +4215,91 @@ class TestSerialization(TestCase, SerializationMixin):
sd_loaded_ref = torch.load(f)
self.assertEqual(sd_loaded, sd_loaded_ref)
@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization(self, materialize_fake):
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
t_device = "cuda" if torch.cuda.is_available() else "cpu"
t_v2 = torch.randn(2, 3, device=t_device)
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)
if not materialize_fake:
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
st = torch.sparse_coo_tensor(i, v, (2, 4))
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
mode, converter = FakeTensorMode(), FakeTensorConverter()
def fn(t):
return converter.from_real_tensor(mode, t) if materialize_fake else t
sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
sd_expected = {
't_v2': torch.zeros(2, 3, device=t_device),
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
}
if not materialize_fake:
sd['st'] = st
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))
with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
with safe_globals([TwoTensor]):
sd_loaded = torch.load(f, weights_only=True)
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
if not materialize_fake:
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)
@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization_preserves_views(self, materialize_fake):
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
with ctx():
t = torch.randn(2, 3)
t_view = t.view(-1)
t_slice = t[1]
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
sd_loaded = torch.load(f, weights_only=True)
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
def test_skip_data_serialization_error_cases(self):
def _save_load(t):
with BytesIOContext() as f:
with skip_data():
torch.save(t, f)
f.seek(0)
torch.load(f, weights_only=True)
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
t = torch.randn(2, 3, device="meta")
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
_save_load(nt)
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
_save_load(t)
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
with skip_data(), BytesIOContext() as f:
torch.save(torch.randn(2, 3), f)
f.seek(0)
torch.load(f, weights_only=True)
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)