mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - mostly toplevels (#145178)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ce533867f
commit
f2cfe8b59f
@ -68,7 +68,7 @@ from pickle import (
|
||||
)
|
||||
from struct import unpack
|
||||
from sys import maxsize
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
from torch._utils import IMPORT_MAPPING, NAME_MAPPING
|
||||
@ -83,15 +83,15 @@ _blocklisted_modules = [
|
||||
"nt",
|
||||
]
|
||||
|
||||
_marked_safe_globals_set: Set[Union[Callable, Tuple[Callable, str]]] = set()
|
||||
_marked_safe_globals_set: set[Union[Callable, tuple[Callable, str]]] = set()
|
||||
|
||||
|
||||
def _add_safe_globals(safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
|
||||
def _add_safe_globals(safe_globals: list[Union[Callable, tuple[Callable, str]]]):
|
||||
global _marked_safe_globals_set
|
||||
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))
|
||||
|
||||
|
||||
def _get_safe_globals() -> List[Union[Callable, Tuple[Callable, str]]]:
|
||||
def _get_safe_globals() -> list[Union[Callable, tuple[Callable, str]]]:
|
||||
global _marked_safe_globals_set
|
||||
return list(_marked_safe_globals_set)
|
||||
|
||||
@ -102,14 +102,14 @@ def _clear_safe_globals():
|
||||
|
||||
|
||||
def _remove_safe_globals(
|
||||
globals_to_remove: List[Union[Callable, Tuple[Callable, str]]],
|
||||
globals_to_remove: list[Union[Callable, tuple[Callable, str]]],
|
||||
):
|
||||
global _marked_safe_globals_set
|
||||
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)
|
||||
|
||||
|
||||
class _safe_globals:
|
||||
def __init__(self, safe_globals: List[Union[Callable, Tuple[Callable, str]]]):
|
||||
def __init__(self, safe_globals: list[Union[Callable, tuple[Callable, str]]]):
|
||||
self.safe_globals = safe_globals
|
||||
|
||||
def __enter__(self):
|
||||
@ -127,7 +127,7 @@ class _safe_globals:
|
||||
# the dynamic additions to safe_globals would not be picked up by
|
||||
# _get_allowed_globals due to the lru_cache
|
||||
def _get_user_allowed_globals():
|
||||
rc: Dict[str, Any] = {}
|
||||
rc: dict[str, Any] = {}
|
||||
for f in _marked_safe_globals_set:
|
||||
if isinstance(f, tuple):
|
||||
if len(f) != 2:
|
||||
@ -171,7 +171,7 @@ def _tensor_rebuild_functions():
|
||||
# Unpickling machinery
|
||||
@_functools.lru_cache(maxsize=1)
|
||||
def _get_allowed_globals():
|
||||
rc: Dict[str, Any] = {
|
||||
rc: dict[str, Any] = {
|
||||
"collections.OrderedDict": OrderedDict,
|
||||
"collections.Counter": Counter,
|
||||
"torch.nn.parameter.Parameter": torch.nn.Parameter,
|
||||
@ -221,7 +221,7 @@ def _get_allowed_globals():
|
||||
return rc
|
||||
|
||||
|
||||
def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
|
||||
def _read_global_instruction(readline: Callable) -> tuple[str, str]:
|
||||
module = readline()[:-1].decode("utf-8")
|
||||
name = readline()[:-1].decode("utf-8")
|
||||
# Patch since torch.save default protocol is 2
|
||||
@ -233,7 +233,7 @@ def _read_global_instruction(readline: Callable) -> Tuple[str, str]:
|
||||
return module, name
|
||||
|
||||
|
||||
def get_globals_in_pkl(file) -> Set[str]:
|
||||
def get_globals_in_pkl(file) -> set[str]:
|
||||
globals_in_checkpoint = set()
|
||||
read = file.read
|
||||
readline = file.readline
|
||||
@ -302,7 +302,7 @@ class Unpickler:
|
||||
self.encoding = encoding
|
||||
self.readline = file.readline
|
||||
self.read = file.read
|
||||
self.memo: Dict[int, Any] = {}
|
||||
self.memo: dict[int, Any] = {}
|
||||
self.proto: int = -1
|
||||
|
||||
def load(self):
|
||||
@ -311,7 +311,7 @@ class Unpickler:
|
||||
Return the reconstituted object hierarchy specified in the file.
|
||||
"""
|
||||
self.metastack = []
|
||||
self.stack: List[Any] = []
|
||||
self.stack: list[Any] = []
|
||||
self.append = self.stack.append
|
||||
read = self.read
|
||||
while True:
|
||||
|
Reference in New Issue
Block a user