mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Add torch.serialization.skip_data context manager (#134504)"
This reverts commit 94db935749b8de99d8c3ab23fb880c67c8f3e67a. Reverted https://github.com/pytorch/pytorch/pull/134504 on behalf of https://github.com/kit1980 due to See D62082697 ([comment](https://github.com/pytorch/pytorch/pull/134504#issuecomment-2327542276))
This commit is contained in:
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: serialization"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import gzip
|
||||
@ -20,7 +19,6 @@ from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.serialization import (
|
||||
@ -29,7 +27,6 @@ from torch.serialization import (
|
||||
LoadEndianness,
|
||||
safe_globals,
|
||||
set_default_load_endianness,
|
||||
skip_data,
|
||||
SourceChangeWarning,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
@ -4215,91 +4212,6 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
sd_loaded_ref = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd_loaded_ref)
|
||||
|
||||
@parametrize("materialize_fake", (True, False))
|
||||
def test_skip_data_serialization(self, materialize_fake):
|
||||
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
|
||||
t_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
t_v2 = torch.randn(2, 3, device=t_device)
|
||||
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
|
||||
i = torch.tensor([[0, 1, 1],
|
||||
[2, 0, 2]])
|
||||
v = torch.tensor([3, 4, 5], dtype=torch.float32)
|
||||
if not materialize_fake:
|
||||
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
|
||||
st = torch.sparse_coo_tensor(i, v, (2, 4))
|
||||
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
|
||||
|
||||
mode, converter = FakeTensorMode(), FakeTensorConverter()
|
||||
|
||||
def fn(t):
|
||||
return converter.from_real_tensor(mode, t) if materialize_fake else t
|
||||
|
||||
sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
|
||||
sd_expected = {
|
||||
't_v2': torch.zeros(2, 3, device=t_device),
|
||||
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
|
||||
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
|
||||
}
|
||||
|
||||
if not materialize_fake:
|
||||
sd['st'] = st
|
||||
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))
|
||||
|
||||
with BytesIOContext() as f:
|
||||
with skip_data(materialize_fake_tensors=materialize_fake):
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TwoTensor]):
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
|
||||
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
|
||||
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))
|
||||
|
||||
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
|
||||
if not materialize_fake:
|
||||
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
|
||||
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(ft, f)
|
||||
|
||||
@parametrize("materialize_fake", (True, False))
|
||||
def test_skip_data_serialization_preserves_views(self, materialize_fake):
|
||||
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
|
||||
with ctx():
|
||||
t = torch.randn(2, 3)
|
||||
t_view = t.view(-1)
|
||||
t_slice = t[1]
|
||||
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
|
||||
with BytesIOContext() as f:
|
||||
with skip_data(materialize_fake_tensors=materialize_fake):
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
|
||||
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
|
||||
|
||||
def test_skip_data_serialization_error_cases(self):
|
||||
def _save_load(t):
|
||||
with BytesIOContext() as f:
|
||||
with skip_data():
|
||||
torch.save(t, f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
|
||||
t = torch.randn(2, 3, device="meta")
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
|
||||
_save_load(nt)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
|
||||
_save_load(t)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(torch.randn(2, 3), f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user