Use __qualname__ in add_safe_globals and update Unpickling error raised for Unsupported GLOBAL (#146815)

- Fixes #146814

Change
```python
for f in _marked_safe_globals_set:
    module, name = f.__module__, f.__name__
```
to

```python
for f in _marked_safe_globals_set:
    module, name = f.__module__, f.__qualname__
```
for avoiding same key string overwrite.

A test is also added.
```
python test/test_serialization.py TestSerialization.test_serialization_nested_class
```

- Fixes #146886
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146815
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Hanson-HSChang
2025-02-21 18:04:59 +00:00
committed by PyTorch MergeBot
parent f4e4cfcb91
commit 2190ca7f47
3 changed files with 45 additions and 8 deletions

View File

@ -117,6 +117,13 @@ class FilelikeMock:
def was_called(self, name):
return name in self.calls
class ClassAMock:
class Nested:
pass
class ClassBMock:
class Nested:
pass
def up_size(size):
return (*size[:-1], size[-1] * 2)
@ -1156,6 +1163,21 @@ class TestSerialization(TestCase, SerializationMixin):
_test_save_load_attr(t)
_test_save_load_attr(torch.nn.Parameter(t))
def test_serialization_nested_class(self) -> None:
with tempfile.NamedTemporaryFile() as checkpoint:
torch.save(
dict(
a_nested=ClassAMock.Nested(),
b_nested=ClassBMock.Nested(),
),
checkpoint
)
checkpoint.seek(0)
torch.serialization.add_safe_globals(
[ClassAMock, ClassBMock, getattr, ClassAMock.Nested, ClassBMock.Nested]
)
torch.load(checkpoint, weights_only=True)
def test_weights_only_assert(self):
class HelloWorld:
def __reduce__(self):
@ -1168,7 +1190,7 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertIsNone(torch.load(f, weights_only=False))
f.seek(0)
# Safe load should assert
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL print"):
torch.load(f, weights_only=True)
with torch.serialization.safe_globals([print]):
f.seek(0)
@ -1260,8 +1282,12 @@ class TestSerialization(TestCase, SerializationMixin):
torch.save(sd, f, pickle_protocol=pickle_protocol)
f.seek(0)
if unsafe_global:
with self.assertRaisesRegex(pickle.UnpicklingError,
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` or .* to allowlist"):
with self.assertRaisesRegex(
pickle.UnpicklingError,
"use `torch.serialization.add_safe_globals"
r"\(\[torch.testing._internal.two_tensor.TwoTensor\]\)`"
" or .* to allowlist"
):
torch.load(f, weights_only=True)
else:
with self.assertRaisesRegex(pickle.UnpicklingError,

View File

@ -141,7 +141,7 @@ def _get_user_allowed_globals():
f, name = f
rc[name] = f
else:
module, name = f.__module__, f.__name__
module, name = f.__module__, f.__qualname__
rc[f"{module}.{name}"] = f
return rc
@ -361,10 +361,21 @@ class Unpickler:
"``torch.distributed.tensor`` must be imported to load DTensors"
)
else:
builtins_name = "builtins"
if (
builtins_name in full_path
and builtins_name == full_path[: len(builtins_name)]
):
full_path = full_path[len(builtins_name) :]
full_path = (
full_path[1:]
if len(full_path) > 0 and full_path[0] == "."
else builtins_name + full_path
)
raise UnpicklingError(
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
f"Please use `torch.serialization.add_safe_globals([{name}])` or the "
f"`torch.serialization.safe_globals([{name}])` context manager to allowlist this global "
f"Please use `torch.serialization.add_safe_globals([{full_path}])` or the "
f"`torch.serialization.safe_globals([{full_path}])` context manager to allowlist this global "
"if you trust this class/function."
)
elif key[0] == NEWOBJ[0]:

View File

@ -274,9 +274,9 @@ def add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]])
(function/class, string) where string is the full path of the function/class.
Within the serialized format, each function is identified with its full
path as ``{__module__}.{__name__}``. When calling this API, you can provide this
path as ``{__module__}.{__qualname__}``. When calling this API, you can provide this
full path that should match the one in the checkpoint otherwise the default
``{fn.__module__}.{fn.__name__}`` will be used.
``{fn.__module__}.{fn.__qualname__}`` will be used.
Args:
safe_globals (List[Union[Callable, Tuple[Callable, str]]]): list of globals to mark as safe