mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix allowlisting of builtins for weights_only unpickler (#129244)
Since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), some functions/classes that were renamed from python 2-->3 will be pickled with their python2 name. This PR ensures that when a mod `GLOBAL <python2_mod>.<python2_name> ` is encountered, [following the strategy used by pickle](https://github.com/python/cpython/blob/main/Lib/pickle.py#L1590C13-L1593C63) it is properly mapped to `<python3_mod>.<python3_name>`. This fix ensures that `add_safe_globals` works properly for such functions/classes (i.e. users will allowlist the python3 func and the weights_only unpickler will do the appropriate translation when checking whether a class was allowlisted). An example is as follows: `__builtin__` was named to `builtins`, see the [release notes for Python 3.0](https://docs.python.org/3/whatsnew/3.0.html) > Renamed module `__builtin__` to [`builtins`](https://docs.python.org/3/library/builtins.html#module-builtins) (removing the underscores, adding an ‘s’). The __builtins__ variable found in most global namespaces is unchanged. To modify a builtin, you should use [builtins](https://docs.python.org/3/library/builtins.html#module-builtins), not `__builtins__`! However, since we use [`DEFAULT_PROTOCOL=2`](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L62), builtins will be pickled with their module string as `__builtin__`. ```python >>> import pickle >>> import pickletools >>> print.__module__ 'builtins' >>> with open('print.pkl', 'wb') as f: >>> pickle.dump(print, f, protocol=2) # 2 because this is the default protocol used by pytorch >>> with open('print.pkl', 'rb') as f: >>> pickletools.dis(f) 0: \x80 PROTO 2 2: c GLOBAL '__builtin__ print' # pickle saves the module string as __builtin__ !!! :( 21: q BINPUT 0 23: . STOP ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129244 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
aa4ee2cb9e
commit
1bb1e3463c
@ -23,6 +23,7 @@
|
||||
# weights = torch.load(buf, weights_only = True)
|
||||
|
||||
import functools as _functools
|
||||
import warnings
|
||||
from collections import Counter, OrderedDict
|
||||
from pickle import (
|
||||
APPEND,
|
||||
@ -67,6 +68,16 @@ from struct import unpack
|
||||
from sys import maxsize
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
# We rely on this module in private cPython which provides dicts of
|
||||
# modules/functions that had their names changed from Python 2 to 3
|
||||
has_compat_pickle = True
|
||||
from _compat_pickle import IMPORT_MAPPING, NAME_MAPPING
|
||||
except ImportError:
|
||||
# To prevent warning on import torch, we warn in the Unpickler.load below
|
||||
has_compat_pickle = False
|
||||
IMPORT_MAPPING, NAME_MAPPING = dict(), dict()
|
||||
|
||||
import torch
|
||||
|
||||
_marked_safe_globals_list: List[Any] = []
|
||||
@ -97,7 +108,8 @@ def _clear_safe_globals():
|
||||
def _get_user_allowed_globals():
|
||||
rc: Dict[str, Any] = {}
|
||||
for f in _marked_safe_globals_list:
|
||||
rc[f"{f.__module__}.{f.__name__}"] = f
|
||||
module, name = f.__module__, f.__name__
|
||||
rc[f"{module}.{name}"] = f
|
||||
return rc
|
||||
|
||||
|
||||
@ -170,12 +182,20 @@ class Unpickler:
|
||||
self.readline = file.readline
|
||||
self.read = file.read
|
||||
self.memo: Dict[int, Any] = {}
|
||||
self.proto: int = -1
|
||||
|
||||
def load(self):
|
||||
"""Read a pickled object representation from the open file.
|
||||
|
||||
Return the reconstituted object hierarchy specified in the file.
|
||||
"""
|
||||
if not has_compat_pickle:
|
||||
warnings.warn(
|
||||
"Could not import IMPORT_MAPPING and NAME_MAPPING from _compat_pickle. "
|
||||
"If the default `pickle_protocol` was used at `torch.save` time, any functions or "
|
||||
"classes that are in these maps might not behave correctly if allowlisted via "
|
||||
"`torch.serialization.add_safe_globals()`."
|
||||
)
|
||||
self.metastack = []
|
||||
self.stack: List[Any] = []
|
||||
self.append = self.stack.append
|
||||
@ -190,6 +210,13 @@ class Unpickler:
|
||||
if key[0] == GLOBAL[0]:
|
||||
module = readline()[:-1].decode("utf-8")
|
||||
name = readline()[:-1].decode("utf-8")
|
||||
# Patch since torch.save default protocol is 2
|
||||
# users will be running this code in python > 3
|
||||
if self.proto == 2 and has_compat_pickle:
|
||||
if (module, name) in NAME_MAPPING:
|
||||
module, name = NAME_MAPPING[(module, name)]
|
||||
elif module in IMPORT_MAPPING:
|
||||
module = IMPORT_MAPPING[module]
|
||||
full_path = f"{module}.{name}"
|
||||
if full_path in _get_allowed_globals():
|
||||
self.append(_get_allowed_globals()[full_path])
|
||||
@ -334,8 +361,14 @@ class Unpickler:
|
||||
self.append(decode_long(data))
|
||||
# First and last deserializer ops
|
||||
elif key[0] == PROTO[0]:
|
||||
# Read and ignore proto version
|
||||
read(1)[0]
|
||||
self.proto = read(1)[0]
|
||||
if self.proto != 2:
|
||||
warnings.warn(
|
||||
f"Detected pickle protocol {self.proto} in the checkpoint, which was "
|
||||
"not the default pickle protocol used by `torch.load` (2). The weights_only "
|
||||
"Unpickler might not support all instructions implemented by this protocol, "
|
||||
"please file an issue for adding support if you encounter this."
|
||||
)
|
||||
elif key[0] == STOP[0]:
|
||||
rc = self.stack.pop()
|
||||
return rc
|
||||
|
Reference in New Issue
Block a user