Fix allowlisting of builtins for weights_only unpickler (#129244)

Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`.

This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted).

An example is as follows:
`__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html)

> Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`!

However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`.

```python
>>> import pickle
>>> import pickletools
>>> print.__module__
'builtins'
>>> with open('print.pkl', 'wb') as f:
>>>      pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch
>>> with open('print.pkl', 'rb') as f:
>>>     pickletools.dis(f)
0: \x80 PROTO      2
2: c    GLOBAL     '__builtin__ print' # pickle saves the module string as __builtin__ !!! :(
21: q    BINPUT     0
23: .    STOP
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-06-24 18:08:26 -07:00
committed by PyTorch MergeBot
parent aa4ee2cb9e
commit 1bb1e3463c
2 changed files with 43 additions and 4 deletions

View File

@ -1040,8 +1040,14 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertIsNone(torch.load(f, weights_only=False))
f.seek(0)
# Safe load should assert
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
torch.load(f, weights_only=True)
try:
torch.serialization.add_safe_globals([print])
f.seek(0)
torch.load(f, weights_only=True)
finally:
torch.serialization.clear_safe_globals()
@parametrize('weights_only', (False, True))
def test_serialization_math_bits(self, weights_only):