Add torch.serialization.safe_globals context manager (#127939)

Add context manager mentioned in https://github.com/pytorch/pytorch/pull/127808#pullrequestreview-2096298486

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127939
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-07-12 09:16:10 -07:00
committed by PyTorch MergeBot
parent f0d7164cb9
commit 7c289c2a5c
4 changed files with 69 additions and 0 deletions

View File

@ -397,3 +397,4 @@ The following utility functions are related to serialization:
.. autofunction:: add_safe_globals
.. autofunction:: clear_safe_globals
.. autofunction:: get_safe_globals
.. autoclass:: safe_globals

View File

@ -4329,6 +4329,29 @@ class TestSubclassSerialization(TestCase):
finally:
torch.serialization.clear_safe_globals()
def test_safe_globals_context_manager_weights_only(self):
'''
Tests torch.serialization.safe_globals context manager
'''
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
p = torch.nn.Parameter(t)
sd = OrderedDict([('t', t), ('p', p)])
try:
torch.serialization.add_safe_globals([TestEmptySubclass])
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
with torch.serialization.safe_globals([TwoTensor]):
f.seek(0)
torch.load(f, weights_only=True)
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
f.seek(0)
with self.assertRaisesRegex(pickle.UnpicklingError,
"Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor"):
torch.load(f, weights_only=True)
finally:
torch.serialization.clear_safe_globals()
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
def test_tensor_subclass_map_location(self):
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))

View File

@ -90,6 +90,24 @@ def _clear_safe_globals():
_marked_safe_globals_list = []
def _remove_safe_globals(globals_to_remove: List[Any]):
global _marked_safe_globals_list
_marked_safe_globals_list = list(
set(_marked_safe_globals_list) - set(globals_to_remove)
)
class _safe_globals:
def __init__(self, safe_globals: List[Any]):
self.safe_globals = safe_globals
def __enter__(self):
_add_safe_globals(self.safe_globals)
def __exit__(self, type, value, tb):
_remove_safe_globals(self.safe_globals)
# Separate from _get_allowed_globals because of the lru_cache on _get_allowed_globals
# For example if user had a script like
# torch.load(file_a)

View File

@ -57,6 +57,7 @@ __all__ = [
"clear_safe_globals",
"get_safe_globals",
"add_safe_globals",
"safe_globals",
]
@ -230,6 +231,32 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
_weights_only_unpickler._add_safe_globals(safe_globals)
class safe_globals(_weights_only_unpickler._safe_globals):
r"""Context-manager that adds certain globals as safe for ``weights_only`` load.
Args:
safe_globals: List of globals for weights_only load.
Example:
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
>>> import tempfile
>>> class MyTensor(torch.Tensor):
... pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
... torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
... with torch.serialization.safe_globals([MyTensor]):
... torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
# [-0.8234, 2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
"""
pass
def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the