PEP585 update - torch/_functorch (#145139)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145139
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-17 22:09:55 -08:00
committed by PyTorch MergeBot
parent 10e4d3aebb
commit 78bff1e8c1
27 changed files with 345 additions and 357 deletions

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import networkx as nx
@ -20,18 +20,18 @@ class GraphInfoProvider:
def __init__(
self,
graph_nodes_in_order: List[str],
graph_edges: List[Tuple[str, str]],
all_recomputable_banned_nodes: List[str],
all_node_runtimes: Optional[Dict[str, float]] = None,
all_node_memories: Optional[Dict[str, float]] = None,
recorded_knapsack_input_memories: Optional[List[float]] = None,
recorded_knapsack_input_runtimes: Optional[List[float]] = None,
graph_nodes_in_order: list[str],
graph_edges: list[tuple[str, str]],
all_recomputable_banned_nodes: list[str],
all_node_runtimes: Optional[dict[str, float]] = None,
all_node_memories: Optional[dict[str, float]] = None,
recorded_knapsack_input_memories: Optional[list[float]] = None,
recorded_knapsack_input_runtimes: Optional[list[float]] = None,
joint_graph: Optional[Graph] = None,
):
self.graph_nodes_in_order = graph_nodes_in_order
self.graph_edges = graph_edges
self.all_node_runtimes: Dict[str, float] = dict()
self.all_node_runtimes: dict[str, float] = dict()
if all_node_runtimes is None:
if recorded_knapsack_input_runtimes is None:
raise ValueError(
@ -43,7 +43,7 @@ class GraphInfoProvider:
}
else:
self.all_node_runtimes.update(all_node_runtimes)
self.all_node_memories: Dict[str, float] = dict()
self.all_node_memories: dict[str, float] = dict()
if all_node_memories is None:
if recorded_knapsack_input_memories is None:
raise ValueError(
@ -59,7 +59,7 @@ class GraphInfoProvider:
self.all_recomputable_banned_nodes_set = set(all_recomputable_banned_nodes)
self.recorded_knapsack_input_memories = recorded_knapsack_input_memories
self.recorded_knapsack_input_runtimes = recorded_knapsack_input_runtimes
self._lazily_initialized_graphs: Dict[str, Any] = {
self._lazily_initialized_graphs: dict[str, Any] = {
self.__RECOMPUTABLE_NODE_ONLY_GRAPH: None,
self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT: None,
self.__FULL_NX_JOINT_GRAPH: None,
@ -70,9 +70,9 @@ class GraphInfoProvider:
def inialize_from_graph(
cls,
joint_graph: Graph,
all_recomputable_banned_nodes: List[Node],
recorded_knapsack_input_memories: List[float],
recorded_knapsack_input_runtimes: List[float],
all_recomputable_banned_nodes: list[Node],
recorded_knapsack_input_memories: list[float],
recorded_knapsack_input_runtimes: list[float],
) -> "GraphInfoProvider":
"""
Enables initialization from a joint graph.
@ -144,7 +144,7 @@ class GraphInfoProvider:
for node_name in self.all_recomputable_banned_nodes_set
)
def get_knapsack_memory_input(self) -> List[float]:
def get_knapsack_memory_input(self) -> list[float]:
return (
self.recorded_knapsack_input_memories
if self.recorded_knapsack_input_memories
@ -154,7 +154,7 @@ class GraphInfoProvider:
]
)
def get_knapsack_runtime_input(self) -> List[float]:
def get_knapsack_runtime_input(self) -> list[float]:
return (
self.recorded_knapsack_input_runtimes
if self.recorded_knapsack_input_runtimes
@ -224,7 +224,7 @@ class GraphInfoProvider:
def _recreate_psuedo_joint_graph(self) -> Graph:
# Create a dictionary to store the dependencies of each node
node_dependencies: Dict[str, List[str]] = {
node_dependencies: dict[str, list[str]] = {
node: [] for node in self.graph_nodes_in_order
}
for a, b in self.graph_edges:
@ -234,7 +234,7 @@ class GraphInfoProvider:
joint_graph = Graph()
# Create nodes in the graph
nodes: Dict[str, Node] = {}
nodes: dict[str, Node] = {}
for node_name in self.graph_nodes_in_order:
input_nodes = [nodes[dep] for dep in node_dependencies[node_name]]
if input_nodes:

View File

@ -1,11 +1,9 @@
from typing import List, Tuple
import torch
def greedy_knapsack(
memory: List[float], runtimes: List[float], max_memory: float
) -> Tuple[float, List[int], List[int]]:
memory: list[float], runtimes: list[float], max_memory: float
) -> tuple[float, list[int], list[int]]:
n = len(runtimes)
items = list(range(n))
@ -28,8 +26,8 @@ def greedy_knapsack(
def ilp_knapsack(
memory: List[float], runtimes: List[float], max_memory: float
) -> Tuple[float, List[int], List[int]]:
memory: list[float], runtimes: list[float], max_memory: float
) -> tuple[float, list[int], list[int]]:
import numpy as np
try:
@ -64,8 +62,8 @@ def ilp_knapsack(
def dp_knapsack(
memory: List[float], runtime: List[float], max_memory: float
) -> Tuple[float, List[int], List[int]]:
memory: list[float], runtime: list[float], max_memory: float
) -> tuple[float, list[int], list[int]]:
# Scaling factor to convert floating point weights to integers
S = 10000

View File

@ -1,5 +1,5 @@
from collections import deque
from typing import Callable, Dict, List, Set, Tuple
from typing import Callable
import networkx as nx
import numpy as np
@ -25,10 +25,10 @@ class KnapsackEvaluator:
def _get_backward_memory_from_topologically_sorted_graph(
self,
node_graph: nx.DiGraph,
node_memories: Dict[str, float],
saved_nodes_set: Set[str],
node_memories: dict[str, float],
saved_nodes_set: set[str],
peak_memory_after_forward_pass: float,
) -> List[Tuple[float, str]]:
) -> list[tuple[float, str]]:
"""
Simulates the backward pass and keeps track of the peak memory usage.
@ -108,7 +108,7 @@ class KnapsackEvaluator:
return current_memory
def _validate_all_indexes_accounted_for_in_provided_output(
self, saved_nodes_idxs: List[int], recomputable_node_idxs: List[int]
self, saved_nodes_idxs: list[int], recomputable_node_idxs: list[int]
) -> None:
"""
Validate that all indexes are accounted for in the provided output.
@ -132,10 +132,10 @@ class KnapsackEvaluator:
def evaluate_knapsack_output(
self,
saved_nodes_idxs: List[int],
recomputable_node_idxs: List[int],
saved_nodes_idxs: list[int],
recomputable_node_idxs: list[int],
account_for_backward_pass: bool = False,
) -> Dict[str, float]:
) -> dict[str, float]:
"""
Evaluate the theoretical runtime and peak memory usage of a given checkpointing strategy.
Args:
@ -188,10 +188,10 @@ class KnapsackEvaluator:
def evaluate_distribution_of_results_for_knapsack_algo(
self,
knapsack_algo: Callable[
[List[float], List[float], float], Tuple[float, List[int], List[int]]
[list[float], list[float], float], tuple[float, list[int], list[int]]
],
memory_budget_values: List[float],
) -> List[Dict[str, float]]:
memory_budget_values: list[float],
) -> list[dict[str, float]]:
"""
Evaluates the distribution of results for a given knapsack algorithm.
Args:
@ -216,7 +216,7 @@ class KnapsackEvaluator:
def get_knee_point_memory_budget(
self,
knapsack_algo: Callable[
[List[float], List[float], float], Tuple[float, List[int], List[int]]
[list[float], list[float], float], tuple[float, list[int], list[int]]
],
max_mem_budget: float = 0.1,
min_mem_budget: float = 0.001,

View File

@ -14,7 +14,7 @@ import pickle
import shutil
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
@ -273,7 +273,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
class AOTAutogradCachePickler(FxGraphCachePickler):
def __init__(self, gm: torch.fx.GraphModule):
super().__init__(gm)
self.dispatch_table: Dict
self.dispatch_table: dict
self.dispatch_table.update(
{
AOTConfig: functools.partial(self._reduce_aot_config),
@ -313,7 +313,7 @@ def autograd_cache_key(
config: AOTConfig,
fx_config: _CompileFxKwargs,
# TODO: add args and parameters
) -> Tuple[str, List[str]]:
) -> tuple[str, list[str]]:
"""
Generate a unique hash of the FX graph for caching.
"""
@ -399,7 +399,7 @@ class CompiledBackward(FXGraphCacheLoadable):
"""
# Used by AOTDispatchAutograd.post_compile
backward_state_indices: List[int]
backward_state_indices: list[int]
num_symints_saved_for_bw_: int
def is_backward(self):
@ -424,14 +424,14 @@ class AOTAutogradCacheEntry:
runtime_metadata: ViewAndMutationMeta
# Wrappers that run after each aot_dispatch_* function
dispatch_wrappers: List[CompilerWrapper]
dispatch_wrappers: list[CompilerWrapper]
# Used by AOTSubclassWrapper
maybe_subclass_meta: Optional[SubclassMeta]
num_fw_outs_saved_for_bw: Optional[int]
# Used by RuntimeWrapepr
indices_of_inps_to_detach: List[int]
indices_of_inps_to_detach: list[int]
# Time taken to trace/compile the forward
# forward_time_taken includes AOTAutograd tracing time + inductor compilation time
@ -442,7 +442,7 @@ class AOTAutogradCacheEntry:
# Turn cache entry into the original callable
def wrap_post_compile(
self,
args: List[torch.Tensor],
args: list[torch.Tensor],
aot_config: AOTConfig,
fx_config: _CompileFxKwargs,
) -> Callable:
@ -675,9 +675,9 @@ class AOTAutogradCache:
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
with sanitize_gm_for_cache(gm):
compiled_fn = None
cache_info: Dict[str, Any] = {}
cache_info: dict[str, Any] = {}
cache_key = None
debug_lines: List[str] = []
debug_lines: list[str] = []
cache_event_time = time.time_ns()
cache_state = None
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}

View File

@ -12,7 +12,7 @@ import collections
import contextlib
import logging
from functools import wraps
from typing import Callable, DefaultDict, Dict, List, Optional, Set
from typing import Callable, Optional
import torch
import torch.utils._pytree as pytree
@ -150,13 +150,13 @@ def run_functionalized_fw_and_collect_metadata(
# TODO: refactor to kill this flag
is_train: bool = False,
# Note: this is guaranteed to be set when running under dynamo
static_input_indices: Optional[List[int]] = None,
static_input_indices: Optional[list[int]] = None,
pre_dispatch: bool = False,
# is_export is technically only needed to avoid using functionalization V2
# during analysis
is_export: bool = False,
) -> Callable[..., ViewAndMutationMeta]:
memo: Dict[Tensor, Tensor] = {}
memo: dict[Tensor, Tensor] = {}
def _to_fun(t):
if isinstance(t, Tensor):
@ -173,8 +173,8 @@ def run_functionalized_fw_and_collect_metadata(
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
input_info: List[InputAliasInfo] = []
output_info: List[OutputAliasInfo] = []
input_info: list[InputAliasInfo] = []
output_info: list[OutputAliasInfo] = []
prior_grad_enabled = torch.is_grad_enabled()
prior_autocast_states = _get_autocast_states()
@ -275,15 +275,16 @@ def run_functionalized_fw_and_collect_metadata(
out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
# Keep track of which outputs alias other outputs
out_tensor_alias_counts: DefaultDict = collections.defaultdict(int)
out_tensor_alias_counts: collections.defaultdict = collections.defaultdict(int)
# This tells us, for a given group of outputs that alias each other,
# whether they e.g. all came from an unbind call
num_aliased_tensors_that_are_multi_output_views: DefaultDict = (
num_aliased_tensors_that_are_multi_output_views: collections.defaultdict = (
collections.defaultdict(int)
)
out_storage_to_metadata_key_to_tensors: DefaultDict[
Optional[StorageWeakRef], DefaultDict[MetadataKey, Set[torch.Tensor]]
out_storage_to_metadata_key_to_tensors: collections.defaultdict[
Optional[StorageWeakRef],
collections.defaultdict[MetadataKey, set[torch.Tensor]],
] = collections.defaultdict(lambda: collections.defaultdict(set))
curr_storage = None
@ -382,8 +383,8 @@ def run_functionalized_fw_and_collect_metadata(
].add(o)
# maps the id of an intermediate base to its index in the output of the compiled forward
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
intermediate_bases: List[torch.Tensor] = []
intermediate_base_tensor_id_to_output_idx: dict[int, int] = {}
intermediate_bases: list[torch.Tensor] = []
# Why Do We Care If Storage Changed?
# It's important to understand the implications of storage changes in complex scenarios. Take this example:
#

View File

@ -5,7 +5,7 @@ pathways, taking into account the AOTConfig and the collected ViewAndMutationMet
"""
import dataclasses
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
@ -72,11 +72,11 @@ def _detach_and_copy_item_memo(t):
def aot_dispatch_base_graph(
flat_fn,
flat_args: List[Tensor],
flat_args: list[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[torch.fx.GraphModule, List[Any], Optional[SubclassMeta]]:
) -> tuple[torch.fx.GraphModule, list[Any], Optional[SubclassMeta]]:
# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
# The cases that aot_dispatch_base doesn't need to handle include:
# - outputs that are aliases of graph intermediates
@ -133,7 +133,7 @@ def aot_dispatch_base_graph(
if aot_config.is_export and mod_when_exporting_non_strict is not None:
# For any buffer that is assigned, we want to associate it to the final proxy node
# that it is assigned to. This node can then be added as a buffer mutation output.
assigned_buffers: Dict[str, str] = {}
assigned_buffers: dict[str, str] = {}
hook = register_buffer_assignment_hook(
mod_when_exporting_non_strict, assigned_buffers
)
@ -250,11 +250,11 @@ def aot_dispatch_base_graph(
# the same storage, so long as they have separate TensorImpls.)
def aot_dispatch_autograd_graph(
flat_fn,
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]:
) -> tuple[torch.fx.GraphModule, tuple[list[Any], list[Any]], Optional[SubclassMeta]]:
# traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
# It includes outputs of the original forward, *and* any updated inputs due to input mutations.
# However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.

View File

@ -9,7 +9,7 @@ This file contains utilities related to functionalization in AOTAutograd:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional
import torch
from torch import Tensor
@ -337,11 +337,11 @@ class MetadataKey:
This should be equal whenever has_same_metadata would return True
"""
size: Tuple[SymIntEqByExpr, ...]
size: tuple[SymIntEqByExpr, ...]
layout: torch.layout
is_sparse: bool
# these are empty when is_sparse
stride: Optional[Tuple[SymIntEqByExpr, ...]]
stride: Optional[tuple[SymIntEqByExpr, ...]]
storage_offset: Optional[SymIntEqByExpr]
is_conj: bool
is_neg: bool

View File

@ -12,7 +12,7 @@ In particular, the following analyses are provided:
import contextlib
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -40,8 +40,8 @@ zip = strict_zip
def remove_dupe_metadata(
m: ViewAndMutationMeta,
keep_arg_mask: List[bool],
add_dupe_map: List[int],
keep_arg_mask: list[bool],
add_dupe_map: list[int],
) -> ViewAndMutationMeta:
assert len(m.input_info) == len(keep_arg_mask)
# Easy invariant: the first argument should never be a dupe (it will be kept)
@ -104,12 +104,12 @@ def create_synthetic_base_metadata(
m: ViewAndMutationMeta,
# Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a
# synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata)
synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]],
outer_args: List[Any],
inner_args: List[Any],
) -> Tuple[ViewAndMutationMeta, List[int]]:
synthetic_base_info: list[Union[int, tuple[int, torch.Tensor]]],
outer_args: list[Any],
inner_args: list[Any],
) -> tuple[ViewAndMutationMeta, list[int]]:
# maps inner arg indices to outer arg indices
synthetic_base_to_indices: Dict[int, List[int]] = {}
synthetic_base_to_indices: dict[int, list[int]] = {}
for inner_idx in range(len(inner_args)):
outer_aliased_indices_of_current_base_arg = [
outer_idx
@ -348,10 +348,10 @@ def create_graph_signature(
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
*,
user_args_flat: List[Tensor],
params_and_buffers_flat: List[Tensor],
param_names: List[str],
buffer_names: List[str],
user_args_flat: list[Tensor],
params_and_buffers_flat: list[Tensor],
param_names: list[str],
buffer_names: list[str],
trace_joint: bool,
num_user_fw_outs: Optional[int],
loss_index: Optional[int],

View File

@ -14,7 +14,7 @@ import logging
import time
import traceback
from contextlib import nullcontext
from typing import Any, Callable, List, Optional, Sequence, Tuple
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
import torch.utils.dlpack
@ -72,6 +72,10 @@ from .utils import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
zip = strict_zip
log = logging.getLogger(__name__)
@ -82,10 +86,10 @@ aten = torch.ops.aten
# Returns a Callable and a ViewAndMutationMeta.
# Currently, only export needs the ViewAndMutationMeta after this function.
DispatchReturn = Tuple[Callable, ViewAndMutationMeta]
DispatchReturn = tuple[Callable, ViewAndMutationMeta]
def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]:
def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper]:
"""
Wrappers that run on every dispatch function
"""
@ -96,7 +100,7 @@ def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]
# bits of aot_autograd, and doesn't need to do any specific wrapping.
def aot_dispatch_export(
flat_fn: Callable,
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
@ -136,7 +140,7 @@ def aot_dispatch_export(
def aot_dispatch_base(
flat_fn,
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
@ -283,11 +287,11 @@ def aot_dispatch_base(
def collect_fw_donated_buffer_idxs(
fw_ins: List[Optional[FakeTensor]],
user_fw_outs: List[Optional[FakeTensor]],
bw_outs: List[Optional[FakeTensor]],
saved_tensors: List[FakeTensor],
) -> List[int]:
fw_ins: list[Optional[FakeTensor]],
user_fw_outs: list[Optional[FakeTensor]],
bw_outs: list[Optional[FakeTensor]],
saved_tensors: list[FakeTensor],
) -> list[int]:
"""
Checks if the saved tensors are donated buffers, which means a saved tensor is not
an alias of any tensors in fw_ins, user_fw_outs, and bw_outs.
@ -317,7 +321,7 @@ def collect_bw_donated_buffer_idxs(
fw_module: torch.fx.GraphModule,
bw_module: torch.fx.GraphModule,
fw_metadata: ViewAndMutationMeta,
) -> List[int]:
) -> list[int]:
"""
Collects backward donated buffer indexes from fw_module and bw_module.
"""
@ -372,7 +376,7 @@ def collect_bw_donated_buffer_idxs(
def aot_dispatch_autograd(
flat_fn,
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
@ -541,7 +545,7 @@ def aot_dispatch_autograd(
# and we will end up with a zero grad at x.
# If we later backprop through the second output, this will also require backprop'ing through x.
# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time.
_indices_of_inps_to_detach: List[int] = []
_indices_of_inps_to_detach: list[int] = []
# reversed() since we expect output at end of graph
bw_output = next(reversed(bw_module.graph.find_nodes(op="output")))
@ -888,7 +892,7 @@ def aot_dispatch_autograd(
)
if config.debug_assert:
flat_requires_grad: List[Optional[bool]] = [
flat_requires_grad: list[Optional[bool]] = [
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
]
compiled_fn = DebugAssertWrapper(

View File

@ -6,14 +6,13 @@ compilation, capturing user-friendly tracebacks, and debug messages.
import collections
from contextlib import contextmanager
from typing import List, Tuple
import torch
import torch.fx.traceback as fx_traceback
# This is a list since looking forward, we can have this arbitrarily nested.
graph_being_compiled: List[str] = []
graph_being_compiled: list[str] = []
# TODO: It would be nice to reset the numbering every time aot_id goes
# up, but this is annoying to do right now (because we don't know if
# an aot_id will come back from the dead), so right now this also happens
@ -28,7 +27,7 @@ def set_model_name(name):
model_name = name
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
def get_aot_compilation_context() -> tuple[list[str], str, int]:
return list(graph_being_compiled), model_name, nth_graph
@ -70,7 +69,7 @@ def track_graph_compiling(aot_config, graph_name):
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def setup_stacktrace_preservation_hooks(roots: list):
def iter_graph(roots):
if not roots:
return

View File

@ -13,7 +13,7 @@ import pprint
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
import torch.utils.dlpack
@ -68,6 +68,10 @@ from .utils import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
zip = strict_zip
@ -86,11 +90,11 @@ class CompilerWrapper:
def pre_compile(
self,
flat_fn,
flat_args: List[Tensor],
flat_args: list[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
"""
Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs.
Args:
@ -129,7 +133,7 @@ class CompilerWrapper:
# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
@dataclass
class RuntimeWrapper(CompilerWrapper):
indices_of_inps_to_detach: List[int]
indices_of_inps_to_detach: list[int]
trace_joint: bool
disable_amp: bool
@ -244,7 +248,7 @@ def _create_runtime_wrapper(
compiled_fn,
*,
runtime_metadata: ViewAndMutationMeta,
indices_of_inps_to_detach: List[int],
indices_of_inps_to_detach: list[int],
trace_joint: bool,
keep_input_mutations: bool,
disable_amp: bool,
@ -282,7 +286,7 @@ def _create_runtime_wrapper(
for info in runtime_metadata.output_info
)
def runtime_wrapper(args: List[Any]):
def runtime_wrapper(args: list[Any]):
# stash a ref to each input tensor we plan to use after the compiled function
orig_inputs = {i: args[i] for i in epilogue_args_idx}
@ -454,7 +458,7 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
aot_config,
*,
fw_metadata,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
if config.functionalize_rng_ops:
# Update example inputs for the fw_compiler
fake_mode = detect_fake_mode()
@ -473,7 +477,7 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
runtime_metadata: ViewAndMutationMeta,
):
@wraps(compiled_fn)
def wrapper(runtime_args: List[Any]):
def wrapper(runtime_args: list[Any]):
if runtime_metadata.is_rng_op_functionalized:
# Add the seed and offset to args
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
@ -513,10 +517,10 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
@dataclass
class FakifiedOutWrapper(CompilerWrapper):
out_metas: List[torch.Tensor] = field(default_factory=list)
out_metas: list[torch.Tensor] = field(default_factory=list)
# TracingContext.fwd_output_strides
# Generated from actually doing compile
fwd_output_strides: Optional[List[List[int]]] = None
fwd_output_strides: Optional[list[list[int]]] = None
needs_post_compile: bool = True
def pre_compile(
@ -526,7 +530,7 @@ class FakifiedOutWrapper(CompilerWrapper):
aot_config,
*,
fw_metadata,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context and tracing_context.fakify_first_call:
self.out_metas = [
@ -598,7 +602,7 @@ class AOTDispatchSubclassWrapper(CompilerWrapper):
def pre_compile(
self,
flat_fn,
flat_args: List[Tensor],
flat_args: list[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
@ -626,7 +630,7 @@ class AOTDispatchSubclassWrapper(CompilerWrapper):
subclass_metas = runtime_metadata.subclass_fw_graph_out_meta
@wraps(compiled_fn)
def inner_fn(args: List[Any]):
def inner_fn(args: list[Any]):
unwrapped_args = runtime_unwrap_tensor_subclasses(
args,
subclass_metas=runtime_metadata.subclass_inp_meta,
@ -661,7 +665,7 @@ class EffectTokensWrapper(CompilerWrapper):
num_tokens = len(runtime_metadata.tokens)
@wraps(compiled_fn)
def inner_fn(args: List[Any]):
def inner_fn(args: list[Any]):
if num_tokens > 0:
# Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
old_args = args
@ -764,9 +768,9 @@ class EffectTokensWrapper(CompilerWrapper):
#
@dataclass
class AOTDedupeWrapper(CompilerWrapper):
keep_arg_mask: List[bool] = field(default_factory=list)
add_dupe_map: List[int] = field(default_factory=list)
old_input_metadata: List[InputAliasInfo] = field(default_factory=list)
keep_arg_mask: list[bool] = field(default_factory=list)
add_dupe_map: list[int] = field(default_factory=list)
old_input_metadata: list[InputAliasInfo] = field(default_factory=list)
needs_post_compile: bool = True
# NB: Hot path, avoid set lookups here
@ -780,11 +784,11 @@ class AOTDedupeWrapper(CompilerWrapper):
def pre_compile(
self,
flat_fn,
flat_args: List[Tensor],
flat_args: list[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
# Use information about whether or not flat_fn mutates its arguments
# or not to handle dupe args
@ -869,10 +873,10 @@ class AOTDedupeWrapper(CompilerWrapper):
# ]
# keep_arg_mask = [True, True, False, True]
seen_args: Dict[Tensor, int] = {}
seen_args: dict[Tensor, int] = {}
# Implicitly map duped arg position (list index) to de-duped arg position
keep_arg_mask: List[bool] = []
add_dupe_map: List[int] = []
keep_arg_mask: list[bool] = []
add_dupe_map: list[int] = []
duped_arg_len = len(flat_args)
j = 0 # index into deduped_flat_args
@ -950,7 +954,7 @@ class AOTDedupeWrapper(CompilerWrapper):
return compiled_fn
@wraps(compiled_fn)
def wrapped_compiled_fn(args: List[Any]):
def wrapped_compiled_fn(args: list[Any]):
deduped_args = self.remove_dupe_args(args)
args.clear()
return compiled_fn(deduped_args)
@ -966,7 +970,7 @@ class AOTDedupeWrapper(CompilerWrapper):
def debugged_compiled_fn(args):
# Test that the computed remove/add arg functions are an inverse
new_args = self.add_dupe_args(self.remove_dupe_args(args))
seen: Dict[Any, None] = {}
seen: dict[Any, None] = {}
for i, (x, y) in enumerate(zip(new_args, args)):
seen[y] = None
assert x is y, format_guard_bug_msg(
@ -1008,16 +1012,16 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
# the synthetic base code prohibits more cases in the autograd case than the inference case.
trace_joint: bool # TODO: refactor trace_joint
needs_post_compile: bool = True
aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list)
aliased_arg_idx_with_metadata_mutations: list[int] = field(default_factory=list)
def pre_compile(
self,
flat_fn,
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
is_inference = not self.trace_joint
flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
aot_config,
@ -1069,7 +1073,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
)
replay_views = config.view_replay_for_aliased_outputs
def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
f_args_inner = []
for inner_idx_or_tuple in synthetic_base_info:
if isinstance(inner_idx_or_tuple, int):
@ -1249,12 +1253,12 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
# f(c_base, b_base, a, d)
def merge_view_inputs(
aot_config: AOTConfig,
fwd_inputs: List[Any],
mutated_input_info: List[InputAliasInfo],
fwd_inputs: list[Any],
mutated_input_info: list[InputAliasInfo],
*,
# The autograd case currently has more restrictions than the inference case.
is_inference: bool,
) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
) -> tuple[list[Any], Optional[list[Union[int, tuple[int, torch.Tensor]]]]]:
def _are_differentiable_views(view1, view2):
if view1 is view2:
return True
@ -1278,7 +1282,7 @@ def merge_view_inputs(
# Return early when there are no mutations.
return fwd_inputs, None
storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
storage_ref_to_idx: dict[StorageWeakRef, list[int]] = collections.defaultdict(list)
base_args = []
other_args = []
for i, inpt in enumerate(fwd_inputs):
@ -1293,7 +1297,7 @@ def merge_view_inputs(
# - another int (corresponding to the index in the argument list of the element from the outer calling convention)
# - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
# idx corresponds to which synthetic base from the outer calling context to view
inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
inner_calling_convention_meta: dict[int, Union[int, tuple[int, torch.Tensor]]] = {}
for aliased_input_indices in storage_ref_to_idx.values():
if len(aliased_input_indices) <= 1 or not any(
# We only care about mutations that affect all aliases,
@ -1429,8 +1433,8 @@ def merge_view_inputs(
old_idx = arg_to_old_idx_map[make_hashable(other_arg)]
inner_calling_convention_meta[old_idx] = new_idx
# post process into a list
post_processed_calling_convention_meta: List[
Union[int, Tuple[int, torch.Tensor]]
post_processed_calling_convention_meta: list[
Union[int, tuple[int, torch.Tensor]]
] = [-1 for _ in range(len(inner_calling_convention_meta))]
for k, v in inner_calling_convention_meta.items():
post_processed_calling_convention_meta[k] = v
@ -1443,7 +1447,7 @@ def merge_view_inputs(
@dataclass
class AutogradLazyBackwardCompileInfo:
bw_module: Callable
placeholder_list: List[Any]
placeholder_list: list[Any]
saved_context: Optional[TracingContext]
saved_compile_context: Optional[CompileContext]
@ -1782,9 +1786,9 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
compiled_bw_func, # bw_module after compilation + wrappers
maybe_subclass_meta: Optional[SubclassMeta],
num_symints_saved_for_bw_: int,
backward_state_indices: List[int],
backward_state_indices: list[int],
disable_amp: bool,
indices_of_inps_to_detach: List[int],
indices_of_inps_to_detach: list[int],
lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo],
aot_config: AOTConfig,
*,
@ -2099,7 +2103,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
@dataclass
class DebugAssertWrapper(CompilerWrapper):
flat_requires_grad: List[Optional[bool]] = field(default_factory=list)
flat_requires_grad: list[Optional[bool]] = field(default_factory=list)
def post_compile(
self,
@ -2109,7 +2113,7 @@ class DebugAssertWrapper(CompilerWrapper):
runtime_metadata: ViewAndMutationMeta,
):
@wraps(compiled_fn)
def debug_compiled_function(args: List[Any]):
def debug_compiled_function(args: list[Any]):
# TODO: Check aliasing relationships
# TODO: Check strides for metadata mutation
# (NB: ideally, this logic is factored out of this function and
@ -2135,13 +2139,13 @@ class DebugAssertWrapper(CompilerWrapper):
def pre_compile(
wrappers: List[CompilerWrapper],
wrappers: list[CompilerWrapper],
flat_fn: Callable,
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
"""
Runs a sequence of wrappers on the given function and arguments.
Mutates wrappers in place.
@ -2154,12 +2158,12 @@ def pre_compile(
def post_compile(
wrappers: List[CompilerWrapper],
wrappers: list[CompilerWrapper],
compiled_fn: Callable,
aot_config: AOTConfig,
*,
runtime_metadata: ViewAndMutationMeta,
) -> Tuple[Callable, ViewAndMutationMeta]:
) -> tuple[Callable, ViewAndMutationMeta]:
"""
Runs a sequence of wrappers on the given function. Should be called after pre_compile()
"""

View File

@ -7,9 +7,10 @@ input/output types, metadata, config, function signatures etc.
import collections
import dataclasses
import functools
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Set, Union
from typing import Any, Callable, NewType, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -89,7 +90,7 @@ class OutputAliasInfo:
# here, this refers to the index of the *direct* traced
base_idx: Optional[int]
# If it is a Tensor, what the dynamic dims are (otherwise is None)
dynamic_dims: Optional[Set[int]]
dynamic_dims: Optional[set[int]]
# requires_grad
requires_grad: bool
# FunctionalTensorWrapper that represents this output.
@ -190,7 +191,7 @@ class SubclassCreationMeta:
# meta and attrs are produced by the subclass's __tensor_flatten__.
# We need to keep them around along with outer_size / outer_stride to plumb them
# into __tensor_unflatten__
attrs: Dict[str, Union["SubclassCreationMeta", PlainTensorMeta]]
attrs: dict[str, Union["SubclassCreationMeta", PlainTensorMeta]]
outer_size: Iterable[Union[None, int, torch.SymInt]]
outer_stride: Iterable[Union[None, int, torch.SymInt]]
meta: Any
@ -318,11 +319,11 @@ class SubclassCreationMeta:
class ViewAndMutationMeta:
# length = # user inputs
# This gives us info about every input, and what sort of mutation happened to it (if any)
input_info: List[InputAliasInfo]
input_info: list[InputAliasInfo]
# length = # user outputs
# This gives us info about every output (mostly around whether it aliases other tensors)
output_info: List[OutputAliasInfo]
output_info: list[OutputAliasInfo]
# length = the number of intermediate bases appended as outputs to the end of the forward graph.
# Note: this is not necessarily the same thing as:
@ -341,7 +342,7 @@ class ViewAndMutationMeta:
# Their only use today is to pass them as a best-guess for tangents when tracing the joint.
# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
# pass once, and re-use the output throughout AOTAutograd
traced_tangents: List[Any]
traced_tangents: list[Any]
# Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs
# They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors,
@ -355,7 +356,7 @@ class ViewAndMutationMeta:
# inputs[3] and inputs[4] of the plain-tensor graph".
# length = # user inputs
subclass_inp_meta: List[Union[PlainTensorMeta, SubclassCreationMeta]]
subclass_inp_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]]
# So, the full set of outputs to the forward graph looks something like:
# (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors)
# where the first 3 of those 4 can be subclasses
@ -363,9 +364,9 @@ class ViewAndMutationMeta:
# and not user visible, so there's no point in wrapping/unwrapping them at runtime).
# This list contains subclass information on all of the fw graph outputs
# except for saved_for_bw_tensors.
subclass_fw_graph_out_meta: List[Union[PlainTensorMeta, SubclassCreationMeta]]
subclass_fw_graph_out_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]]
# length = # backward graph inputs
subclass_tangent_meta: List[Union[PlainTensorMeta, SubclassCreationMeta]]
subclass_tangent_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]]
# TODO: we should kill this
# (need to default it to not break internal)
is_train: bool = False
@ -375,7 +376,7 @@ class ViewAndMutationMeta:
# At runtime, we don't keep the traced_tangents around since they're not serializable.
# Instead, we keep any necessary subclass metadata necessary about each traced_tangent.
# This list is generated after calling make_runtime_safe().
traced_tangent_metas: Optional[List[Any]] = None
traced_tangent_metas: Optional[list[Any]] = None
num_symints_saved_for_bw: Optional[int] = None
@ -393,12 +394,12 @@ class ViewAndMutationMeta:
deterministic: Optional[bool] = None
# Keeps track of which input indices store parameters (which we will treat as static)
static_input_indices: List[int] = field(default_factory=list)
static_input_indices: list[int] = field(default_factory=list)
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
# side-effectful operators, FunctionalTensorMode will populate this
# dictionary telling us how many tokens we will need during tracing.
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
tokens: dict[Any, torch.Tensor] = field(default_factory=dict)
# Only filled in if/when we trace the joint function
# If an input requires grad and is mutated in the backward, it is only safe to keep the mutation
@ -406,14 +407,14 @@ class ViewAndMutationMeta:
# (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True)
# At runtime during the backward, we use this list of indices to error properly if we find out
# that it was not safe to include a backward mutation in the graph.
indices_of_inputs_that_requires_grad_with_mutations_in_bw: List[int] = field(
indices_of_inputs_that_requires_grad_with_mutations_in_bw: list[int] = field(
default_factory=list
)
# Indexes of saved tensors which are donated buffer.
# Donated buffer means the tensor is not alias of any forward user input, forward user output,
# and backward output.
bw_donated_idxs: Optional[List[int]] = None
bw_donated_idxs: Optional[list[int]] = None
# Number of tokens used in backward, appended at the end of backward outputs.
# Filled after tracing joint function.
@ -670,7 +671,7 @@ class SubclassMeta:
#
# Optional field because we don't compute for inference graphs
grad_input_metas: Optional[
List[Union[PlainTensorMeta, SubclassCreationMeta]]
list[Union[PlainTensorMeta, SubclassCreationMeta]]
] = None
def __init__(self) -> None:
@ -704,8 +705,8 @@ class BackwardSignature:
Each string name is the `node.name` of the corresponding node in the fx graph.
"""
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
gradients_to_parameters: dict[str, str]
gradients_to_user_inputs: dict[str, str]
loss_output: str
@ -732,29 +733,29 @@ class GraphSignature:
a signature on the backward section of the joint graph.
"""
parameters: List[FQN]
buffers: List[FQN]
parameters: list[FQN]
buffers: list[FQN]
user_inputs: List[GraphInputName]
user_outputs: List[GraphOutputName]
inputs_to_parameters: Dict[GraphInputName, FQN]
inputs_to_buffers: Dict[GraphInputName, FQN]
user_inputs: list[GraphInputName]
user_outputs: list[GraphOutputName]
inputs_to_parameters: dict[GraphInputName, FQN]
inputs_to_buffers: dict[GraphInputName, FQN]
# If the user's module mutates a buffer,
# it's represented in the graph as an extra graph output.
# This dict is a mapping from
# "graph outputs that correspond to updated buffers"
# to the FQN names of those mutated buffers.
buffers_to_mutate: Dict[GraphOutputName, FQN]
user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName]
buffers_to_mutate: dict[GraphOutputName, FQN]
user_inputs_to_mutate: dict[GraphOutputName, GraphInputName]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
backward_signature: Optional[BackwardSignature]
input_tokens: List[GraphInputName]
output_tokens: List[GraphOutputName]
input_tokens: list[GraphInputName]
output_tokens: list[GraphOutputName]
@classmethod
def from_tracing_metadata(
@ -762,11 +763,11 @@ class GraphSignature:
*,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
graph_input_names: List[str],
graph_output_names: List[str],
graph_input_names: list[str],
graph_output_names: list[str],
view_mutation_metadata: ViewAndMutationMeta,
named_parameters: List[str],
named_buffers: List[str],
named_parameters: list[str],
named_buffers: list[str],
num_user_inputs: int,
num_user_outputs: int,
loss_index: Optional[int],
@ -877,15 +878,15 @@ class AOTConfig:
fw_compiler: Callable
bw_compiler: Callable
partition_fn: Callable
decompositions: Dict[OpOverload, Callable]
decompositions: dict[OpOverload, Callable]
num_params_buffers: int
aot_id: int
keep_inference_input_mutations: bool
is_export: bool = False
no_tangents: bool = False
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
static_input_indices: Optional[List[int]] = None
aot_autograd_arg_pos_to_source: Optional[list[Source]] = None
static_input_indices: Optional[list[int]] = None
inference_compiler: Optional[Callable] = None
enable_log: bool = True
# this is always false outside of export.

View File

@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Dict, List, Tuple
from typing import Any
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@ -13,13 +13,13 @@ class SubclassCreationMeta:
start_idx: int
num_tensors: int
class_type: Any
attrs: Dict[str, "SubclassCreationMeta"]
attrs: dict[str, "SubclassCreationMeta"]
metadata: Any
class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
todo: List[torch.Tensor] = list(tensors)
todo: list[torch.Tensor] = list(tensors)
def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def]
if subclass_meta is None:
@ -35,9 +35,9 @@ class UnwrapTensorSubclass(torch.nn.Module):
return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0]
def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]:
def right_inverse(self, tensor: torch.Tensor) -> list[torch.Tensor]:
assert type(tensor) is not torch.Tensor
plain_tensors: List[torch.Tensor] = []
plain_tensors: list[torch.Tensor] = []
def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def]
if type(tensor) is torch.Tensor:
@ -79,7 +79,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor}
"""
name_param: List[Tuple[str, torch.nn.Parameter]] = list(
name_param: list[tuple[str, torch.nn.Parameter]] = list(
module.named_parameters(recurse=False)
)
for name, param in name_param:

View File

@ -7,7 +7,8 @@ and this includes tensor subclasses that implement __torch_dispatch__.
import collections
import typing
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from collections.abc import Iterable
from typing import Any, Optional, TypeVar, Union
import torch
import torch.utils._pytree as pytree
@ -59,7 +60,7 @@ def maybe_suggest_memory_format(
def get_subclass_typing_container(
tensor_subclass: torch.Tensor,
) -> Dict[Type[torch.Tensor], List[Type[torch.Tensor]]]:
) -> dict[type[torch.Tensor], list[type[torch.Tensor]]]:
"""
Given a subclass, returns a recursive dictionary mapping each
inner tensors to its' subclass types.
@ -74,7 +75,7 @@ def get_subclass_typing_container(
inner_tensor = getattr(tensor_subclass, key)
_get_types_for_subclass(inner_tensor)
tracker: Dict[Any, List[Any]] = collections.defaultdict(list)
tracker: dict[Any, list[Any]] = collections.defaultdict(list)
_get_types_for_subclass(tensor_subclass)
return tracker
@ -134,13 +135,13 @@ def create_subclass_metadata(
# computes metadata about "how to reconstruct the current list of subclasses,
# if we were given their flattened dense tensors instead"
def create_subclass_meta(
curr_args: Union[List[Any], Tuple[Any, ...]],
curr_args: Union[list[Any], tuple[Any, ...]],
*,
count_symints: bool = True,
with_memory_format: bool = False,
) -> List[Union[PlainTensorMeta, SubclassCreationMeta]]:
) -> list[Union[PlainTensorMeta, SubclassCreationMeta]]:
idx = 0
infos: List[Union[PlainTensorMeta, SubclassCreationMeta]] = []
infos: list[Union[PlainTensorMeta, SubclassCreationMeta]] = []
for a in curr_args:
if is_traceable_wrapper_subclass(a):
assert isinstance(a, Tensor)
@ -173,7 +174,7 @@ def filter_symints(lst: Iterable[Union[int, SymInt]]):
return [s for s in lst if symint_check(s)]
def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> List[bool]:
def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list[bool]:
# Non-nested symints are replaced with None in `make_runtime_safe()`
return [s is None for s in lst]
@ -192,7 +193,7 @@ def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> List
# primals (but not tangents) on entry to the forward. See the runtime version of
# this function below.
def unwrap_tensor_subclasses(
wrapped_args: List[Union[Tensor, int]],
wrapped_args: list[Union[Tensor, int]],
*,
append_symints: bool,
):
@ -213,7 +214,7 @@ def unwrap_tensor_subclasses(
out.extend(filter_symints(t.size()))
out.extend(filter_symints(t.stride()))
xs_inner: List[Union[int, Tensor, SymInt]] = []
xs_inner: list[Union[int, Tensor, SymInt]] = []
for x in wrapped_args:
flatten_subclass(typing.cast(Tensor, x), out=xs_inner)
@ -224,10 +225,10 @@ def unwrap_tensor_subclasses(
# subclass_metas is needed at runtime to compute which indices are symints in
# the outer_size/outer_stride
def runtime_unwrap_tensor_subclasses(
wrapped_args: List[Union[Tensor, int]],
wrapped_args: list[Union[Tensor, int]],
*,
append_symints: bool,
subclass_metas: Optional[List[Union[PlainTensorMeta, SubclassCreationMeta]]] = None,
subclass_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = None,
):
def flatten_subclass(x: Tensor, meta: Optional[SubclassCreationMeta], *, out):
if not is_traceable_wrapper_subclass(x):
@ -262,7 +263,7 @@ def runtime_unwrap_tensor_subclasses(
)
return out
xs_inner: List[Union[int, Tensor, SymInt]] = []
xs_inner: list[Union[int, Tensor, SymInt]] = []
if append_symints:
assert subclass_metas is not None
@ -319,13 +320,13 @@ def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
def wrap_tensor_subclasses(
unwrapped_args: Union[Tuple[Any, ...], List[Any]],
unwrapped_args: Union[tuple[Any, ...], list[Any]],
*,
subclass_metas: List[Union[PlainTensorMeta, SubclassCreationMeta]],
subclass_metas: list[Union[PlainTensorMeta, SubclassCreationMeta]],
num_fw_outs_saved_for_bw: Optional[int] = None,
included_subclass_symints: bool = False,
is_runtime: bool = False,
) -> Tuple[Any, ...]:
) -> tuple[Any, ...]:
wrapped_args = []
num_args_tallied = 0
for subclass_meta in subclass_metas:
@ -386,7 +387,7 @@ def wrap_tensor_subclasses(
# - when is_joint_structure is False, args is [*primals]
def wrap_tensor_subclasses_maybe_joint(
unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta
) -> Union[Tuple[Any, ...], List[Any]]:
) -> Union[tuple[Any, ...], list[Any]]:
# Since this function is re-used for both inference and joint graphs,
if is_joint_structure:
assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2
@ -417,7 +418,7 @@ def wrap_tensor_subclasses_maybe_joint(
def compute_inner_mutated_inp_indices_from_subclass_meta(
fw_metadata: ViewAndMutationMeta,
inner_metadata: ViewAndMutationMeta,
) -> List[int]:
) -> list[int]:
# Note: [Recomputing subclass mutation handling]
#
# Generally, if a subclass requires grad, its components will not require grad.

View File

@ -14,7 +14,7 @@ It does so by:
import warnings
from contextlib import contextmanager, nullcontext
from functools import wraps
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, Union
from unittest.mock import patch
import torch
@ -190,7 +190,7 @@ def fn_prepped_for_autograd(
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
def inner_fn(primals: List[Any], tangents: List[Any]):
def inner_fn(primals: list[Any], tangents: list[Any]):
outs, tangent_mask = fn(*primals)
assert len(tangent_mask) == len(outs)
@ -232,7 +232,7 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
if config.functionalize_rng_ops:
PhiloxStateTracker.mark_beginning_of_backward()
backward_out: Tuple[Tensor, ...] = ()
backward_out: tuple[Tensor, ...] = ()
# Call the backwards pass
if grad_primals:
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
@ -733,7 +733,7 @@ def handle_effect_tokens_fn(
# In particular, we need this to tell the partitioner how many dense forward outputs there are.
def aot_dispatch_subclass(
flat_fn_maybe_joint,
args: List[Any],
args: list[Any],
*,
is_joint_structure: bool,
meta: ViewAndMutationMeta,

View File

@ -8,7 +8,7 @@ import operator
import warnings
from contextlib import nullcontext
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.utils._pytree as pytree
@ -114,7 +114,7 @@ def make_boxed_compiler(compiler):
def call_func_at_runtime_with_args(
f, args: Union[Tuple[Any], List[Any]], steal_args=False, disable_amp=False
f, args: Union[tuple[Any], list[Any]], steal_args=False, disable_amp=False
):
if not steal_args:
args = list(args)
@ -156,7 +156,7 @@ class PytreeThunk:
if self.spec.is_leaf():
self.is_really_simple = True
def unflatten(self, x: List[Any]) -> Any:
def unflatten(self, x: list[Any]) -> Any:
if self.is_really_simple:
return x[0]
if self.is_simple:
@ -168,7 +168,7 @@ class PytreeThunk:
# Creates a function that returns flattened inputs and outputs
# Also returns the output tree spec, which is needed to recover the "unflattened"
# output tree structure later.
def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]:
def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThunk]:
if kwargs is None:
kwargs = {}
# Save the args_spec for flat_tensor_args to unflatten while tracing

View File

@ -1,22 +1,10 @@
# mypy: ignore-errors
import itertools
from collections.abc import KeysView, Sequence
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from typing import (
Any,
Callable,
Dict,
KeysView,
List,
NewType,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
)
from typing import Any, Callable, NewType, Optional, Protocol, TypeVar
from unittest.mock import patch
import torch
@ -447,7 +435,7 @@ AOT_COUNTER = itertools.count()
aot_autograd_decompositions = {}
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any])
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
@ -477,7 +465,7 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
def __init__(
self,
output_code_ty: Type[TOutputCode],
output_code_ty: type[TOutputCode],
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
):
self.output_code_ty = output_code_ty
@ -492,7 +480,7 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
def process_inputs(
flat_args: List[Any],
flat_args: list[Any],
aot_config: AOTConfig,
fake_mode: FakeTensorMode,
shape_env: Optional[ShapeEnv],
@ -560,8 +548,8 @@ def process_inputs(
def construct_fake_mode(
flat_args: List[Any], aot_config: AOTConfig
) -> Tuple[FakeTensorMode, Optional[ShapeEnv]]:
flat_args: list[Any], aot_config: AOTConfig
) -> tuple[FakeTensorMode, Optional[ShapeEnv]]:
fake_mode = detect_fake_mode(flat_args)
if fake_mode is None:
shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
@ -577,7 +565,7 @@ def create_aot_dispatcher_function(
aot_config: AOTConfig,
fake_mode: FakeTensorMode,
shape_env: Optional[ShapeEnv],
) -> Tuple[Callable, ViewAndMutationMeta]:
) -> tuple[Callable, ViewAndMutationMeta]:
with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True):
return _create_aot_dispatcher_function(
flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
@ -590,7 +578,7 @@ def _create_aot_dispatcher_function(
aot_config: AOTConfig,
fake_mode: FakeTensorMode,
shape_env: Optional[ShapeEnv],
) -> Tuple[Callable, ViewAndMutationMeta]:
) -> tuple[Callable, ViewAndMutationMeta]:
"""
Traces the forward and backward graphs of the attr:`flat_fn` to generate a
joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
@ -843,7 +831,7 @@ def aot_function(
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
num_params_buffers: int = 0,
keep_inference_input_mutations: bool = False,
inference_compiler: Optional[Callable] = None,
@ -1008,7 +996,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
def _try_get_metadata_from_dynamo(
mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int
) -> Tuple[Optional[List[torch._guards.Source]], List[int]]:
) -> tuple[Optional[list[torch._guards.Source]], list[int]]:
"""
Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule.
We first verify that `mod` does come from Dynamo, then we handle cases where
@ -1074,7 +1062,7 @@ def aot_module_simplified(
fw_compiler: AOTDispatchCompiler,
bw_compiler: Optional[AOTDispatchCompiler] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
keep_inference_input_mutations=False,
inference_compiler: Optional[AOTDispatchCompiler] = None,
cudagraphs: Optional[BoxedBool] = None,
@ -1186,7 +1174,7 @@ def aot_module_simplified(
# the inputs so that they can be freed before the end of this scope.
# For overhead reasons, this is not the default wrapper, see comment:
# https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
def boxed_forward(runtime_args: List[Any]):
def boxed_forward(runtime_args: list[Any]):
flat_args = []
flat_args.extend(params_flat)
flat_args.extend(runtime_args)
@ -1204,7 +1192,7 @@ def aot_module_simplified(
# historically returned a function that was not the boxed calling
# convention. This should get fixed...
# NB: GraphModule/nn.Module rely on the non-boxed calling convention here
def forward(*runtime_args: Tuple[Any]):
def forward(*runtime_args: tuple[Any]):
full_args = []
full_args.extend(params_flat)
full_args.extend(runtime_args)
@ -1222,7 +1210,7 @@ def aot_export_module(
mod: nn.Module,
args,
*,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
# If true, we'll return a joint forward-backward graph,
# As well as metadata on the loss + gradients in the backward.
trace_joint: bool,
@ -1233,7 +1221,7 @@ def aot_export_module(
# If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong.
dynamic_shapes: Optional[bool] = None,
kwargs=None,
) -> Tuple[torch.fx.GraphModule, GraphSignature]:
) -> tuple[torch.fx.GraphModule, GraphSignature]:
"""
This function takes in a module, and returns:
(1) an FX graph that can be exported
@ -1434,7 +1422,7 @@ def aot_export_joint_simple(
# it will assume that parms/buffers are static.
# With the new inferred dynamic shapes API, maybe this doesn't matter?
num_params_buffers: int = 0,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
) -> torch.fx.GraphModule:
"""
A simplified version of export. Used by higher order operators.
@ -1530,7 +1518,7 @@ def _aot_export_function(
args,
*,
num_params_buffers: int = 0,
decompositions: Optional[Dict] = None,
decompositions: Optional[dict] = None,
# If we're exporting a joint graph and we don't want any tangent inputs in the graph
# (because we are backpropping through a scalar 1 loss),
# we need to explicitly specify not to include tangents in the graph.
@ -1543,7 +1531,7 @@ def _aot_export_function(
# If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong.
dynamic_shapes: Optional[bool] = None,
kwargs=None,
) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
kwargs = kwargs or {}
flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs)

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import NamedTuple, Tuple
from typing import NamedTuple
import torch
import torch.utils._pytree as pytree
@ -580,7 +580,7 @@ def get_tangents_in_dims(input_dims, tangents):
# Wraps a ctx object. Forwards all attr accesses to the underlying object
# except for the attrs in _pt_attrs
class WrappedCtx:
_pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
_pt_reserved_attrs: tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
def __init__(self, ctx):
if not isinstance(ctx, WrappedCtx):

View File

@ -10,7 +10,7 @@ documentation.
import textwrap
import warnings
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch._functorch.apis as apis
import torch._functorch.eager_transforms as _impl
@ -98,7 +98,7 @@ def jvp(
def jacrev(
func: Callable,
argnums: Union[int, Tuple[int]] = 0,
argnums: Union[int, tuple[int]] = 0,
*,
has_aux=False,
chunk_size: Optional[int] = None,

View File

@ -8,7 +8,7 @@
import contextlib
from functools import partial, wraps
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.autograd.forward_ad as fwAD
@ -428,7 +428,7 @@ def error_if_complex(func_name, args, is_input):
@exposed_in("torch.func")
def jacrev(
func: Callable,
argnums: Union[int, Tuple[int]] = 0,
argnums: Union[int, tuple[int]] = 0,
*,
has_aux=False,
chunk_size: Optional[int] = None,
@ -911,7 +911,7 @@ def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
)
def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
def assert_non_empty_tensor_output(output: list[Any], api: str) -> None:
if (len(output) == 1 and output[0] is None) or len(output) < 1:
raise RuntimeError(
f"{api}: Expected f to be a function that has non-empty output (got output = {output})"
@ -946,7 +946,7 @@ def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
def assert_non_empty_list_of_tensors(
output: List[torch.Tensor], api: str, argname: str
output: list[torch.Tensor], api: str, argname: str
) -> None:
if len(output) == 0:
raise RuntimeError(f"{api}: Expected {argname} to contain at least one Tensor.")
@ -1676,7 +1676,7 @@ def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
@exposed_in("torch.func")
def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
"""
Returns the value of ``func`` at ``primals`` and linear approximation
at ``primals``.

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
import torch
import torch.nn as nn
@ -10,9 +11,9 @@ from torch._functorch.utils import exposed_in
@exposed_in("torch.func")
def functional_call(
module: "torch.nn.Module",
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
args: Optional[Union[Any, Tuple]] = None,
kwargs: Optional[Dict[str, Any]] = None,
parameter_and_buffer_dicts: Union[dict[str, Tensor], Sequence[dict[str, Tensor]]],
args: Optional[Union[Any, tuple]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
tie_weights: bool = True,
strict: bool = False,
@ -126,7 +127,7 @@ def functional_call(
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
)
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
all_keys_counter: Dict[str, int] = {}
all_keys_counter: dict[str, int] = {}
for k in all_keys:
v = all_keys_counter.get(k, 0)
all_keys_counter[k] = v + 1
@ -157,7 +158,7 @@ def functional_call(
@exposed_in("torch.func")
def stack_module_state(
models: Union[Sequence[nn.Module], nn.ModuleList],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
) -> tuple[dict[str, Any], dict[str, Any]]:
"""stack_module_state(models) -> params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
@ -238,7 +239,7 @@ def stack_module_state(
def construct_stacked_leaf(
tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
tensors: Union[tuple[Tensor, ...], list[Tensor]], name: str
) -> Tensor:
all_requires_grad = all(t.requires_grad for t in tensors)
none_requires_grad = all(not t.requires_grad for t in tensors)

View File

@ -6,7 +6,7 @@ import os
import sys
from dataclasses import dataclass
from functools import partial, wraps
from typing import Callable, List
from typing import Callable
import torch
import torch.fx as fx
@ -22,8 +22,8 @@ is_tuple = object()
@dataclass
class LoadTensorMeta:
size: List[int]
stride: List[int]
size: list[int]
stride: list[int]
dtype: torch.dtype
device: torch.device
@ -164,7 +164,7 @@ def is_power_of_two(n):
@dataclass
class ReproState:
graph: fx.Graph
inps: List[torch.Tensor]
inps: list[torch.Tensor]
def __post_init__(self):
ph_nodes = get_placeholders(self.graph)

View File

@ -6,18 +6,8 @@
# LICENSE file in the root directory of this source tree.
import copy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Sequence,
Tuple,
Type,
Union,
)
from collections.abc import Iterable, Sequence
from typing import Any, Callable, NoReturn, Union
import torch
import torch.nn as nn
@ -41,9 +31,9 @@ def raise_parameter_tying_error() -> NoReturn:
def create_names_map(
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
) -> Dict[str, List[str]]:
named_params: Union[dict[str, Tensor], Iterable[tuple[str, Tensor]]],
tied_named_params: Union[dict[str, Tensor], Iterable[tuple[str, Tensor]]],
) -> dict[str, list[str]]:
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
@ -59,7 +49,7 @@ def create_names_map(
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
@ -70,9 +60,9 @@ def create_names_map(
def _extract_members(
mod: nn.Module,
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
named_members: Callable[..., Iterable[tuple[str, Tensor]]],
subclass: Callable[[Tensor], Tensor],
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]:
all_named_members = tuple(named_members(remove_duplicate=False))
unique_named_members = tuple(named_members(remove_duplicate=True))
names_map = create_names_map(unique_named_members, all_named_members)
@ -95,7 +85,7 @@ def _extract_members(
def extract_weights(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
@ -109,7 +99,7 @@ def extract_weights(
def extract_buffers(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]:
return _extract_members(mod, mod.named_buffers, lambda x: x)
@ -131,9 +121,9 @@ def load_weights(
def _swap_state(
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
) -> List[Tensor]:
result: List[Tensor] = []
mod: nn.Module, names_map: dict[str, list[str]], elems: Iterable[Tensor]
) -> list[Tensor]:
result: list[Tensor] = []
accessor = NamedMemberAccessor(mod)
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
@ -261,10 +251,10 @@ class FunctionalModuleWithBuffers(nn.Module):
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
buffer_names: Tuple[str, ...],
param_names_map: Dict[str, List[str]],
buffer_names_map: Dict[str, List[str]],
param_names: tuple[str, ...],
buffer_names: tuple[str, ...],
param_names_map: dict[str, list[str]],
buffer_names_map: dict[str, list[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
@ -277,7 +267,7 @@ class FunctionalModuleWithBuffers(nn.Module):
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
) -> tuple["FunctionalModuleWithBuffers", tuple[Tensor, ...], tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
@ -317,8 +307,8 @@ class FunctionalModule(nn.Module):
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
names_map: Dict[str, List[str]],
param_names: tuple[str, ...],
names_map: dict[str, list[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
@ -328,7 +318,7 @@ class FunctionalModule(nn.Module):
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
) -> tuple["FunctionalModule", tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
@ -349,7 +339,7 @@ class FunctionalModule(nn.Module):
def make_functional(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
) -> tuple[FunctionalModule, tuple[Tensor, ...]]:
"""make_functional(model, disable_autograd_tracking=False) -> func, params
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
@ -419,7 +409,7 @@ def make_functional(
def make_functional_with_buffers(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
) -> tuple[FunctionalModuleWithBuffers, tuple[Tensor, ...], tuple[Tensor, ...]]:
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
@ -479,8 +469,8 @@ def make_functional_with_buffers(
def transpose_stack(
tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
) -> Tuple[Tensor, ...]:
tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...]
) -> tuple[Tensor, ...]:
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
results = tuple(
torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
@ -490,7 +480,7 @@ def transpose_stack(
def combine_state_for_ensemble(
models: Sequence[nn.Module],
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
) -> tuple[FunctionalModuleWithBuffers, tuple[Tensor, ...], tuple[Tensor, ...]]:
"""combine_state_for_ensemble(models) -> func, params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
@ -551,8 +541,8 @@ def combine_state_for_ensemble(
def functional_init(
model_class: Type[nn.Module],
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
model_class: type[nn.Module],
ensemble_shape: Union[tuple[()], tuple[int]] = (),
device: torch.types.Device = "cpu",
):
def wrapped(*args, **kwargs):
@ -578,8 +568,8 @@ def functional_init(
def functional_init_with_buffers(
model_class: Type[nn.Module],
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
model_class: type[nn.Module],
ensemble_shape: Union[tuple[()], tuple[int]] = (),
device: torch.types.Device = "cpu",
):
def wrapped(*args, **kwargs):

View File

@ -9,7 +9,7 @@ import operator
import os
from collections import defaultdict
from dataclasses import dataclass, replace
from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Callable, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.inductor_prims
@ -55,11 +55,11 @@ prims = torch.ops.prims
class OpTypes:
"""Class for keeping track of different operator categories"""
fusible_ops: Set[Callable]
compute_intensive_ops: Set[Callable]
random_ops: Set[Callable]
view_ops: Set[Callable]
recomputable_ops: Set[Callable]
fusible_ops: set[Callable]
compute_intensive_ops: set[Callable]
random_ops: set[Callable]
view_ops: set[Callable]
recomputable_ops: set[Callable]
def is_fusible(self, node: fx.Node):
return get_aten_target(node) in self.fusible_ops
@ -81,14 +81,14 @@ class OpTypes:
class NodeInfo:
# Be careful about iterating over these explicitly, as their order may not
# be deterministic
inputs: List[fx.Node]
_required_fw_nodes: Set[fx.Node]
required_bw_nodes: Set[fx.Node]
unclaimed_nodes: Set[fx.Node]
fw_order: Dict[fx.Node, int]
inputs: list[fx.Node]
_required_fw_nodes: set[fx.Node]
required_bw_nodes: set[fx.Node]
unclaimed_nodes: set[fx.Node]
fw_order: dict[fx.Node, int]
@functools.cached_property
def required_fw_nodes(self) -> List[fx.Node]:
def required_fw_nodes(self) -> list[fx.Node]:
return sorted(
(n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n]
)
@ -158,8 +158,8 @@ InvalidNode = InvalidNodeBase()
def _extract_graph_with_inputs_outputs(
joint_graph: fx.Graph,
inputs: List[fx.Node],
outputs: List[fx.Node],
inputs: list[fx.Node],
outputs: list[fx.Node],
subgraph: Optional[str] = None,
) -> fx.Graph:
"""
@ -272,7 +272,7 @@ def _must_be_in_backward(node: fx.Node) -> bool:
def _extract_fwd_bwd_outputs(
joint_module: fx.GraphModule, *, num_fwd_outputs
) -> Tuple[List[fx.Node], List[fx.Node]]:
) -> tuple[list[fx.Node], list[fx.Node]]:
outputs = pytree.arg_tree_leaves(
*(node.args for node in joint_module.graph.find_nodes(op="output"))
)
@ -281,7 +281,7 @@ def _extract_fwd_bwd_outputs(
return fwd_outputs, bwd_outputs
def _remove_by_name(saved_values: List[fx.Node], name: str):
def _remove_by_name(saved_values: list[fx.Node], name: str):
for saved_value in saved_values:
if saved_value.name == name:
saved_values.remove(saved_value)
@ -290,11 +290,11 @@ def _remove_by_name(saved_values: List[fx.Node], name: str):
def _extract_fwd_bwd_modules(
joint_module: fx.GraphModule,
saved_values: List[fx.Node],
saved_sym_nodes: List[fx.Node],
saved_values: list[fx.Node],
saved_sym_nodes: list[fx.Node],
*,
num_fwd_outputs: int,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
) -> tuple[fx.GraphModule, fx.GraphModule]:
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
joint_module, num_fwd_outputs=num_fwd_outputs
)
@ -326,7 +326,7 @@ def _extract_fwd_bwd_modules(
# we propagate all symbols which are referenced by backwards inputs.
# These are not directly used in the graph but are required for downstream
# sizevar assignment
saved_symbols: Set[sympy.Symbol] = set()
saved_symbols: set[sympy.Symbol] = set()
saved_sym_nodes_binding = []
saved_sym_nodes_derived = []
@ -389,7 +389,7 @@ def _extract_fwd_bwd_modules(
def default_partition(
joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs
) -> Tuple[fx.GraphModule, fx.GraphModule]:
) -> tuple[fx.GraphModule, fx.GraphModule]:
"""
Partitions the :attr:`joint_module` in a manner that closely resembles the
behavior observed in the original ``.forward()`` and ``.backward()`` of the
@ -512,7 +512,7 @@ def _size_of(node: fx.Node) -> int:
def _count_ops(graph: fx.Graph):
from collections import defaultdict
cnt: Dict[str, int] = defaultdict(int)
cnt: dict[str, int] = defaultdict(int)
for node in graph.nodes:
if node.op == "call_function":
cnt[node.target.__name__] += 1
@ -537,7 +537,7 @@ def pointwise_ops():
return ops
def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]:
def sort_depths(args, depth_map: dict[fx.Node, int]) -> list[tuple[fx.Node, int]]:
arg_depths = {
arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
}
@ -568,7 +568,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
"""
new_graph = fx.Graph()
env: Dict[fx.Node, fx.Node] = {}
env: dict[fx.Node, fx.Node] = {}
# Add new placeholder nodes in the order specified by the inputs
for node in gm.graph.find_nodes(op="placeholder"):
@ -623,7 +623,7 @@ def functionalize_rng_ops(
fw_module: fx.GraphModule,
bw_module: fx.GraphModule,
num_sym_nodes: int,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
) -> tuple[fx.GraphModule, fx.GraphModule]:
# During user-driven activation checkpointing, we have to ensure that a rng
# op in fwd yields the same output as the recomputed rng op in the bwd. To
# do this, we use functionalize wrappers to wrap the random ops and share
@ -1074,12 +1074,12 @@ def solve_min_cut(
# backwards pass instead of only relying on whether it's unfusible in the
# forwards.
def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int:
def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
"""
Finds the first unfusible node in the chain of nodes starting from
`start_nodes` and returns its position.
"""
sorted_nodes: List[Tuple[int, fx.Node, bool]] = []
sorted_nodes: list[tuple[int, fx.Node, bool]] = []
for n in start_nodes:
heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True))
@ -1182,7 +1182,7 @@ def solve_min_cut(
raise
reachable, non_reachable = partition
cutset: Set[Tuple[str, str]] = set()
cutset: set[tuple[str, str]] = set()
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
cutset.update((u, v) for v in nbrs if v in non_reachable)
@ -1219,7 +1219,7 @@ def visualize_min_cut_graph(nx_graph):
def get_default_op_list() -> OpTypes:
default_recomputable_ops: List[Callable] = [
default_recomputable_ops: list[Callable] = [
aten.add,
aten.sub,
aten.div,
@ -1392,12 +1392,12 @@ def get_name_to_node(graph: fx.Graph):
def _optimize_runtime_with_given_memory(
joint_graph: fx.Graph,
memory: List[float],
runtimes: List[float],
memory: list[float],
runtimes: list[float],
max_memory: float,
node_info: NodeInfo,
all_recomputable_banned_nodes: List[fx.Node],
) -> Tuple[float, List[int], List[int]]:
all_recomputable_banned_nodes: list[fx.Node],
) -> tuple[float, list[int], list[int]]:
SOLVER = config.activation_memory_budget_solver
if SOLVER == "greedy":
return greedy_knapsack(memory, runtimes, max_memory)
@ -1490,7 +1490,7 @@ def choose_saved_values_set(
joint_graph: fx.Graph,
node_info: NodeInfo,
memory_budget=1,
) -> List[fx.Node]:
) -> list[fx.Node]:
if memory_budget > 1 or memory_budget < 0:
raise RuntimeError(
f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}"
@ -1523,7 +1523,7 @@ def choose_saved_values_set(
if memory_budget == 1:
return runtime_optimized_saved_values
def estimate_activations_size(saved_values: List[fx.Node]) -> float:
def estimate_activations_size(saved_values: list[fx.Node]) -> float:
return sum(map(_size_of, saved_values)) / 1e9
min_act_size = estimate_activations_size(node_info.inputs)
@ -1535,7 +1535,7 @@ def choose_saved_values_set(
def get_normalized_size(sz):
return (sz / 1e9) / (max_act_size - min_act_size)
def get_mem_ratio(activations: List[fx.Node]):
def get_mem_ratio(activations: list[fx.Node]):
return (estimate_activations_size(activations) - min_act_size) / (
max_act_size - min_act_size
)
@ -1567,7 +1567,7 @@ def choose_saved_values_set(
input_storages = {get_node_storage(node) for node in node_info.inputs}
def get_recomputable_banned_nodes(banned_nodes: Set[fx.Node]) -> List[fx.Node]:
def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]:
return [
i
for i in banned_nodes
@ -1729,7 +1729,7 @@ def min_cut_rematerialization_partition(
compiler="inductor",
*,
num_fwd_outputs,
) -> Tuple[fx.GraphModule, fx.GraphModule]:
) -> tuple[fx.GraphModule, fx.GraphModule]:
"""
Partitions the joint graph such that the backward recomputes the forward.
Recomputing helps in trading off memory bandwidth with computation.
@ -1799,7 +1799,7 @@ def min_cut_rematerialization_partition(
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward"
)
required_fw_nodes: Set[fx.Node] = {
required_fw_nodes: set[fx.Node] = {
name_to_node[node.name]
for node in forward_only_graph.nodes
if node.op != "output"
@ -1886,7 +1886,7 @@ def min_cut_rematerialization_partition(
}
remat_nodes = fw_module_nodes & bw_module_nodes
counts: Dict[str, int] = defaultdict(int)
counts: dict[str, int] = defaultdict(int)
for node in fw_module.graph.nodes:
if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
counts[str(node.target._overloadpacket)] += 1
@ -1906,7 +1906,7 @@ def draw_graph(
fname: str,
figname: str = "fx_graph",
clear_meta: bool = True,
prog: Optional[Union[str, List[str]]] = None,
prog: Optional[Union[str, list[str]]] = None,
parse_stack_trace: bool = False,
dot_graph_shape: Optional[str] = None,
) -> None:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
from typing import Any
import torch
import torch.utils._pytree as pytree
@ -258,14 +258,14 @@ def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
return coerce_cinterpreter(interpreter)
def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
def retrieve_all_functorch_interpreters() -> list[FuncTorchInterpreter]:
cis = torch._C._functorch.get_interpreter_stack()
if cis is None:
return []
return [coerce_cinterpreter(ci) for ci in cis]
def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
def compare_functorch_state(states: list[tuple[Any, ...]]) -> bool:
# There are four possible cases covered here:
# 1. Current stack empty AND stack when generated not empty -> Invalidate
# 2. Current stack not empty AND stack when generated empty -> Invalidate

View File

@ -1,5 +1,6 @@
import contextlib
from typing import Any, Generator, Tuple, Union
from collections.abc import Generator
from typing import Any, Union
import torch
from torch._C._functorch import (
@ -28,7 +29,7 @@ def enable_single_level_autograd_function() -> Generator[None, None, None]:
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args: Tuple[Any, ...]) -> Tuple[Any, ...]:
def unwrap_dead_wrappers(args: tuple[Any, ...]) -> tuple[Any, ...]:
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
@ -36,4 +37,4 @@ def unwrap_dead_wrappers(args: Tuple[Any, ...]) -> Tuple[Any, ...]:
return result
argnums_t = Union[int, Tuple[int, ...]]
argnums_t = Union[int, tuple[int, ...]]

View File

@ -12,7 +12,7 @@ import itertools
import os
import threading
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
from torch import Tensor
@ -32,8 +32,8 @@ from torch.utils._pytree import (
)
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
in_dims_t = Union[int, tuple]
out_dims_t = Union[int, tuple[int, ...]]
def doesnt_support_saved_tensors_hooks(f):
@ -52,7 +52,7 @@ def doesnt_support_saved_tensors_hooks(f):
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]], flat_args: List
flat_in_dims: list[Optional[int]], flat_args: list
) -> int:
batch_sizes = [
arg.size(in_dim)
@ -69,7 +69,7 @@ def _validate_and_get_batch_size(
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
@ -81,7 +81,7 @@ def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
def _as_tuple(
value: Any, num_elements: int, error_message_lambda: Callable[[], str]
) -> Tuple:
) -> tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
@ -90,8 +90,8 @@ def _as_tuple(
def _process_batched_inputs(
in_dims: in_dims_t, args: Tuple, func: Callable
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
in_dims: in_dims_t, args: tuple, func: Callable
) -> tuple[int, list[Any], list[Any], TreeSpec]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
@ -152,8 +152,8 @@ def _process_batched_inputs(
def _create_batched_inputs(
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec
) -> Tuple:
flat_in_dims: list[Any], flat_args: list[Any], vmap_level: int, args_spec
) -> tuple:
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [
arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level)
@ -186,12 +186,12 @@ def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_di
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
batched_outputs: Union[Tensor, tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int,
batch_size: int,
func: Callable,
) -> Tuple:
) -> tuple:
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
def incompatible_error():