Add utility to get all unsafe globals in checkpoint (no pickletools dependency) (#139221)

Fixes https://github.com/pytorch/pytorch/issues/129698

https://github.com/pytorch/pytorch/pull/139106 without pickletools

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139221
Approved by: https://github.com/malfet
ghstack dependencies: #138936
This commit is contained in:
Mikayla Gawarecki
2024-11-01 09:50:25 -07:00
committed by PyTorch MergeBot
parent f3b485eb2a
commit ea0e09b3f3
4 changed files with 144 additions and 10 deletions

View File

@ -68,7 +68,7 @@ from pickle import (
)
from struct import unpack
from sys import maxsize
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List, Set, Tuple
import torch
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
@ -207,6 +207,83 @@ def _get_allowed_globals():
return rc
def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
# users will be running this code in python > 3
if (module, name) in NAME_MAPPING:
module, name = NAME_MAPPING[(module, name)]
elif module in IMPORT_MAPPING:
module = IMPORT_MAPPING[module]
return module, name
def get_globals_in_pkl(file) -> Set[str]:
globals_in_checkpoint = set()
protocol = None
read = file.read
readline = file.readline
op_to_bytes_to_read = {
NEWOBJ[0]: 0,
REDUCE[0]: 0,
BUILD[0]: 0,
APPEND[0]: 0,
APPENDS[0]: 0,
SETITEM[0]: 0,
SETITEMS[0]: 0,
MARK[0]: 0,
TUPLE[0]: 0,
TUPLE1[0]: 0,
TUPLE2[0]: 0,
TUPLE3[0]: 0,
NONE[0]: 0,
NEWFALSE[0]: 0,
NEWTRUE[0]: 0,
EMPTY_TUPLE[0]: 0,
EMPTY_LIST[0]: 0,
EMPTY_DICT[0]: 0,
EMPTY_SET[0]: 0,
BINPERSID[0]: 0,
BININT[0]: 4,
BININT1[0]: 1,
BININT2[0]: 2,
BINFLOAT[0]: 8,
BINGET[0]: 1,
LONG_BINGET[0]: 4,
BINPUT[0]: 1,
LONG_BINPUT[0]: 4,
}
while True:
key = read(1)
if not key:
raise EOFError
assert isinstance(key, bytes_types)
if key[0] == GLOBAL[0]:
module, name = _read_global_instruction(readline)
globals_in_checkpoint.add(f"{module}.{name}")
elif key[0] in op_to_bytes_to_read:
bytes_to_read = op_to_bytes_to_read[key[0]]
if bytes_to_read:
read(bytes_to_read)
# ops where bytes to read depends on the data
elif key[0] == BINUNICODE[0]:
strlen = unpack("<I", read(4))[0]
if strlen > maxsize:
raise UnpicklingError("String is too long")
read(strlen)
elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}:
strlen = read(1)[0]
read(strlen)
# first and last op
elif key[0] == PROTO[0]:
protocol = read(1)[0]
elif key[0] == STOP[0]:
return globals_in_checkpoint
else:
raise UnpicklingError(f"Unsupported operand {key[0]}")
class Unpickler:
def __init__(self, file, *, encoding: str = "bytes"):
self.encoding = encoding
@ -232,15 +309,7 @@ class Unpickler:
assert isinstance(key, bytes_types)
# Risky operators
if key[0] == GLOBAL[0]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
# users will be running this code in python > 3
if self.proto == 2:
if (module, name) in NAME_MAPPING:
module, name = NAME_MAPPING[(module, name)]
elif module in IMPORT_MAPPING:
module = IMPORT_MAPPING[module]
module, name = _read_global_instruction(self.readline)
full_path = f"{module}.{name}"
if module in _blocklisted_modules:
raise UnpicklingError(