mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use __qualname__ in add_safe_globals and update Unpickling error raised for Unsupported GLOBAL (#146815)
- Fixes #146814 Change ```python for f in _marked_safe_globals_set: module, name = f.__module__, f.__name__ ``` to ```python for f in _marked_safe_globals_set: module, name = f.__module__, f.__qualname__ ``` for avoiding same key string overwrite. A test is also added. ``` python test/test_serialization.py TestSerialization.test_serialization_nested_class ``` - Fixes #146886 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146815 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
f4e4cfcb91
commit
2190ca7f47
@ -117,6 +117,13 @@ class FilelikeMock:
|
||||
def was_called(self, name):
|
||||
return name in self.calls
|
||||
|
||||
class ClassAMock:
|
||||
class Nested:
|
||||
pass
|
||||
|
||||
class ClassBMock:
|
||||
class Nested:
|
||||
pass
|
||||
|
||||
def up_size(size):
|
||||
return (*size[:-1], size[-1] * 2)
|
||||
@ -1156,6 +1163,21 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
_test_save_load_attr(t)
|
||||
_test_save_load_attr(torch.nn.Parameter(t))
|
||||
|
||||
def test_serialization_nested_class(self) -> None:
|
||||
with tempfile.NamedTemporaryFile() as checkpoint:
|
||||
torch.save(
|
||||
dict(
|
||||
a_nested=ClassAMock.Nested(),
|
||||
b_nested=ClassBMock.Nested(),
|
||||
),
|
||||
checkpoint
|
||||
)
|
||||
checkpoint.seek(0)
|
||||
torch.serialization.add_safe_globals(
|
||||
[ClassAMock, ClassBMock, getattr, ClassAMock.Nested, ClassBMock.Nested]
|
||||
)
|
||||
torch.load(checkpoint, weights_only=True)
|
||||
|
||||
def test_weights_only_assert(self):
|
||||
class HelloWorld:
|
||||
def __reduce__(self):
|
||||
@ -1168,7 +1190,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertIsNone(torch.load(f, weights_only=False))
|
||||
f.seek(0)
|
||||
# Safe load should assert
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL print"):
|
||||
torch.load(f, weights_only=True)
|
||||
with torch.serialization.safe_globals([print]):
|
||||
f.seek(0)
|
||||
@ -1260,8 +1282,12 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
torch.save(sd, f, pickle_protocol=pickle_protocol)
|
||||
f.seek(0)
|
||||
if unsafe_global:
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` or .* to allowlist"):
|
||||
with self.assertRaisesRegex(
|
||||
pickle.UnpicklingError,
|
||||
"use `torch.serialization.add_safe_globals"
|
||||
r"\(\[torch.testing._internal.two_tensor.TwoTensor\]\)`"
|
||||
" or .* to allowlist"
|
||||
):
|
||||
torch.load(f, weights_only=True)
|
||||
else:
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError,
|
||||
|
@ -141,7 +141,7 @@ def _get_user_allowed_globals():
|
||||
f, name = f
|
||||
rc[name] = f
|
||||
else:
|
||||
module, name = f.__module__, f.__name__
|
||||
module, name = f.__module__, f.__qualname__
|
||||
rc[f"{module}.{name}"] = f
|
||||
return rc
|
||||
|
||||
@ -361,10 +361,21 @@ class Unpickler:
|
||||
"``torch.distributed.tensor`` must be imported to load DTensors"
|
||||
)
|
||||
else:
|
||||
builtins_name = "builtins"
|
||||
if (
|
||||
builtins_name in full_path
|
||||
and builtins_name == full_path[: len(builtins_name)]
|
||||
):
|
||||
full_path = full_path[len(builtins_name) :]
|
||||
full_path = (
|
||||
full_path[1:]
|
||||
if len(full_path) > 0 and full_path[0] == "."
|
||||
else builtins_name + full_path
|
||||
)
|
||||
raise UnpicklingError(
|
||||
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
|
||||
f"Please use `torch.serialization.add_safe_globals([{name}])` or the "
|
||||
f"`torch.serialization.safe_globals([{name}])` context manager to allowlist this global "
|
||||
f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the "
|
||||
f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global "
|
||||
"if you trust this class/function."
|
||||
)
|
||||
elif key[0] == NEWOBJ[0]:
|
||||
|
@ -274,9 +274,9 @@ def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]])
|
||||
(function/class, string) where string is the full path of the function/class.
|
||||
|
||||
Within the serialized format, each function is identified with its full
|
||||
path as ``{__module__}.{__name__}``. When calling this API, you can provide this
|
||||
path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this
|
||||
full path that should match the one in the checkpoint otherwise the default
|
||||
``{fn.__module__}.{fn.__name__}`` will be used.
|
||||
``{fn.__module__}.{fn.__qualname__}`` will be used.
|
||||
|
||||
Args:
|
||||
safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe
|
||||
|
Reference in New Issue
Block a user