mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add torch.serialization.skip_data context manager (#134504)
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
dbeb8a1691
commit
a096f2899d
@ -398,3 +398,4 @@ The following utility functions are related to serialization:
|
||||
.. autofunction:: clear_safe_globals
|
||||
.. autofunction:: get_safe_globals
|
||||
.. autoclass:: safe_globals
|
||||
.. autoclass:: skip_data
|
||||
|
@ -540,7 +540,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
# call _fused_adamw_ with undefined tensor.
|
||||
self.module.fallback_with_undefined_tensor()
|
||||
|
||||
def test_open_device_numpy_serialization_map_location(self):
|
||||
def test_open_device_numpy_serialization(self):
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
device = self.module.custom_device()
|
||||
default_protocol = torch.serialization.DEFAULT_PROTOCOL
|
||||
@ -553,6 +553,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
self.assertTrue(
|
||||
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
|
||||
)
|
||||
# Test map_location
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
with safe_globals(
|
||||
@ -569,6 +570,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
sd_loaded = torch.load(f, map_location="cpu")
|
||||
self.assertTrue(sd_loaded["x"].is_cpu)
|
||||
|
||||
# Test metadata_only
|
||||
with TemporaryFileName() as f:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Cannot serialize tensors on backends with no storage under skip_data context manager",
|
||||
):
|
||||
with torch.serialization.skip_data():
|
||||
torch.save(sd, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: serialization"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import gc
|
||||
import gzip
|
||||
@ -19,6 +20,7 @@ from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.serialization import (
|
||||
@ -27,6 +29,7 @@ from torch.serialization import (
|
||||
LoadEndianness,
|
||||
safe_globals,
|
||||
set_default_load_endianness,
|
||||
skip_data,
|
||||
SourceChangeWarning,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
@ -4212,6 +4215,91 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
sd_loaded_ref = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd_loaded_ref)
|
||||
|
||||
@parametrize("materialize_fake", (True, False))
|
||||
def test_skip_data_serialization(self, materialize_fake):
|
||||
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
|
||||
t_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
t_v2 = torch.randn(2, 3, device=t_device)
|
||||
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
|
||||
i = torch.tensor([[0, 1, 1],
|
||||
[2, 0, 2]])
|
||||
v = torch.tensor([3, 4, 5], dtype=torch.float32)
|
||||
if not materialize_fake:
|
||||
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
|
||||
st = torch.sparse_coo_tensor(i, v, (2, 4))
|
||||
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
|
||||
|
||||
mode, converter = FakeTensorMode(), FakeTensorConverter()
|
||||
|
||||
def fn(t):
|
||||
return converter.from_real_tensor(mode, t) if materialize_fake else t
|
||||
|
||||
sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
|
||||
sd_expected = {
|
||||
't_v2': torch.zeros(2, 3, device=t_device),
|
||||
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
|
||||
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
|
||||
}
|
||||
|
||||
if not materialize_fake:
|
||||
sd['st'] = st
|
||||
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))
|
||||
|
||||
with BytesIOContext() as f:
|
||||
with skip_data(materialize_fake_tensors=materialize_fake):
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TwoTensor]):
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
|
||||
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
|
||||
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))
|
||||
|
||||
# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
|
||||
if not materialize_fake:
|
||||
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
|
||||
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(ft, f)
|
||||
|
||||
@parametrize("materialize_fake", (True, False))
|
||||
def test_skip_data_serialization_preserves_views(self, materialize_fake):
|
||||
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
|
||||
with ctx():
|
||||
t = torch.randn(2, 3)
|
||||
t_view = t.view(-1)
|
||||
t_slice = t[1]
|
||||
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
|
||||
with BytesIOContext() as f:
|
||||
with skip_data(materialize_fake_tensors=materialize_fake):
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
|
||||
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
|
||||
|
||||
def test_skip_data_serialization_error_cases(self):
|
||||
def _save_load(t):
|
||||
with BytesIOContext() as f:
|
||||
with skip_data():
|
||||
torch.save(t, f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
|
||||
t = torch.randn(2, 3, device="meta")
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
|
||||
_save_load(nt)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
|
||||
_save_load(t)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(torch.randn(2, 3), f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
@ -209,8 +209,19 @@ class Tensor(torch._C.TensorBase):
|
||||
return new_tensor
|
||||
|
||||
def __reduce_ex__(self, proto):
|
||||
materialize_fake_tensors = (
|
||||
torch.serialization._serialization_tls.materialize_fake_tensors
|
||||
)
|
||||
state = torch._utils._get_obj_state(self)
|
||||
if type(self) is Tensor and not state:
|
||||
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
|
||||
# some state that cannot be pickled
|
||||
if (
|
||||
# TODO: remove hasattr, it's a hack to support versions of torch that
|
||||
# don't have _subclasses
|
||||
hasattr(torch, "_subclasses")
|
||||
and type(self) is torch._subclasses.fake_tensor.FakeTensor
|
||||
and materialize_fake_tensors
|
||||
) or (type(self) is Tensor and not state):
|
||||
# Fast path for regular tensor without Python state.
|
||||
return self._reduce_ex_internal(proto)
|
||||
if has_torch_function_unary(self):
|
||||
@ -251,6 +262,12 @@ class Tensor(torch._C.TensorBase):
|
||||
# See Note [Don't serialize hooks]
|
||||
warn_if_has_hooks(self)
|
||||
backward_hooks: Dict[Any, Any] = OrderedDict()
|
||||
|
||||
skip_data = torch.serialization._serialization_tls.skip_data
|
||||
materialize_fake_tensors = (
|
||||
torch.serialization._serialization_tls.materialize_fake_tensors
|
||||
)
|
||||
|
||||
# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
|
||||
# We considered a few options:
|
||||
# 1. CPU tensor can't be used here.
|
||||
@ -268,6 +285,10 @@ class Tensor(torch._C.TensorBase):
|
||||
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
|
||||
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
|
||||
# this would reconstruct the BFloat16 tensor from numpy.
|
||||
if skip_data:
|
||||
raise RuntimeError(
|
||||
"Cannot serialize tensors on backends with no storage under skip_data context manager"
|
||||
)
|
||||
numpy_tensor = (
|
||||
self.cpu().numpy()
|
||||
if self.dtype != torch.bfloat16
|
||||
@ -280,6 +301,10 @@ class Tensor(torch._C.TensorBase):
|
||||
if self.device.type == "meta":
|
||||
# NB: This implementation BREAKS storage sharing. Current
|
||||
# hypothesis is that no one cares for meta tensors.
|
||||
if skip_data:
|
||||
warnings.warn(
|
||||
"Serializing tensors on the meta device under skip_data context manager is a no-op"
|
||||
)
|
||||
arg_meta = (
|
||||
self.dtype,
|
||||
tuple(self.size()),
|
||||
@ -288,6 +313,10 @@ class Tensor(torch._C.TensorBase):
|
||||
)
|
||||
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
|
||||
if self.is_quantized:
|
||||
if skip_data:
|
||||
raise RuntimeError(
|
||||
"Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
|
||||
)
|
||||
# quantizer_params can be different type based on torch attribute
|
||||
quantizer_params: Union[
|
||||
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
|
||||
@ -369,6 +398,10 @@ class Tensor(torch._C.TensorBase):
|
||||
)
|
||||
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
|
||||
elif self.is_nested:
|
||||
if skip_data:
|
||||
raise RuntimeError(
|
||||
"Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
|
||||
)
|
||||
args_nested = (
|
||||
# NB: values() currently returns the storage as a buffer in an unsafe way.
|
||||
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
|
||||
@ -383,14 +416,30 @@ class Tensor(torch._C.TensorBase):
|
||||
type(self) is not torch.Tensor
|
||||
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
|
||||
and (
|
||||
isinstance(
|
||||
self,
|
||||
(
|
||||
torch._subclasses.fake_tensor.FakeTensor,
|
||||
torch._subclasses.functional_tensor.FunctionalTensor,
|
||||
),
|
||||
isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
|
||||
or (
|
||||
not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and self.data_ptr() == 0
|
||||
)
|
||||
or self.data_ptr() == 0
|
||||
)
|
||||
):
|
||||
arg_wrapper_subclass = (
|
||||
type(self),
|
||||
self.dtype,
|
||||
tuple(self.size()),
|
||||
self.stride(),
|
||||
self.storage_offset(),
|
||||
self.layout,
|
||||
self.device,
|
||||
self.requires_grad,
|
||||
)
|
||||
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
|
||||
elif (
|
||||
type(self) is not torch.Tensor
|
||||
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
|
||||
and (
|
||||
isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and not (skip_data and materialize_fake_tensors)
|
||||
)
|
||||
):
|
||||
arg_wrapper_subclass = (
|
||||
@ -418,6 +467,16 @@ class Tensor(torch._C.TensorBase):
|
||||
dtype=self.dtype,
|
||||
_internal=True,
|
||||
) # type: ignore[assignment]
|
||||
|
||||
# TODO: remove hasattr, it's a hack to support versions of torch that
|
||||
# don't have _subclasses
|
||||
if (
|
||||
hasattr(torch, "_subclasses")
|
||||
and isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and skip_data
|
||||
):
|
||||
storage._fake_device = self.device
|
||||
|
||||
args = (
|
||||
storage,
|
||||
self.storage_offset(),
|
||||
|
@ -3,7 +3,6 @@ import copyreg
|
||||
import functools
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@ -109,16 +108,13 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
return kwargs["async"]
|
||||
|
||||
|
||||
_thread_local_state = threading.local()
|
||||
|
||||
|
||||
def _get_restore_location(device):
|
||||
"""Return the map_location location.
|
||||
|
||||
Used for rebuild functions where the tensor device is distinct from the storage
|
||||
"""
|
||||
|
||||
map_location = getattr(_thread_local_state, "map_location", None)
|
||||
map_location = torch.serialization._serialization_tls.map_location
|
||||
if map_location is None:
|
||||
return device
|
||||
else:
|
||||
|
@ -11,6 +11,7 @@ import struct
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from enum import Enum
|
||||
@ -60,6 +61,7 @@ __all__ = [
|
||||
"get_safe_globals",
|
||||
"add_safe_globals",
|
||||
"safe_globals",
|
||||
"skip_data",
|
||||
]
|
||||
|
||||
|
||||
@ -87,6 +89,22 @@ else:
|
||||
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
|
||||
|
||||
|
||||
# _serialization_tls is used to store thread local state specific to serialization
|
||||
# that needs to be propagated to other files, in particular we use this for
|
||||
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils)
|
||||
# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
|
||||
# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx)
|
||||
class _SerializationLocal(threading.local):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.map_location: Optional[MAP_LOCATION] = None
|
||||
self.skip_data: bool = False
|
||||
self.materialize_fake_tensors: bool = False
|
||||
|
||||
|
||||
_serialization_tls = _SerializationLocal()
|
||||
|
||||
|
||||
class SourceChangeWarning(Warning):
|
||||
pass
|
||||
|
||||
@ -268,6 +286,47 @@ class safe_globals(_weights_only_unpickler._safe_globals):
|
||||
"""
|
||||
|
||||
|
||||
class skip_data:
|
||||
"""
|
||||
Context-manager that skips writing storage bytes for ``torch.save`` calls.
|
||||
|
||||
Storages will still be saved, but the space that their bytes would usually be written to
|
||||
will be empty space. The storage bytes can then be populated in a separate pass.
|
||||
|
||||
.. warning::
|
||||
The ``skip_data`` context manager is an early prototype and is subject to change.
|
||||
|
||||
Args:
|
||||
materialize_fake_tensors: Whether to materialize FakeTensors.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
|
||||
>>> import tempfile
|
||||
>>> t = torch.randn(2, 3)
|
||||
>>> with tempfile.NamedTemporaryFile() as f:
|
||||
... with torch.serialization.skip_data():
|
||||
... torch.save(t, f.name)
|
||||
... torch.load(f.name, weights_only=True)
|
||||
tensor([[0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
"""
|
||||
|
||||
def __init__(self, materialize_fake_tensors: bool = False):
|
||||
self.materialize_fake_tensors = materialize_fake_tensors
|
||||
|
||||
def __enter__(self):
|
||||
global _serialization_tls
|
||||
self._old_skip_data = _serialization_tls.skip_data
|
||||
self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors
|
||||
_serialization_tls.skip_data = True
|
||||
_serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
global _serialization_tls
|
||||
_serialization_tls.skip_data = self._old_skip_data
|
||||
_serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors
|
||||
|
||||
|
||||
def _is_zipfile(f) -> bool:
|
||||
# This is a stricter implementation than zipfile.is_zipfile().
|
||||
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
|
||||
@ -797,6 +856,11 @@ def save(
|
||||
)
|
||||
return
|
||||
else:
|
||||
global _serialization_tls
|
||||
if _serialization_tls.skip_data:
|
||||
raise RuntimeError(
|
||||
"Cannot use skip_data=True with _use_new_zipfile_serialization=False"
|
||||
)
|
||||
with _open_file_like(f, "wb") as opened_file:
|
||||
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
|
||||
|
||||
@ -955,7 +1019,13 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
|
||||
def _save(
|
||||
obj,
|
||||
zip_file,
|
||||
pickle_module,
|
||||
pickle_protocol,
|
||||
_disable_byteorder_record,
|
||||
):
|
||||
serialized_storages = {}
|
||||
id_map: Dict[int, str] = {}
|
||||
|
||||
@ -990,7 +1060,7 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
|
||||
# If storage is allocated, ensure that any other saved storages
|
||||
# pointing to the same data all have the same dtype. If storage is
|
||||
# not allocated, don't perform this check
|
||||
if storage.data_ptr() != 0:
|
||||
if str(storage.device) != "meta" and storage.data_ptr() != 0:
|
||||
if storage.data_ptr() in storage_dtypes:
|
||||
if storage_dtype != storage_dtypes[storage.data_ptr()]:
|
||||
raise RuntimeError(
|
||||
@ -1001,7 +1071,10 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
|
||||
location = location_tag(storage)
|
||||
if hasattr(obj, "_fake_device") and obj._fake_device is not None:
|
||||
location = str(obj._fake_device)
|
||||
else:
|
||||
location = location_tag(storage)
|
||||
serialized_storages[storage_key] = storage
|
||||
|
||||
return ("storage", storage_type, storage_key, location, storage_numel)
|
||||
@ -1027,14 +1100,18 @@ def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_reco
|
||||
for key in sorted(serialized_storages.keys()):
|
||||
name = f"data/{key}"
|
||||
storage = serialized_storages[key]
|
||||
# given that we copy things around anyway, we might use storage.cpu()
|
||||
# this means to that to get tensors serialized, you need to implement
|
||||
# .cpu() on the underlying Storage
|
||||
if storage.device.type != "cpu":
|
||||
storage = storage.cpu()
|
||||
# Now that it is on the CPU we can directly copy it into the zip file
|
||||
num_bytes = storage.nbytes()
|
||||
zip_file.write_record(name, storage, num_bytes)
|
||||
global _serialization_tls
|
||||
if _serialization_tls.skip_data:
|
||||
zip_file.write_record_metadata(name, num_bytes)
|
||||
else:
|
||||
# given that we copy things around anyway, we might use storage.cpu()
|
||||
# this means to that to get tensors serialized, you need to implement
|
||||
# .cpu() on the underlying Storage
|
||||
if storage.device.type != "cpu":
|
||||
storage = storage.cpu()
|
||||
# Now that it is on the CPU we can directly copy it into the zip file
|
||||
zip_file.write_record(name, storage, num_bytes)
|
||||
|
||||
|
||||
def load(
|
||||
@ -1184,6 +1261,14 @@ def load(
|
||||
updated_message += message
|
||||
return updated_message + DOCS_MESSAGE
|
||||
|
||||
global _serialization_tls
|
||||
skip_data = _serialization_tls.skip_data
|
||||
if skip_data:
|
||||
raise RuntimeError(
|
||||
"`torch.load` called within a torch.serialization.skip_data context manager "
|
||||
"is not supported yet. Please call torch.load outside the skip_data context manager."
|
||||
)
|
||||
|
||||
if weights_only is None:
|
||||
weights_only, warn_weights_only = False, True
|
||||
else:
|
||||
@ -1758,9 +1843,10 @@ def _load(
|
||||
unpickler.persistent_load = persistent_load
|
||||
# Needed for tensors where storage device and rebuild tensor device are
|
||||
# not connected (wrapper subclasses and tensors rebuilt using numpy)
|
||||
torch._utils._thread_local_state.map_location = map_location
|
||||
global _serialization_tls
|
||||
_serialization_tls.map_location = map_location
|
||||
result = unpickler.load()
|
||||
del torch._utils._thread_local_state.map_location
|
||||
_serialization_tls.map_location = None
|
||||
|
||||
torch._utils._validate_loaded_sparse_tensors()
|
||||
torch._C._log_api_usage_metadata(
|
||||
|
@ -39,6 +39,8 @@ class _StorageBase:
|
||||
is_sparse: _bool = False
|
||||
is_sparse_csr: _bool = False
|
||||
device: torch.device
|
||||
# Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
|
||||
_fake_device: _Optional[torch.device] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
@ -649,6 +651,8 @@ def _get_device_from_module(module: str):
|
||||
|
||||
class TypedStorage:
|
||||
is_sparse: _bool = False
|
||||
# Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
|
||||
_fake_device: _Optional[torch.device] = None
|
||||
|
||||
dtype: torch.dtype
|
||||
|
||||
|
Reference in New Issue
Block a user