mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - mostly toplevels (#145178)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ce533867f
commit
f2cfe8b59f
@ -17,13 +17,9 @@ from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -260,8 +256,8 @@ class Guard:
|
||||
create_fn: Callable[[GuardBuilderBase, Guard], None]
|
||||
|
||||
# Export only. These values are written to at time of guard check_fn creation.
|
||||
guard_types: Optional[List[str]] = None
|
||||
code_list: Optional[List[str]] = None
|
||||
guard_types: Optional[list[str]] = None
|
||||
code_list: Optional[list[str]] = None
|
||||
obj_weakref: Optional[object] = None
|
||||
guarded_class_weakref: Optional[type] = None
|
||||
|
||||
@ -448,8 +444,8 @@ overlapping with any other input, overlapping_sources represent tensors that eit
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StorageOverlap(GuardEnvExpr):
|
||||
overlapping_sources: List[Source]
|
||||
non_overlapping_sources: List[Source]
|
||||
overlapping_sources: list[Source]
|
||||
non_overlapping_sources: list[Source]
|
||||
|
||||
|
||||
"""
|
||||
@ -478,7 +474,7 @@ class GuardsCheckpointState:
|
||||
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
|
||||
"""
|
||||
|
||||
dynamo_guards: Set[Guard] = set()
|
||||
dynamo_guards: set[Guard] = set()
|
||||
|
||||
def __init__(self, dynamo_guards):
|
||||
self.dynamo_guards = dynamo_guards
|
||||
@ -500,7 +496,7 @@ class GuardsCheckpointState:
|
||||
|
||||
|
||||
class ModuleContextCheckpointState:
|
||||
nn_modules: Dict[str, torch.nn.Module] = {}
|
||||
nn_modules: dict[str, torch.nn.Module] = {}
|
||||
|
||||
def __init__(self, nn_modules):
|
||||
self.nn_modules = nn_modules
|
||||
@ -523,7 +519,7 @@ class ModuleContextCheckpointState:
|
||||
|
||||
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||
def __init__(self) -> None:
|
||||
self.nn_modules: Dict[str, Any] = {}
|
||||
self.nn_modules: dict[str, Any] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
return ModuleContextCheckpointState(dict(self.nn_modules))
|
||||
@ -534,7 +530,7 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||
|
||||
|
||||
class GlobalContextCheckpointState:
|
||||
global_state: Dict[str, Tuple[Callable, ...]] = {}
|
||||
global_state: dict[str, tuple[Callable, ...]] = {}
|
||||
|
||||
def __init__(self, global_states):
|
||||
self.global_state = global_states
|
||||
@ -572,7 +568,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
|
||||
self.global_state: dict[str, tuple[Callable, ...]] = {}
|
||||
|
||||
def copy_graphstate(self):
|
||||
return GlobalContextCheckpointState(dict(self.global_state))
|
||||
@ -628,7 +624,7 @@ class GuardsSet:
|
||||
guard.user_stack = TracingContext.extract_stack()
|
||||
self.inner.add(guard)
|
||||
|
||||
def update(self, *others: Set[Guard]):
|
||||
def update(self, *others: set[Guard]):
|
||||
for o in others:
|
||||
for g in o:
|
||||
self.add(g, skip=1)
|
||||
@ -641,7 +637,7 @@ class GuardsSet:
|
||||
class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||
def __init__(self) -> None:
|
||||
self.dynamo_guards: GuardsSet = GuardsSet()
|
||||
self.aotautograd_guards: List[GuardEnvExpr] = []
|
||||
self.aotautograd_guards: list[GuardEnvExpr] = []
|
||||
|
||||
def copy_graphstate(self):
|
||||
return GuardsCheckpointState(set(self.dynamo_guards.inner))
|
||||
@ -674,9 +670,9 @@ class HopSubgraphCache:
|
||||
|
||||
class InvokeSubgraphCache(HopSubgraphCache):
|
||||
def __init__(self) -> None:
|
||||
self.autograd_cache: Dict[str, Callable] = {}
|
||||
self.proxy_dispatch_cache: Dict[str, Callable] = {}
|
||||
self.dynamo_identifiers: Dict[str, str] = {}
|
||||
self.autograd_cache: dict[str, Callable] = {}
|
||||
self.proxy_dispatch_cache: dict[str, Callable] = {}
|
||||
self.dynamo_identifiers: dict[str, str] = {}
|
||||
|
||||
def add_dynamo_identifier(self, cache_key: str, identifier: str):
|
||||
self.dynamo_identifiers[cache_key] = identifier
|
||||
@ -748,7 +744,7 @@ class CompileContext:
|
||||
self.compile_id: Optional[CompileId] = compile_id
|
||||
self.attempt = 0
|
||||
# Verbose ShapeEnv guards produced.
|
||||
self.shape_env_guards: List[str] = []
|
||||
self.shape_env_guards: list[str] = []
|
||||
|
||||
@staticmethod
|
||||
def current_compile_id():
|
||||
@ -816,7 +812,7 @@ class TracingContext:
|
||||
# careful not to accidentally induce guards on the SymInt if
|
||||
# you ever do change this in aot_autograd.py; you should check
|
||||
# on permutations preferentially.)
|
||||
self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
|
||||
self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None
|
||||
# When this is True, whenever we encounter an int in Dynamo tracing,
|
||||
# we will (1) force unspec it and (2) force it as a size-like unbacked
|
||||
# integer. This is currently used when processing certain lists of
|
||||
|
||||
Reference in New Issue
Block a user