PEP585 update - mostly toplevels (#145178)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-21 13:42:12 -08:00
committed by PyTorch MergeBot
parent 1ce533867f
commit f2cfe8b59f
39 changed files with 356 additions and 386 deletions

View File

@ -68,7 +68,7 @@ from pickle import (
)
from struct import unpack
from sys import maxsize
from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, Union
import torch
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
@ -83,15 +83,15 @@ _blocklisted_modules = [
"nt",
]
_marked_safe_globals_set: Set[Union[Callable, Tuple[Callable, str]]] = set()
_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set()
def _add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]):
global _marked_safe_globals_set
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
def _get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
global _marked_safe_globals_set
return list(_marked_safe_globals_set)
@ -102,14 +102,14 @@ def _clear_safe_globals():
def _remove_safe_globals(
globals_to_remove: List[Union[Callable, Tuple[Callable, str]]],
globals_to_remove: list[Union[Callable, tuple[Callable, str]]],
):
global _marked_safe_globals_set
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
class _safe_globals:
def __init__(self, safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]):
self.safe_globals = safe_globals
def __enter__(self):
@ -127,7 +127,7 @@ class _safe_globals:
# the dynamic additions to safe_globals would not be picked up by
# _get_allowed_globals due to the lru_cache
def _get_user_allowed_globals():
rc: Dict[str, Any] = {}
rc: dict[str, Any] = {}
for f in _marked_safe_globals_set:
if isinstance(f, tuple):
if len(f) != 2:
@ -171,7 +171,7 @@ def _tensor_rebuild_functions():
# Unpickling machinery
@_functools.lru_cache(maxsize=1)
def _get_allowed_globals():
rc: Dict[str, Any] = {
rc: dict[str, Any] = {
"collections.OrderedDict": OrderedDict,
"collections.Counter": Counter,
"torch.nn.parameter.Parameter": torch.nn.Parameter,
@ -221,7 +221,7 @@ def _get_allowed_globals():
return rc
def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
def _read_global_instruction(readline: Callable) -> tuple[str, str]:
module = readline()[:-1].decode("utf-8")
name = readline()[:-1].decode("utf-8")
# Patch since torch.save default protocol is 2
@ -233,7 +233,7 @@ def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
return module, name
def get_globals_in_pkl(file) -> Set[str]:
def get_globals_in_pkl(file) -> set[str]:
globals_in_checkpoint = set()
read = file.read
readline = file.readline
@ -302,7 +302,7 @@ class Unpickler:
self.encoding = encoding
self.readline = file.readline
self.read = file.read
self.memo: Dict[int, Any] = {}
self.memo: dict[int, Any] = {}
self.proto: int = -1
def load(self):
@ -311,7 +311,7 @@ class Unpickler:
Return the reconstituted object hierarchy specified in the file.
"""
self.metastack = []
self.stack: List[Any] = []
self.stack: list[Any] = []
self.append = self.stack.append
read = self.read
while True: