PEP585 update - mostly toplevels (#145178)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-21 13:42:12 -08:00
committed by PyTorch MergeBot
parent 1ce533867f
commit f2cfe8b59f
39 changed files with 356 additions and 386 deletions

View File

@ -17,13 +17,9 @@ from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Generic,
List,
NamedTuple,
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
@ -260,8 +256,8 @@ class Guard:
create_fn: Callable[[GuardBuilderBase, Guard], None]
# Export only. These values are written to at time of guard check_fn creation.
guard_types: Optional[List[str]] = None
code_list: Optional[List[str]] = None
guard_types: Optional[list[str]] = None
code_list: Optional[list[str]] = None
obj_weakref: Optional[object] = None
guarded_class_weakref: Optional[type] = None
@ -448,8 +444,8 @@ overlapping with any other input, overlapping_sources represent tensors that eit
@dataclasses.dataclass
class StorageOverlap(GuardEnvExpr):
overlapping_sources: List[Source]
non_overlapping_sources: List[Source]
overlapping_sources: list[Source]
non_overlapping_sources: list[Source]
"""
@ -478,7 +474,7 @@ class GuardsCheckpointState:
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
"""
dynamo_guards: Set[Guard] = set()
dynamo_guards: set[Guard] = set()
def __init__(self, dynamo_guards):
self.dynamo_guards = dynamo_guards
@ -500,7 +496,7 @@ class GuardsCheckpointState:
class ModuleContextCheckpointState:
nn_modules: Dict[str, torch.nn.Module] = {}
nn_modules: dict[str, torch.nn.Module] = {}
def __init__(self, nn_modules):
self.nn_modules = nn_modules
@ -523,7 +519,7 @@ class ModuleContextCheckpointState:
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self) -> None:
self.nn_modules: Dict[str, Any] = {}
self.nn_modules: dict[str, Any] = {}
def copy_graphstate(self):
return ModuleContextCheckpointState(dict(self.nn_modules))
@ -534,7 +530,7 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
class GlobalContextCheckpointState:
global_state: Dict[str, Tuple[Callable, ...]] = {}
global_state: dict[str, tuple[Callable, ...]] = {}
def __init__(self, global_states):
self.global_state = global_states
@ -572,7 +568,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
}
def __init__(self) -> None:
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
self.global_state: dict[str, tuple[Callable, ...]] = {}
def copy_graphstate(self):
return GlobalContextCheckpointState(dict(self.global_state))
@ -628,7 +624,7 @@ class GuardsSet:
guard.user_stack = TracingContext.extract_stack()
self.inner.add(guard)
def update(self, *others: Set[Guard]):
def update(self, *others: set[Guard]):
for o in others:
for g in o:
self.add(g, skip=1)
@ -641,7 +637,7 @@ class GuardsSet:
class GuardsContext(Checkpointable[GuardsCheckpointState]):
def __init__(self) -> None:
self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: List[GuardEnvExpr] = []
self.aotautograd_guards: list[GuardEnvExpr] = []
def copy_graphstate(self):
return GuardsCheckpointState(set(self.dynamo_guards.inner))
@ -674,9 +670,9 @@ class HopSubgraphCache:
class InvokeSubgraphCache(HopSubgraphCache):
def __init__(self) -> None:
self.autograd_cache: Dict[str, Callable] = {}
self.proxy_dispatch_cache: Dict[str, Callable] = {}
self.dynamo_identifiers: Dict[str, str] = {}
self.autograd_cache: dict[str, Callable] = {}
self.proxy_dispatch_cache: dict[str, Callable] = {}
self.dynamo_identifiers: dict[str, str] = {}
def add_dynamo_identifier(self, cache_key: str, identifier: str):
self.dynamo_identifiers[cache_key] = identifier
@ -748,7 +744,7 @@ class CompileContext:
self.compile_id: Optional[CompileId] = compile_id
self.attempt = 0
# Verbose ShapeEnv guards produced.
self.shape_env_guards: List[str] = []
self.shape_env_guards: list[str] = []
@staticmethod
def current_compile_id():
@ -816,7 +812,7 @@ class TracingContext:
# careful not to accidentally induce guards on the SymInt if
# you ever do change this in aot_autograd.py; you should check
# on permutations preferentially.)
self.output_strides: Optional[List[Optional[Tuple[int, ...]]]] = None
self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None
# When this is True, whenever we encounter an int in Dynamo tracing,
# we will (1) force unspec it and (2) force it as a size-like unbacked
# integer. This is currently used when processing certain lists of