mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix allowlisting of builtins for weights_only unpickler (#129244)
Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`. This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted). An example is as follows: `__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html) > Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`! However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`. ```python >>> import pickle >>> import pickletools >>> print.__module__ 'builtins' >>> with open('print.pkl', 'wb') as f: >>> pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch >>> with open('print.pkl', 'rb') as f: >>> pickletools.dis(f) 0: \x80 PROTO 2 2: c GLOBAL '__builtin__ print' # pickle saves the module string as __builtin__ !!! :( 21: q BINPUT 0 23: . STOP ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
aa4ee2cb9e
commit
1bb1e3463c
@ -1040,8 +1040,14 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertIsNone(torch.load(f, weights_only=False))
|
||||
f.seek(0)
|
||||
# Safe load should assert
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL __builtin__.print"):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, "Unsupported global: GLOBAL builtins.print"):
|
||||
torch.load(f, weights_only=True)
|
||||
try:
|
||||
torch.serialization.add_safe_globals([print])
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
@parametrize('weights_only', (False, True))
|
||||
def test_serialization_math_bits(self, weights_only):
|
||||
|
Reference in New Issue
Block a user