mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add utility to get all unsafe globals in checkpoint (no pickletools dependency) (#139221)
Fixes https://github.com/pytorch/pytorch/issues/129698 https://github.com/pytorch/pytorch/pull/139106 without pickletools Pull Request resolved: https://github.com/pytorch/pytorch/pull/139221 Approved by: https://github.com/malfet ghstack dependencies: #138936
This commit is contained in:
committed by
PyTorch MergeBot
parent
f3b485eb2a
commit
ea0e09b3f3
@ -68,7 +68,7 @@ from pickle import (
|
||||
)
|
||||
from struct import unpack
|
||||
from sys import maxsize
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
|
||||
@ -207,6 +207,83 @@ def _get_allowed_globals():
|
||||
return rc
|
||||
|
||||
|
||||
def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
|
||||
module = readline()[:-1].decode("utf-8")
|
||||
name = readline()[:-1].decode("utf-8")
|
||||
# Patch since torch.save default protocol is 2
|
||||
# users will be running this code in python > 3
|
||||
if (module, name) in NAME_MAPPING:
|
||||
module, name = NAME_MAPPING[(module, name)]
|
||||
elif module in IMPORT_MAPPING:
|
||||
module = IMPORT_MAPPING[module]
|
||||
return module, name
|
||||
|
||||
|
||||
def get_globals_in_pkl(file) -> Set[str]:
|
||||
globals_in_checkpoint = set()
|
||||
protocol = None
|
||||
read = file.read
|
||||
readline = file.readline
|
||||
op_to_bytes_to_read = {
|
||||
NEWOBJ[0]: 0,
|
||||
REDUCE[0]: 0,
|
||||
BUILD[0]: 0,
|
||||
APPEND[0]: 0,
|
||||
APPENDS[0]: 0,
|
||||
SETITEM[0]: 0,
|
||||
SETITEMS[0]: 0,
|
||||
MARK[0]: 0,
|
||||
TUPLE[0]: 0,
|
||||
TUPLE1[0]: 0,
|
||||
TUPLE2[0]: 0,
|
||||
TUPLE3[0]: 0,
|
||||
NONE[0]: 0,
|
||||
NEWFALSE[0]: 0,
|
||||
NEWTRUE[0]: 0,
|
||||
EMPTY_TUPLE[0]: 0,
|
||||
EMPTY_LIST[0]: 0,
|
||||
EMPTY_DICT[0]: 0,
|
||||
EMPTY_SET[0]: 0,
|
||||
BINPERSID[0]: 0,
|
||||
BININT[0]: 4,
|
||||
BININT1[0]: 1,
|
||||
BININT2[0]: 2,
|
||||
BINFLOAT[0]: 8,
|
||||
BINGET[0]: 1,
|
||||
LONG_BINGET[0]: 4,
|
||||
BINPUT[0]: 1,
|
||||
LONG_BINPUT[0]: 4,
|
||||
}
|
||||
while True:
|
||||
key = read(1)
|
||||
if not key:
|
||||
raise EOFError
|
||||
assert isinstance(key, bytes_types)
|
||||
if key[0] == GLOBAL[0]:
|
||||
module, name = _read_global_instruction(readline)
|
||||
globals_in_checkpoint.add(f"{module}.{name}")
|
||||
elif key[0] in op_to_bytes_to_read:
|
||||
bytes_to_read = op_to_bytes_to_read[key[0]]
|
||||
if bytes_to_read:
|
||||
read(bytes_to_read)
|
||||
# ops where bytes to read depends on the data
|
||||
elif key[0] == BINUNICODE[0]:
|
||||
strlen = unpack("<I", read(4))[0]
|
||||
if strlen > maxsize:
|
||||
raise UnpicklingError("String is too long")
|
||||
read(strlen)
|
||||
elif key[0] in {SHORT_BINSTRING[0], LONG1[0]}:
|
||||
strlen = read(1)[0]
|
||||
read(strlen)
|
||||
# first and last op
|
||||
elif key[0] == PROTO[0]:
|
||||
protocol = read(1)[0]
|
||||
elif key[0] == STOP[0]:
|
||||
return globals_in_checkpoint
|
||||
else:
|
||||
raise UnpicklingError(f"Unsupported operand {key[0]}")
|
||||
|
||||
|
||||
class Unpickler:
|
||||
def __init__(self, file, *, encoding: str = "bytes"):
|
||||
self.encoding = encoding
|
||||
@ -232,15 +309,7 @@ class Unpickler:
|
||||
assert isinstance(key, bytes_types)
|
||||
# Risky operators
|
||||
if key[0] == GLOBAL[0]:
|
||||
module = readline()[:-1].decode("utf-8")
|
||||
name = readline()[:-1].decode("utf-8")
|
||||
# Patch since torch.save default protocol is 2
|
||||
# users will be running this code in python > 3
|
||||
if self.proto == 2:
|
||||
if (module, name) in NAME_MAPPING:
|
||||
module, name = NAME_MAPPING[(module, name)]
|
||||
elif module in IMPORT_MAPPING:
|
||||
module = IMPORT_MAPPING[module]
|
||||
module, name = _read_global_instruction(self.readline)
|
||||
full_path = f"{module}.{name}"
|
||||
if module in _blocklisted_modules:
|
||||
raise UnpicklingError(
|
||||
|
Reference in New Issue
Block a user