Fix allowlisting of builtins for weights_only unpickler (#129244)

Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`.

This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted).

An example is as follows:
`__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html)

> Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`!

However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`.

```python
>>> import pickle
>>> import pickletools
>>> print.__module__
'builtins'
>>> with open('print.pkl', 'wb') as f:
>>>      pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch
>>> with open('print.pkl', 'rb') as f:
>>>     pickletools.dis(f)
0: \x80 PROTO      2
2: c    GLOBAL     '__builtin__ print' # pickle saves the module string as __builtin__ !!! :(
21: q    BINPUT     0
23: .    STOP
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-06-24 18:08:26 -07:00
committed by PyTorch MergeBot
parent aa4ee2cb9e
commit 1bb1e3463c
2 changed files with 43 additions and 4 deletions

View File

@ -23,6 +23,7 @@
# weights = torch.load(buf, weights_only = True)
import functools as _functools
import warnings
from collections import Counter, OrderedDict
from pickle import (
APPEND,
@ -67,6 +68,16 @@ from struct import unpack
from sys import maxsize
from typing import Any, Dict, List
try:
# We rely on this module in private cPython which provides dicts of
# modules/functions that had their names changed from Python 2 to 3
has_compat_pickle = True
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
except ImportError:
# To prevent warning on import torch, we warn in the Unpickler.load below
has_compat_pickle = False
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
import torch
_marked_safe_globals_list: List[Any] = []
@ -97,7 +108,8 @@ def _clear_safe_globals():
def _get_user_allowed_globals():
rc: Dict[str, Any] = {}
for f in _marked_safe_globals_list:
rc[f"{f.__module__}.{f.__name__}"] = f
module, name = f.__module__, f.__name__
rc[f"{module}.{name}"] = f
return rc
@ -170,12 +182,20 @@ class Unpickler:
self.readline = file.readline
self.read = file.read
self.memo: Dict[int, Any] = {}
self.proto: int = -1
def load(self):
"""Read a pickled object representation from the open file.
Return the reconstituted object hierarchy specified in the file.
"""
if not has_compat_pickle:
warnings.warn(
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
"classes that are in these maps might not behave correctly if allowlisted via "
"`torch.serialization.add_safe_globals()`."
)
self.metastack = []
self.stack: List[Any] = []
self.append = self.stack.append
@ -190,6 +210,13 @@ class Unpickler:
if key[0] == GLOBAL[0]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
# users will be running this code in python > 3
if self.proto == 2 and has_compat_pickle:
if (module, name) in NAME_MAPPING:
module, name = NAME_MAPPING[(module, name)]
elif module in IMPORT_MAPPING:
module = IMPORT_MAPPING[module]
full_path = f"{module}.{name}"
if full_path in _get_allowed_globals():
self.append(_get_allowed_globals()[full_path])
@ -334,8 +361,14 @@ class Unpickler:
self.append(decode_long(data))
# First and last deserializer ops
elif key[0] == PROTO[0]:
# Read and ignore proto version
read(1)[0]
self.proto = read(1)[0]
if self.proto != 2:
warnings.warn(
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
"not the default pickle protocol used by `torch.load` (2). The weights_only "
"Unpickler might not support all instructions implemented by this protocol, "
"please file an issue for adding support if you encounter this."
)
elif key[0] == STOP[0]:
rc = self.stack.pop()
return rc