mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add utility to get all unsafe globals in checkpoint (no pickletools dependency) (#139221)
Fixes https://github.com/pytorch/pytorch/issues/129698 https://github.com/pytorch/pytorch/pull/139106 without pickletools Pull Request resolved: https://github.com/pytorch/pytorch/pull/139221 Approved by: https://github.com/malfet ghstack dependencies: #138936
This commit is contained in:
committed by
PyTorch MergeBot
parent
f3b485eb2a
commit
ea0e09b3f3
@ -4394,6 +4394,33 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
with zipfile.ZipFile(f) as zip_file:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
|
||||
def test_get_unsafe_globals_in_checkpoint(self):
|
||||
t = torch.randn(2, 3)
|
||||
tt = TwoTensor(t, t)
|
||||
expected_unsafe_global_strs = {"torch.testing._internal.two_tensor.TwoTensor"}
|
||||
expected_all_global_strs = {"torch.testing._internal.two_tensor.TwoTensor",
|
||||
"torch._utils._rebuild_wrapper_subclass",
|
||||
"torch._tensor._rebuild_from_type_v2",
|
||||
"torch.serialization._get_layout",
|
||||
"torch.float32",
|
||||
"torch.device",
|
||||
"torch._utils._rebuild_tensor_v2",
|
||||
"torch.FloatStorage",
|
||||
"collections.OrderedDict"}
|
||||
with BytesIOContext() as f:
|
||||
torch.save(tt, f)
|
||||
f.seek(0)
|
||||
unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f)
|
||||
self.assertEqual(set(unsafe_globals), expected_unsafe_global_strs)
|
||||
f.seek(0)
|
||||
try:
|
||||
old_get_allowed_globals = torch._weights_only_unpickler._get_allowed_globals
|
||||
torch._weights_only_unpickler._get_allowed_globals = lambda: dict() # noqa: PIE807
|
||||
unsafe_all_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f)
|
||||
self.assertEqual(set(unsafe_all_globals), expected_all_global_strs)
|
||||
finally:
|
||||
torch._weights_only_unpickler._get_allowed_globals = old_get_allowed_globals
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user