mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add torch.serialization.safe_globals context manager (#127939)
Add context manager mentioned in https://github.com/pytorch/pytorch/pull/127808#pullrequestreview-2096298486 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127939 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
f0d7164cb9
commit
7c289c2a5c
@ -397,3 +397,4 @@ The following utility functions are related to serialization:
|
||||
.. autofunction:: add_safe_globals
|
||||
.. autofunction:: clear_safe_globals
|
||||
.. autofunction:: get_safe_globals
|
||||
.. autoclass:: safe_globals
|
||||
|
@ -4329,6 +4329,29 @@ class TestSubclassSerialization(TestCase):
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
def test_safe_globals_context_manager_weights_only(self):
|
||||
'''
|
||||
Tests torch.serialization.safe_globals context manager
|
||||
'''
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
p = torch.nn.Parameter(t)
|
||||
sd = OrderedDict([('t', t), ('p', p)])
|
||||
|
||||
try:
|
||||
torch.serialization.add_safe_globals([TestEmptySubclass])
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
with torch.serialization.safe_globals([TwoTensor]):
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"):
|
||||
torch.load(f, weights_only=True)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
|
||||
def test_tensor_subclass_map_location(self):
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
|
@ -90,6 +90,24 @@ def _clear_safe_globals():
|
||||
_marked_safe_globals_list = []
|
||||
|
||||
|
||||
def _remove_safe_globals(globals_to_remove: List[Any]):
|
||||
global _marked_safe_globals_list
|
||||
_marked_safe_globals_list = list(
|
||||
set(_marked_safe_globals_list) - set(globals_to_remove)
|
||||
)
|
||||
|
||||
|
||||
class _safe_globals:
|
||||
def __init__(self, safe_globals: List[Any]):
|
||||
self.safe_globals = safe_globals
|
||||
|
||||
def __enter__(self):
|
||||
_add_safe_globals(self.safe_globals)
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
_remove_safe_globals(self.safe_globals)
|
||||
|
||||
|
||||
# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
|
||||
# For example if user had a script like
|
||||
# torch.load(file_a)
|
||||
|
@ -57,6 +57,7 @@ __all__ = [
|
||||
"clear_safe_globals",
|
||||
"get_safe_globals",
|
||||
"add_safe_globals",
|
||||
"safe_globals",
|
||||
]
|
||||
|
||||
|
||||
@ -230,6 +231,32 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
|
||||
_weights_only_unpickler._add_safe_globals(safe_globals)
|
||||
|
||||
|
||||
class safe_globals(_weights_only_unpickler._safe_globals):
|
||||
r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
|
||||
|
||||
Args:
|
||||
safe_globals: List of globals for weights_only load.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
|
||||
>>> import tempfile
|
||||
>>> class MyTensor(torch.Tensor):
|
||||
... pass
|
||||
>>> t = MyTensor(torch.randn(2, 3))
|
||||
>>> with tempfile.NamedTemporaryFile() as f:
|
||||
... torch.save(t, f.name)
|
||||
# Running `torch.load(f.name, weights_only=True)` will fail with
|
||||
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
|
||||
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
|
||||
... with torch.serialization.safe_globals([MyTensor]):
|
||||
... torch.load(f.name, weights_only=True)
|
||||
# MyTensor([[-0.5024, -1.8152, -0.5455],
|
||||
# [-0.8234, 2.0500, -0.3657]])
|
||||
>>> assert torch.serialization.get_safe_globals() == []
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _is_zipfile(f) -> bool:
|
||||
# This is a stricter implementation than zipfile.is_zipfile().
|
||||
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
|
||||
|
Reference in New Issue
Block a user