mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add example for torch.serialization.add_safe_globals (#129396)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129396 Approved by: https://github.com/albanD ghstack dependencies: #129244, #129251, #129239
This commit is contained in:
committed by
PyTorch MergeBot
parent
381ce0821c
commit
f18becaaf1
@ -209,6 +209,22 @@ def add_safe_globals(safe_globals: List[Any]) -> None:
|
||||
|
||||
Args:
|
||||
safe_globals (List[Any]): list of globals to mark as safe
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
|
||||
>>> import tempfile
|
||||
>>> class MyTensor(torch.Tensor):
|
||||
... pass
|
||||
>>> t = MyTensor(torch.randn(2, 3))
|
||||
>>> with tempfile.NamedTemporaryFile() as f:
|
||||
... torch.save(t, f.name)
|
||||
# Running `torch.load(f.name, weights_only=True)` will fail with
|
||||
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
|
||||
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
|
||||
... torch.serialization.add_safe_globals([MyTensor])
|
||||
... torch.load(f.name, weights_only=True)
|
||||
# MyTensor([[-0.5024, -1.8152, -0.5455],
|
||||
# [-0.8234, 2.0500, -0.3657]])
|
||||
"""
|
||||
_weights_only_unpickler._add_safe_globals(safe_globals)
|
||||
|
||||
|
Reference in New Issue
Block a user