Add utility to get all unsafe globals in checkpoint (no pickletools dependency) (#139221)

Fixes https://github.com/pytorch/pytorch/issues/129698

https://github.com/pytorch/pytorch/pull/139106 without pickletools

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139221
Approved by: https://github.com/malfet
ghstack dependencies: #138936
This commit is contained in:
Mikayla Gawarecki
2024-11-01 09:50:25 -07:00
committed by PyTorch MergeBot
parent f3b485eb2a
commit ea0e09b3f3
4 changed files with 144 additions and 10 deletions

View File

@ -4394,6 +4394,33 @@ class TestSerialization(TestCase, SerializationMixin):
with zipfile.ZipFile(f) as zip_file:
zip_file.extractall(path=temp_dir)
def test_get_unsafe_globals_in_checkpoint(self):
t = torch.randn(2, 3)
tt = TwoTensor(t, t)
expected_unsafe_global_strs = {"torch.testing._internal.two_tensor.TwoTensor"}
expected_all_global_strs = {"torch.testing._internal.two_tensor.TwoTensor",
"torch._utils._rebuild_wrapper_subclass",
"torch._tensor._rebuild_from_type_v2",
"torch.serialization._get_layout",
"torch.float32",
"torch.device",
"torch._utils._rebuild_tensor_v2",
"torch.FloatStorage",
"collections.OrderedDict"}
with BytesIOContext() as f:
torch.save(tt, f)
f.seek(0)
unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f)
self.assertEqual(set(unsafe_globals), expected_unsafe_global_strs)
f.seek(0)
try:
old_get_allowed_globals = torch._weights_only_unpickler._get_allowed_globals
torch._weights_only_unpickler._get_allowed_globals = lambda: dict() # noqa: PIE807
unsafe_all_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f)
self.assertEqual(set(unsafe_all_globals), expected_all_global_strs)
finally:
torch._weights_only_unpickler._get_allowed_globals = old_get_allowed_globals
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)