Add example for torch.serialization.add_safe_globals (#129396)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129396
Approved by: https://github.com/albanD, https://github.com/malfet
ghstack dependencies: #129239
This commit is contained in:
Mikayla Gawarecki
2024-06-25 19:56:05 -07:00
committed by PyTorch MergeBot
parent 303ad8d7f5
commit 3b531eace7

View File

@ -209,6 +209,22 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
Args:
safe_globals (List[Any]): list of globals to mark as safe
Example:
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
>>> import tempfile
>>> class MyTensor(torch.Tensor):
... pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
... torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
... torch.serialization.add_safe_globals([MyTensor])
... torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
# [-0.8234, 2.0500, -0.3657]])
"""
_weights_only_unpickler._add_safe_globals(safe_globals)