mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_functorch (#145139)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145139 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
10e4d3aebb
commit
78bff1e8c1
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@ -20,18 +20,18 @@ class GraphInfoProvider:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_nodes_in_order: List[str],
|
||||
graph_edges: List[Tuple[str, str]],
|
||||
all_recomputable_banned_nodes: List[str],
|
||||
all_node_runtimes: Optional[Dict[str, float]] = None,
|
||||
all_node_memories: Optional[Dict[str, float]] = None,
|
||||
recorded_knapsack_input_memories: Optional[List[float]] = None,
|
||||
recorded_knapsack_input_runtimes: Optional[List[float]] = None,
|
||||
graph_nodes_in_order: list[str],
|
||||
graph_edges: list[tuple[str, str]],
|
||||
all_recomputable_banned_nodes: list[str],
|
||||
all_node_runtimes: Optional[dict[str, float]] = None,
|
||||
all_node_memories: Optional[dict[str, float]] = None,
|
||||
recorded_knapsack_input_memories: Optional[list[float]] = None,
|
||||
recorded_knapsack_input_runtimes: Optional[list[float]] = None,
|
||||
joint_graph: Optional[Graph] = None,
|
||||
):
|
||||
self.graph_nodes_in_order = graph_nodes_in_order
|
||||
self.graph_edges = graph_edges
|
||||
self.all_node_runtimes: Dict[str, float] = dict()
|
||||
self.all_node_runtimes: dict[str, float] = dict()
|
||||
if all_node_runtimes is None:
|
||||
if recorded_knapsack_input_runtimes is None:
|
||||
raise ValueError(
|
||||
@ -43,7 +43,7 @@ class GraphInfoProvider:
|
||||
}
|
||||
else:
|
||||
self.all_node_runtimes.update(all_node_runtimes)
|
||||
self.all_node_memories: Dict[str, float] = dict()
|
||||
self.all_node_memories: dict[str, float] = dict()
|
||||
if all_node_memories is None:
|
||||
if recorded_knapsack_input_memories is None:
|
||||
raise ValueError(
|
||||
@ -59,7 +59,7 @@ class GraphInfoProvider:
|
||||
self.all_recomputable_banned_nodes_set = set(all_recomputable_banned_nodes)
|
||||
self.recorded_knapsack_input_memories = recorded_knapsack_input_memories
|
||||
self.recorded_knapsack_input_runtimes = recorded_knapsack_input_runtimes
|
||||
self._lazily_initialized_graphs: Dict[str, Any] = {
|
||||
self._lazily_initialized_graphs: dict[str, Any] = {
|
||||
self.__RECOMPUTABLE_NODE_ONLY_GRAPH: None,
|
||||
self.__RECOMPUTABLE_NODE_ONLY_GRAPH_WITH_LARGER_GRAPH_CONTEXT: None,
|
||||
self.__FULL_NX_JOINT_GRAPH: None,
|
||||
@ -70,9 +70,9 @@ class GraphInfoProvider:
|
||||
def inialize_from_graph(
|
||||
cls,
|
||||
joint_graph: Graph,
|
||||
all_recomputable_banned_nodes: List[Node],
|
||||
recorded_knapsack_input_memories: List[float],
|
||||
recorded_knapsack_input_runtimes: List[float],
|
||||
all_recomputable_banned_nodes: list[Node],
|
||||
recorded_knapsack_input_memories: list[float],
|
||||
recorded_knapsack_input_runtimes: list[float],
|
||||
) -> "GraphInfoProvider":
|
||||
"""
|
||||
Enables initialization from a joint graph.
|
||||
@ -144,7 +144,7 @@ class GraphInfoProvider:
|
||||
for node_name in self.all_recomputable_banned_nodes_set
|
||||
)
|
||||
|
||||
def get_knapsack_memory_input(self) -> List[float]:
|
||||
def get_knapsack_memory_input(self) -> list[float]:
|
||||
return (
|
||||
self.recorded_knapsack_input_memories
|
||||
if self.recorded_knapsack_input_memories
|
||||
@ -154,7 +154,7 @@ class GraphInfoProvider:
|
||||
]
|
||||
)
|
||||
|
||||
def get_knapsack_runtime_input(self) -> List[float]:
|
||||
def get_knapsack_runtime_input(self) -> list[float]:
|
||||
return (
|
||||
self.recorded_knapsack_input_runtimes
|
||||
if self.recorded_knapsack_input_runtimes
|
||||
@ -224,7 +224,7 @@ class GraphInfoProvider:
|
||||
|
||||
def _recreate_psuedo_joint_graph(self) -> Graph:
|
||||
# Create a dictionary to store the dependencies of each node
|
||||
node_dependencies: Dict[str, List[str]] = {
|
||||
node_dependencies: dict[str, list[str]] = {
|
||||
node: [] for node in self.graph_nodes_in_order
|
||||
}
|
||||
for a, b in self.graph_edges:
|
||||
@ -234,7 +234,7 @@ class GraphInfoProvider:
|
||||
|
||||
joint_graph = Graph()
|
||||
# Create nodes in the graph
|
||||
nodes: Dict[str, Node] = {}
|
||||
nodes: dict[str, Node] = {}
|
||||
for node_name in self.graph_nodes_in_order:
|
||||
input_nodes = [nodes[dep] for dep in node_dependencies[node_name]]
|
||||
if input_nodes:
|
||||
|
@ -1,11 +1,9 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def greedy_knapsack(
|
||||
memory: List[float], runtimes: List[float], max_memory: float
|
||||
) -> Tuple[float, List[int], List[int]]:
|
||||
memory: list[float], runtimes: list[float], max_memory: float
|
||||
) -> tuple[float, list[int], list[int]]:
|
||||
n = len(runtimes)
|
||||
items = list(range(n))
|
||||
|
||||
@ -28,8 +26,8 @@ def greedy_knapsack(
|
||||
|
||||
|
||||
def ilp_knapsack(
|
||||
memory: List[float], runtimes: List[float], max_memory: float
|
||||
) -> Tuple[float, List[int], List[int]]:
|
||||
memory: list[float], runtimes: list[float], max_memory: float
|
||||
) -> tuple[float, list[int], list[int]]:
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
@ -64,8 +62,8 @@ def ilp_knapsack(
|
||||
|
||||
|
||||
def dp_knapsack(
|
||||
memory: List[float], runtime: List[float], max_memory: float
|
||||
) -> Tuple[float, List[int], List[int]]:
|
||||
memory: list[float], runtime: list[float], max_memory: float
|
||||
) -> tuple[float, list[int], list[int]]:
|
||||
# Scaling factor to convert floating point weights to integers
|
||||
S = 10000
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections import deque
|
||||
from typing import Callable, Dict, List, Set, Tuple
|
||||
from typing import Callable
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@ -25,10 +25,10 @@ class KnapsackEvaluator:
|
||||
def _get_backward_memory_from_topologically_sorted_graph(
|
||||
self,
|
||||
node_graph: nx.DiGraph,
|
||||
node_memories: Dict[str, float],
|
||||
saved_nodes_set: Set[str],
|
||||
node_memories: dict[str, float],
|
||||
saved_nodes_set: set[str],
|
||||
peak_memory_after_forward_pass: float,
|
||||
) -> List[Tuple[float, str]]:
|
||||
) -> list[tuple[float, str]]:
|
||||
"""
|
||||
Simulates the backward pass and keeps track of the peak memory usage.
|
||||
|
||||
@ -108,7 +108,7 @@ class KnapsackEvaluator:
|
||||
return current_memory
|
||||
|
||||
def _validate_all_indexes_accounted_for_in_provided_output(
|
||||
self, saved_nodes_idxs: List[int], recomputable_node_idxs: List[int]
|
||||
self, saved_nodes_idxs: list[int], recomputable_node_idxs: list[int]
|
||||
) -> None:
|
||||
"""
|
||||
Validate that all indexes are accounted for in the provided output.
|
||||
@ -132,10 +132,10 @@ class KnapsackEvaluator:
|
||||
|
||||
def evaluate_knapsack_output(
|
||||
self,
|
||||
saved_nodes_idxs: List[int],
|
||||
recomputable_node_idxs: List[int],
|
||||
saved_nodes_idxs: list[int],
|
||||
recomputable_node_idxs: list[int],
|
||||
account_for_backward_pass: bool = False,
|
||||
) -> Dict[str, float]:
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Evaluate the theoretical runtime and peak memory usage of a given checkpointing strategy.
|
||||
Args:
|
||||
@ -188,10 +188,10 @@ class KnapsackEvaluator:
|
||||
def evaluate_distribution_of_results_for_knapsack_algo(
|
||||
self,
|
||||
knapsack_algo: Callable[
|
||||
[List[float], List[float], float], Tuple[float, List[int], List[int]]
|
||||
[list[float], list[float], float], tuple[float, list[int], list[int]]
|
||||
],
|
||||
memory_budget_values: List[float],
|
||||
) -> List[Dict[str, float]]:
|
||||
memory_budget_values: list[float],
|
||||
) -> list[dict[str, float]]:
|
||||
"""
|
||||
Evaluates the distribution of results for a given knapsack algorithm.
|
||||
Args:
|
||||
@ -216,7 +216,7 @@ class KnapsackEvaluator:
|
||||
def get_knee_point_memory_budget(
|
||||
self,
|
||||
knapsack_algo: Callable[
|
||||
[List[float], List[float], float], Tuple[float, List[int], List[int]]
|
||||
[list[float], list[float], float], tuple[float, list[int], list[int]]
|
||||
],
|
||||
max_mem_budget: float = 0.1,
|
||||
min_mem_budget: float = 0.001,
|
||||
|
@ -14,7 +14,7 @@ import pickle
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
|
||||
@ -273,7 +273,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
||||
class AOTAutogradCachePickler(FxGraphCachePickler):
|
||||
def __init__(self, gm: torch.fx.GraphModule):
|
||||
super().__init__(gm)
|
||||
self.dispatch_table: Dict
|
||||
self.dispatch_table: dict
|
||||
self.dispatch_table.update(
|
||||
{
|
||||
AOTConfig: functools.partial(self._reduce_aot_config),
|
||||
@ -313,7 +313,7 @@ def autograd_cache_key(
|
||||
config: AOTConfig,
|
||||
fx_config: _CompileFxKwargs,
|
||||
# TODO: add args and parameters
|
||||
) -> Tuple[str, List[str]]:
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Generate a unique hash of the FX graph for caching.
|
||||
"""
|
||||
@ -399,7 +399,7 @@ class CompiledBackward(FXGraphCacheLoadable):
|
||||
"""
|
||||
|
||||
# Used by AOTDispatchAutograd.post_compile
|
||||
backward_state_indices: List[int]
|
||||
backward_state_indices: list[int]
|
||||
num_symints_saved_for_bw_: int
|
||||
|
||||
def is_backward(self):
|
||||
@ -424,14 +424,14 @@ class AOTAutogradCacheEntry:
|
||||
runtime_metadata: ViewAndMutationMeta
|
||||
|
||||
# Wrappers that run after each aot_dispatch_* function
|
||||
dispatch_wrappers: List[CompilerWrapper]
|
||||
dispatch_wrappers: list[CompilerWrapper]
|
||||
|
||||
# Used by AOTSubclassWrapper
|
||||
maybe_subclass_meta: Optional[SubclassMeta]
|
||||
num_fw_outs_saved_for_bw: Optional[int]
|
||||
|
||||
# Used by RuntimeWrapepr
|
||||
indices_of_inps_to_detach: List[int]
|
||||
indices_of_inps_to_detach: list[int]
|
||||
|
||||
# Time taken to trace/compile the forward
|
||||
# forward_time_taken includes AOTAutograd tracing time + inductor compilation time
|
||||
@ -442,7 +442,7 @@ class AOTAutogradCacheEntry:
|
||||
# Turn cache entry into the original callable
|
||||
def wrap_post_compile(
|
||||
self,
|
||||
args: List[torch.Tensor],
|
||||
args: list[torch.Tensor],
|
||||
aot_config: AOTConfig,
|
||||
fx_config: _CompileFxKwargs,
|
||||
) -> Callable:
|
||||
@ -675,9 +675,9 @@ class AOTAutogradCache:
|
||||
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
|
||||
with sanitize_gm_for_cache(gm):
|
||||
compiled_fn = None
|
||||
cache_info: Dict[str, Any] = {}
|
||||
cache_info: dict[str, Any] = {}
|
||||
cache_key = None
|
||||
debug_lines: List[str] = []
|
||||
debug_lines: list[str] = []
|
||||
cache_event_time = time.time_ns()
|
||||
cache_state = None
|
||||
fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs}
|
||||
|
@ -12,7 +12,7 @@ import collections
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional, Set
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -150,13 +150,13 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
# TODO: refactor to kill this flag
|
||||
is_train: bool = False,
|
||||
# Note: this is guaranteed to be set when running under dynamo
|
||||
static_input_indices: Optional[List[int]] = None,
|
||||
static_input_indices: Optional[list[int]] = None,
|
||||
pre_dispatch: bool = False,
|
||||
# is_export is technically only needed to avoid using functionalization V2
|
||||
# during analysis
|
||||
is_export: bool = False,
|
||||
) -> Callable[..., ViewAndMutationMeta]:
|
||||
memo: Dict[Tensor, Tensor] = {}
|
||||
memo: dict[Tensor, Tensor] = {}
|
||||
|
||||
def _to_fun(t):
|
||||
if isinstance(t, Tensor):
|
||||
@ -173,8 +173,8 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
|
||||
assert all(isinstance(a, tuple(KNOWN_TYPES)) for a in flat_args)
|
||||
|
||||
input_info: List[InputAliasInfo] = []
|
||||
output_info: List[OutputAliasInfo] = []
|
||||
input_info: list[InputAliasInfo] = []
|
||||
output_info: list[OutputAliasInfo] = []
|
||||
|
||||
prior_grad_enabled = torch.is_grad_enabled()
|
||||
prior_autocast_states = _get_autocast_states()
|
||||
@ -275,15 +275,16 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
|
||||
|
||||
# Keep track of which outputs alias other outputs
|
||||
out_tensor_alias_counts: DefaultDict = collections.defaultdict(int)
|
||||
out_tensor_alias_counts: collections.defaultdict = collections.defaultdict(int)
|
||||
# This tells us, for a given group of outputs that alias each other,
|
||||
# whether they e.g. all came from an unbind call
|
||||
num_aliased_tensors_that_are_multi_output_views: DefaultDict = (
|
||||
num_aliased_tensors_that_are_multi_output_views: collections.defaultdict = (
|
||||
collections.defaultdict(int)
|
||||
)
|
||||
|
||||
out_storage_to_metadata_key_to_tensors: DefaultDict[
|
||||
Optional[StorageWeakRef], DefaultDict[MetadataKey, Set[torch.Tensor]]
|
||||
out_storage_to_metadata_key_to_tensors: collections.defaultdict[
|
||||
Optional[StorageWeakRef],
|
||||
collections.defaultdict[MetadataKey, set[torch.Tensor]],
|
||||
] = collections.defaultdict(lambda: collections.defaultdict(set))
|
||||
|
||||
curr_storage = None
|
||||
@ -382,8 +383,8 @@ def run_functionalized_fw_and_collect_metadata(
|
||||
].add(o)
|
||||
|
||||
# maps the id of an intermediate base to its index in the output of the compiled forward
|
||||
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
|
||||
intermediate_bases: List[torch.Tensor] = []
|
||||
intermediate_base_tensor_id_to_output_idx: dict[int, int] = {}
|
||||
intermediate_bases: list[torch.Tensor] = []
|
||||
# Why Do We Care If Storage Changed?
|
||||
# It's important to understand the implications of storage changes in complex scenarios. Take this example:
|
||||
#
|
||||
|
@ -5,7 +5,7 @@ pathways, taking into account the AOTConfig and the collected ViewAndMutationMet
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -72,11 +72,11 @@ def _detach_and_copy_item_memo(t):
|
||||
|
||||
def aot_dispatch_base_graph(
|
||||
flat_fn,
|
||||
flat_args: List[Tensor],
|
||||
flat_args: list[Tensor],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[torch.fx.GraphModule, List[Any], Optional[SubclassMeta]]:
|
||||
) -> tuple[torch.fx.GraphModule, list[Any], Optional[SubclassMeta]]:
|
||||
# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
|
||||
# The cases that aot_dispatch_base doesn't need to handle include:
|
||||
# - outputs that are aliases of graph intermediates
|
||||
@ -133,7 +133,7 @@ def aot_dispatch_base_graph(
|
||||
if aot_config.is_export and mod_when_exporting_non_strict is not None:
|
||||
# For any buffer that is assigned, we want to associate it to the final proxy node
|
||||
# that it is assigned to. This node can then be added as a buffer mutation output.
|
||||
assigned_buffers: Dict[str, str] = {}
|
||||
assigned_buffers: dict[str, str] = {}
|
||||
hook = register_buffer_assignment_hook(
|
||||
mod_when_exporting_non_strict, assigned_buffers
|
||||
)
|
||||
@ -250,11 +250,11 @@ def aot_dispatch_base_graph(
|
||||
# the same storage, so long as they have separate TensorImpls.)
|
||||
def aot_dispatch_autograd_graph(
|
||||
flat_fn,
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]:
|
||||
) -> tuple[torch.fx.GraphModule, tuple[list[Any], list[Any]], Optional[SubclassMeta]]:
|
||||
# traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
|
||||
# It includes outputs of the original forward, *and* any updated inputs due to input mutations.
|
||||
# However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
|
||||
|
@ -9,7 +9,7 @@ This file contains utilities related to functionalization in AOTAutograd:
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -337,11 +337,11 @@ class MetadataKey:
|
||||
This should be equal whenever has_same_metadata would return True
|
||||
"""
|
||||
|
||||
size: Tuple[SymIntEqByExpr, ...]
|
||||
size: tuple[SymIntEqByExpr, ...]
|
||||
layout: torch.layout
|
||||
is_sparse: bool
|
||||
# these are empty when is_sparse
|
||||
stride: Optional[Tuple[SymIntEqByExpr, ...]]
|
||||
stride: Optional[tuple[SymIntEqByExpr, ...]]
|
||||
storage_offset: Optional[SymIntEqByExpr]
|
||||
is_conj: bool
|
||||
is_neg: bool
|
||||
|
@ -12,7 +12,7 @@ In particular, the following analyses are provided:
|
||||
|
||||
import contextlib
|
||||
import itertools
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -40,8 +40,8 @@ zip = strict_zip
|
||||
|
||||
def remove_dupe_metadata(
|
||||
m: ViewAndMutationMeta,
|
||||
keep_arg_mask: List[bool],
|
||||
add_dupe_map: List[int],
|
||||
keep_arg_mask: list[bool],
|
||||
add_dupe_map: list[int],
|
||||
) -> ViewAndMutationMeta:
|
||||
assert len(m.input_info) == len(keep_arg_mask)
|
||||
# Easy invariant: the first argument should never be a dupe (it will be kept)
|
||||
@ -104,12 +104,12 @@ def create_synthetic_base_metadata(
|
||||
m: ViewAndMutationMeta,
|
||||
# Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a
|
||||
# synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata)
|
||||
synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]],
|
||||
outer_args: List[Any],
|
||||
inner_args: List[Any],
|
||||
) -> Tuple[ViewAndMutationMeta, List[int]]:
|
||||
synthetic_base_info: list[Union[int, tuple[int, torch.Tensor]]],
|
||||
outer_args: list[Any],
|
||||
inner_args: list[Any],
|
||||
) -> tuple[ViewAndMutationMeta, list[int]]:
|
||||
# maps inner arg indices to outer arg indices
|
||||
synthetic_base_to_indices: Dict[int, List[int]] = {}
|
||||
synthetic_base_to_indices: dict[int, list[int]] = {}
|
||||
for inner_idx in range(len(inner_args)):
|
||||
outer_aliased_indices_of_current_base_arg = [
|
||||
outer_idx
|
||||
@ -348,10 +348,10 @@ def create_graph_signature(
|
||||
in_spec: pytree.TreeSpec,
|
||||
out_spec: pytree.TreeSpec,
|
||||
*,
|
||||
user_args_flat: List[Tensor],
|
||||
params_and_buffers_flat: List[Tensor],
|
||||
param_names: List[str],
|
||||
buffer_names: List[str],
|
||||
user_args_flat: list[Tensor],
|
||||
params_and_buffers_flat: list[Tensor],
|
||||
param_names: list[str],
|
||||
buffer_names: list[str],
|
||||
trace_joint: bool,
|
||||
num_user_fw_outs: Optional[int],
|
||||
loss_index: Optional[int],
|
||||
|
@ -14,7 +14,7 @@ import logging
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
@ -72,6 +72,10 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
zip = strict_zip
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -82,10 +86,10 @@ aten = torch.ops.aten
|
||||
|
||||
# Returns a Callable and a ViewAndMutationMeta.
|
||||
# Currently, only export needs the ViewAndMutationMeta after this function.
|
||||
DispatchReturn = Tuple[Callable, ViewAndMutationMeta]
|
||||
DispatchReturn = tuple[Callable, ViewAndMutationMeta]
|
||||
|
||||
|
||||
def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]:
|
||||
def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper]:
|
||||
"""
|
||||
Wrappers that run on every dispatch function
|
||||
"""
|
||||
@ -96,7 +100,7 @@ def _create_wrappers_for_dispatch(needs_autograd: bool) -> List[CompilerWrapper]
|
||||
# bits of aot_autograd, and doesn't need to do any specific wrapping.
|
||||
def aot_dispatch_export(
|
||||
flat_fn: Callable,
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
@ -136,7 +140,7 @@ def aot_dispatch_export(
|
||||
|
||||
def aot_dispatch_base(
|
||||
flat_fn,
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
@ -283,11 +287,11 @@ def aot_dispatch_base(
|
||||
|
||||
|
||||
def collect_fw_donated_buffer_idxs(
|
||||
fw_ins: List[Optional[FakeTensor]],
|
||||
user_fw_outs: List[Optional[FakeTensor]],
|
||||
bw_outs: List[Optional[FakeTensor]],
|
||||
saved_tensors: List[FakeTensor],
|
||||
) -> List[int]:
|
||||
fw_ins: list[Optional[FakeTensor]],
|
||||
user_fw_outs: list[Optional[FakeTensor]],
|
||||
bw_outs: list[Optional[FakeTensor]],
|
||||
saved_tensors: list[FakeTensor],
|
||||
) -> list[int]:
|
||||
"""
|
||||
Checks if the saved tensors are donated buffers, which means a saved tensor is not
|
||||
an alias of any tensors in fw_ins, user_fw_outs, and bw_outs.
|
||||
@ -317,7 +321,7 @@ def collect_bw_donated_buffer_idxs(
|
||||
fw_module: torch.fx.GraphModule,
|
||||
bw_module: torch.fx.GraphModule,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Collects backward donated buffer indexes from fw_module and bw_module.
|
||||
"""
|
||||
@ -372,7 +376,7 @@ def collect_bw_donated_buffer_idxs(
|
||||
|
||||
def aot_dispatch_autograd(
|
||||
flat_fn,
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
@ -541,7 +545,7 @@ def aot_dispatch_autograd(
|
||||
# and we will end up with a zero grad at x.
|
||||
# If we later backprop through the second output, this will also require backprop'ing through x.
|
||||
# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time.
|
||||
_indices_of_inps_to_detach: List[int] = []
|
||||
_indices_of_inps_to_detach: list[int] = []
|
||||
|
||||
# reversed() since we expect output at end of graph
|
||||
bw_output = next(reversed(bw_module.graph.find_nodes(op="output")))
|
||||
@ -888,7 +892,7 @@ def aot_dispatch_autograd(
|
||||
)
|
||||
|
||||
if config.debug_assert:
|
||||
flat_requires_grad: List[Optional[bool]] = [
|
||||
flat_requires_grad: list[Optional[bool]] = [
|
||||
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
|
||||
]
|
||||
compiled_fn = DebugAssertWrapper(
|
||||
|
@ -6,14 +6,13 @@ compilation, capturing user-friendly tracebacks, and debug messages.
|
||||
|
||||
import collections
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
|
||||
# This is a list since looking forward, we can have this arbitrarily nested.
|
||||
graph_being_compiled: List[str] = []
|
||||
graph_being_compiled: list[str] = []
|
||||
# TODO: It would be nice to reset the numbering every time aot_id goes
|
||||
# up, but this is annoying to do right now (because we don't know if
|
||||
# an aot_id will come back from the dead), so right now this also happens
|
||||
@ -28,7 +27,7 @@ def set_model_name(name):
|
||||
model_name = name
|
||||
|
||||
|
||||
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
|
||||
def get_aot_compilation_context() -> tuple[list[str], str, int]:
|
||||
return list(graph_being_compiled), model_name, nth_graph
|
||||
|
||||
|
||||
@ -70,7 +69,7 @@ def track_graph_compiling(aot_config, graph_name):
|
||||
callback_set = False
|
||||
|
||||
|
||||
def setup_stacktrace_preservation_hooks(roots: List):
|
||||
def setup_stacktrace_preservation_hooks(roots: list):
|
||||
def iter_graph(roots):
|
||||
if not roots:
|
||||
return
|
||||
|
@ -13,7 +13,7 @@ import pprint
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
@ -68,6 +68,10 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
zip = strict_zip
|
||||
|
||||
|
||||
@ -86,11 +90,11 @@ class CompilerWrapper:
|
||||
def pre_compile(
|
||||
self,
|
||||
flat_fn,
|
||||
flat_args: List[Tensor],
|
||||
flat_args: list[Tensor],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
|
||||
"""
|
||||
Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs.
|
||||
Args:
|
||||
@ -129,7 +133,7 @@ class CompilerWrapper:
|
||||
# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
|
||||
@dataclass
|
||||
class RuntimeWrapper(CompilerWrapper):
|
||||
indices_of_inps_to_detach: List[int]
|
||||
indices_of_inps_to_detach: list[int]
|
||||
trace_joint: bool
|
||||
disable_amp: bool
|
||||
|
||||
@ -244,7 +248,7 @@ def _create_runtime_wrapper(
|
||||
compiled_fn,
|
||||
*,
|
||||
runtime_metadata: ViewAndMutationMeta,
|
||||
indices_of_inps_to_detach: List[int],
|
||||
indices_of_inps_to_detach: list[int],
|
||||
trace_joint: bool,
|
||||
keep_input_mutations: bool,
|
||||
disable_amp: bool,
|
||||
@ -282,7 +286,7 @@ def _create_runtime_wrapper(
|
||||
for info in runtime_metadata.output_info
|
||||
)
|
||||
|
||||
def runtime_wrapper(args: List[Any]):
|
||||
def runtime_wrapper(args: list[Any]):
|
||||
# stash a ref to each input tensor we plan to use after the compiled function
|
||||
orig_inputs = {i: args[i] for i in epilogue_args_idx}
|
||||
|
||||
@ -454,7 +458,7 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
|
||||
aot_config,
|
||||
*,
|
||||
fw_metadata,
|
||||
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
|
||||
if config.functionalize_rng_ops:
|
||||
# Update example inputs for the fw_compiler
|
||||
fake_mode = detect_fake_mode()
|
||||
@ -473,7 +477,7 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
|
||||
runtime_metadata: ViewAndMutationMeta,
|
||||
):
|
||||
@wraps(compiled_fn)
|
||||
def wrapper(runtime_args: List[Any]):
|
||||
def wrapper(runtime_args: list[Any]):
|
||||
if runtime_metadata.is_rng_op_functionalized:
|
||||
# Add the seed and offset to args
|
||||
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
|
||||
@ -513,10 +517,10 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper):
|
||||
|
||||
@dataclass
|
||||
class FakifiedOutWrapper(CompilerWrapper):
|
||||
out_metas: List[torch.Tensor] = field(default_factory=list)
|
||||
out_metas: list[torch.Tensor] = field(default_factory=list)
|
||||
# TracingContext.fwd_output_strides
|
||||
# Generated from actually doing compile
|
||||
fwd_output_strides: Optional[List[List[int]]] = None
|
||||
fwd_output_strides: Optional[list[list[int]]] = None
|
||||
needs_post_compile: bool = True
|
||||
|
||||
def pre_compile(
|
||||
@ -526,7 +530,7 @@ class FakifiedOutWrapper(CompilerWrapper):
|
||||
aot_config,
|
||||
*,
|
||||
fw_metadata,
|
||||
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
if tracing_context and tracing_context.fakify_first_call:
|
||||
self.out_metas = [
|
||||
@ -598,7 +602,7 @@ class AOTDispatchSubclassWrapper(CompilerWrapper):
|
||||
def pre_compile(
|
||||
self,
|
||||
flat_fn,
|
||||
flat_args: List[Tensor],
|
||||
flat_args: list[Tensor],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
@ -626,7 +630,7 @@ class AOTDispatchSubclassWrapper(CompilerWrapper):
|
||||
subclass_metas = runtime_metadata.subclass_fw_graph_out_meta
|
||||
|
||||
@wraps(compiled_fn)
|
||||
def inner_fn(args: List[Any]):
|
||||
def inner_fn(args: list[Any]):
|
||||
unwrapped_args = runtime_unwrap_tensor_subclasses(
|
||||
args,
|
||||
subclass_metas=runtime_metadata.subclass_inp_meta,
|
||||
@ -661,7 +665,7 @@ class EffectTokensWrapper(CompilerWrapper):
|
||||
num_tokens = len(runtime_metadata.tokens)
|
||||
|
||||
@wraps(compiled_fn)
|
||||
def inner_fn(args: List[Any]):
|
||||
def inner_fn(args: list[Any]):
|
||||
if num_tokens > 0:
|
||||
# Pass in forward effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||
old_args = args
|
||||
@ -764,9 +768,9 @@ class EffectTokensWrapper(CompilerWrapper):
|
||||
#
|
||||
@dataclass
|
||||
class AOTDedupeWrapper(CompilerWrapper):
|
||||
keep_arg_mask: List[bool] = field(default_factory=list)
|
||||
add_dupe_map: List[int] = field(default_factory=list)
|
||||
old_input_metadata: List[InputAliasInfo] = field(default_factory=list)
|
||||
keep_arg_mask: list[bool] = field(default_factory=list)
|
||||
add_dupe_map: list[int] = field(default_factory=list)
|
||||
old_input_metadata: list[InputAliasInfo] = field(default_factory=list)
|
||||
needs_post_compile: bool = True
|
||||
|
||||
# NB: Hot path, avoid set lookups here
|
||||
@ -780,11 +784,11 @@ class AOTDedupeWrapper(CompilerWrapper):
|
||||
def pre_compile(
|
||||
self,
|
||||
flat_fn,
|
||||
flat_args: List[Tensor],
|
||||
flat_args: list[Tensor],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
|
||||
# Use information about whether or not flat_fn mutates its arguments
|
||||
# or not to handle dupe args
|
||||
|
||||
@ -869,10 +873,10 @@ class AOTDedupeWrapper(CompilerWrapper):
|
||||
# ]
|
||||
# keep_arg_mask = [True, True, False, True]
|
||||
|
||||
seen_args: Dict[Tensor, int] = {}
|
||||
seen_args: dict[Tensor, int] = {}
|
||||
# Implicitly map duped arg position (list index) to de-duped arg position
|
||||
keep_arg_mask: List[bool] = []
|
||||
add_dupe_map: List[int] = []
|
||||
keep_arg_mask: list[bool] = []
|
||||
add_dupe_map: list[int] = []
|
||||
duped_arg_len = len(flat_args)
|
||||
|
||||
j = 0 # index into deduped_flat_args
|
||||
@ -950,7 +954,7 @@ class AOTDedupeWrapper(CompilerWrapper):
|
||||
return compiled_fn
|
||||
|
||||
@wraps(compiled_fn)
|
||||
def wrapped_compiled_fn(args: List[Any]):
|
||||
def wrapped_compiled_fn(args: list[Any]):
|
||||
deduped_args = self.remove_dupe_args(args)
|
||||
args.clear()
|
||||
return compiled_fn(deduped_args)
|
||||
@ -966,7 +970,7 @@ class AOTDedupeWrapper(CompilerWrapper):
|
||||
def debugged_compiled_fn(args):
|
||||
# Test that the computed remove/add arg functions are an inverse
|
||||
new_args = self.add_dupe_args(self.remove_dupe_args(args))
|
||||
seen: Dict[Any, None] = {}
|
||||
seen: dict[Any, None] = {}
|
||||
for i, (x, y) in enumerate(zip(new_args, args)):
|
||||
seen[y] = None
|
||||
assert x is y, format_guard_bug_msg(
|
||||
@ -1008,16 +1012,16 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
|
||||
# the synthetic base code prohibits more cases in the autograd case than the inference case.
|
||||
trace_joint: bool # TODO: refactor trace_joint
|
||||
needs_post_compile: bool = True
|
||||
aliased_arg_idx_with_metadata_mutations: List[int] = field(default_factory=list)
|
||||
aliased_arg_idx_with_metadata_mutations: list[int] = field(default_factory=list)
|
||||
|
||||
def pre_compile(
|
||||
self,
|
||||
flat_fn,
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
|
||||
is_inference = not self.trace_joint
|
||||
flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
|
||||
aot_config,
|
||||
@ -1069,7 +1073,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
|
||||
)
|
||||
replay_views = config.view_replay_for_aliased_outputs
|
||||
|
||||
def _unpack_synthetic_bases(primals: Tuple[Any, ...]) -> List[Any]:
|
||||
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
|
||||
f_args_inner = []
|
||||
for inner_idx_or_tuple in synthetic_base_info:
|
||||
if isinstance(inner_idx_or_tuple, int):
|
||||
@ -1249,12 +1253,12 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
|
||||
# f(c_base, b_base, a, d)
|
||||
def merge_view_inputs(
|
||||
aot_config: AOTConfig,
|
||||
fwd_inputs: List[Any],
|
||||
mutated_input_info: List[InputAliasInfo],
|
||||
fwd_inputs: list[Any],
|
||||
mutated_input_info: list[InputAliasInfo],
|
||||
*,
|
||||
# The autograd case currently has more restrictions than the inference case.
|
||||
is_inference: bool,
|
||||
) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
|
||||
) -> tuple[list[Any], Optional[list[Union[int, tuple[int, torch.Tensor]]]]]:
|
||||
def _are_differentiable_views(view1, view2):
|
||||
if view1 is view2:
|
||||
return True
|
||||
@ -1278,7 +1282,7 @@ def merge_view_inputs(
|
||||
# Return early when there are no mutations.
|
||||
return fwd_inputs, None
|
||||
|
||||
storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
|
||||
storage_ref_to_idx: dict[StorageWeakRef, list[int]] = collections.defaultdict(list)
|
||||
base_args = []
|
||||
other_args = []
|
||||
for i, inpt in enumerate(fwd_inputs):
|
||||
@ -1293,7 +1297,7 @@ def merge_view_inputs(
|
||||
# - another int (corresponding to the index in the argument list of the element from the outer calling convention)
|
||||
# - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
|
||||
# idx corresponds to which synthetic base from the outer calling context to view
|
||||
inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
|
||||
inner_calling_convention_meta: dict[int, Union[int, tuple[int, torch.Tensor]]] = {}
|
||||
for aliased_input_indices in storage_ref_to_idx.values():
|
||||
if len(aliased_input_indices) <= 1 or not any(
|
||||
# We only care about mutations that affect all aliases,
|
||||
@ -1429,8 +1433,8 @@ def merge_view_inputs(
|
||||
old_idx = arg_to_old_idx_map[make_hashable(other_arg)]
|
||||
inner_calling_convention_meta[old_idx] = new_idx
|
||||
# post process into a list
|
||||
post_processed_calling_convention_meta: List[
|
||||
Union[int, Tuple[int, torch.Tensor]]
|
||||
post_processed_calling_convention_meta: list[
|
||||
Union[int, tuple[int, torch.Tensor]]
|
||||
] = [-1 for _ in range(len(inner_calling_convention_meta))]
|
||||
for k, v in inner_calling_convention_meta.items():
|
||||
post_processed_calling_convention_meta[k] = v
|
||||
@ -1443,7 +1447,7 @@ def merge_view_inputs(
|
||||
@dataclass
|
||||
class AutogradLazyBackwardCompileInfo:
|
||||
bw_module: Callable
|
||||
placeholder_list: List[Any]
|
||||
placeholder_list: list[Any]
|
||||
saved_context: Optional[TracingContext]
|
||||
saved_compile_context: Optional[CompileContext]
|
||||
|
||||
@ -1782,9 +1786,9 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
compiled_bw_func, # bw_module after compilation + wrappers
|
||||
maybe_subclass_meta: Optional[SubclassMeta],
|
||||
num_symints_saved_for_bw_: int,
|
||||
backward_state_indices: List[int],
|
||||
backward_state_indices: list[int],
|
||||
disable_amp: bool,
|
||||
indices_of_inps_to_detach: List[int],
|
||||
indices_of_inps_to_detach: list[int],
|
||||
lazy_backward_info: Optional[AutogradLazyBackwardCompileInfo],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
@ -2099,7 +2103,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
|
||||
@dataclass
|
||||
class DebugAssertWrapper(CompilerWrapper):
|
||||
flat_requires_grad: List[Optional[bool]] = field(default_factory=list)
|
||||
flat_requires_grad: list[Optional[bool]] = field(default_factory=list)
|
||||
|
||||
def post_compile(
|
||||
self,
|
||||
@ -2109,7 +2113,7 @@ class DebugAssertWrapper(CompilerWrapper):
|
||||
runtime_metadata: ViewAndMutationMeta,
|
||||
):
|
||||
@wraps(compiled_fn)
|
||||
def debug_compiled_function(args: List[Any]):
|
||||
def debug_compiled_function(args: list[Any]):
|
||||
# TODO: Check aliasing relationships
|
||||
# TODO: Check strides for metadata mutation
|
||||
# (NB: ideally, this logic is factored out of this function and
|
||||
@ -2135,13 +2139,13 @@ class DebugAssertWrapper(CompilerWrapper):
|
||||
|
||||
|
||||
def pre_compile(
|
||||
wrappers: List[CompilerWrapper],
|
||||
wrappers: list[CompilerWrapper],
|
||||
flat_fn: Callable,
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[Callable, List[Tensor], ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]:
|
||||
"""
|
||||
Runs a sequence of wrappers on the given function and arguments.
|
||||
Mutates wrappers in place.
|
||||
@ -2154,12 +2158,12 @@ def pre_compile(
|
||||
|
||||
|
||||
def post_compile(
|
||||
wrappers: List[CompilerWrapper],
|
||||
wrappers: list[CompilerWrapper],
|
||||
compiled_fn: Callable,
|
||||
aot_config: AOTConfig,
|
||||
*,
|
||||
runtime_metadata: ViewAndMutationMeta,
|
||||
) -> Tuple[Callable, ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, ViewAndMutationMeta]:
|
||||
"""
|
||||
Runs a sequence of wrappers on the given function. Should be called after pre_compile()
|
||||
"""
|
||||
|
@ -7,9 +7,10 @@ input/output types, metadata, config, function signatures etc.
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Set, Union
|
||||
from typing import Any, Callable, NewType, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -89,7 +90,7 @@ class OutputAliasInfo:
|
||||
# here, this refers to the index of the *direct* traced
|
||||
base_idx: Optional[int]
|
||||
# If it is a Tensor, what the dynamic dims are (otherwise is None)
|
||||
dynamic_dims: Optional[Set[int]]
|
||||
dynamic_dims: Optional[set[int]]
|
||||
# requires_grad
|
||||
requires_grad: bool
|
||||
# FunctionalTensorWrapper that represents this output.
|
||||
@ -190,7 +191,7 @@ class SubclassCreationMeta:
|
||||
# meta and attrs are produced by the subclass's __tensor_flatten__.
|
||||
# We need to keep them around along with outer_size / outer_stride to plumb them
|
||||
# into __tensor_unflatten__
|
||||
attrs: Dict[str, Union["SubclassCreationMeta", PlainTensorMeta]]
|
||||
attrs: dict[str, Union["SubclassCreationMeta", PlainTensorMeta]]
|
||||
outer_size: Iterable[Union[None, int, torch.SymInt]]
|
||||
outer_stride: Iterable[Union[None, int, torch.SymInt]]
|
||||
meta: Any
|
||||
@ -318,11 +319,11 @@ class SubclassCreationMeta:
|
||||
class ViewAndMutationMeta:
|
||||
# length = # user inputs
|
||||
# This gives us info about every input, and what sort of mutation happened to it (if any)
|
||||
input_info: List[InputAliasInfo]
|
||||
input_info: list[InputAliasInfo]
|
||||
|
||||
# length = # user outputs
|
||||
# This gives us info about every output (mostly around whether it aliases other tensors)
|
||||
output_info: List[OutputAliasInfo]
|
||||
output_info: list[OutputAliasInfo]
|
||||
|
||||
# length = the number of intermediate bases appended as outputs to the end of the forward graph.
|
||||
# Note: this is not necessarily the same thing as:
|
||||
@ -341,7 +342,7 @@ class ViewAndMutationMeta:
|
||||
# Their only use today is to pass them as a best-guess for tangents when tracing the joint.
|
||||
# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
|
||||
# pass once, and re-use the output throughout AOTAutograd
|
||||
traced_tangents: List[Any]
|
||||
traced_tangents: list[Any]
|
||||
|
||||
# Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs
|
||||
# They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors,
|
||||
@ -355,7 +356,7 @@ class ViewAndMutationMeta:
|
||||
# inputs[3] and inputs[4] of the plain-tensor graph".
|
||||
|
||||
# length = # user inputs
|
||||
subclass_inp_meta: List[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
subclass_inp_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
# So, the full set of outputs to the forward graph looks something like:
|
||||
# (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors)
|
||||
# where the first 3 of those 4 can be subclasses
|
||||
@ -363,9 +364,9 @@ class ViewAndMutationMeta:
|
||||
# and not user visible, so there's no point in wrapping/unwrapping them at runtime).
|
||||
# This list contains subclass information on all of the fw graph outputs
|
||||
# except for saved_for_bw_tensors.
|
||||
subclass_fw_graph_out_meta: List[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
subclass_fw_graph_out_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
# length = # backward graph inputs
|
||||
subclass_tangent_meta: List[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
subclass_tangent_meta: list[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
# TODO: we should kill this
|
||||
# (need to default it to not break internal)
|
||||
is_train: bool = False
|
||||
@ -375,7 +376,7 @@ class ViewAndMutationMeta:
|
||||
# At runtime, we don't keep the traced_tangents around since they're not serializable.
|
||||
# Instead, we keep any necessary subclass metadata necessary about each traced_tangent.
|
||||
# This list is generated after calling make_runtime_safe().
|
||||
traced_tangent_metas: Optional[List[Any]] = None
|
||||
traced_tangent_metas: Optional[list[Any]] = None
|
||||
|
||||
num_symints_saved_for_bw: Optional[int] = None
|
||||
|
||||
@ -393,12 +394,12 @@ class ViewAndMutationMeta:
|
||||
deterministic: Optional[bool] = None
|
||||
|
||||
# Keeps track of which input indices store parameters (which we will treat as static)
|
||||
static_input_indices: List[int] = field(default_factory=list)
|
||||
static_input_indices: list[int] = field(default_factory=list)
|
||||
|
||||
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
|
||||
# side-effectful operators, FunctionalTensorMode will populate this
|
||||
# dictionary telling us how many tokens we will need during tracing.
|
||||
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
|
||||
tokens: dict[Any, torch.Tensor] = field(default_factory=dict)
|
||||
|
||||
# Only filled in if/when we trace the joint function
|
||||
# If an input requires grad and is mutated in the backward, it is only safe to keep the mutation
|
||||
@ -406,14 +407,14 @@ class ViewAndMutationMeta:
|
||||
# (grad mode is disabled by default when users run the backward, but can be turned on with create_graph=True)
|
||||
# At runtime during the backward, we use this list of indices to error properly if we find out
|
||||
# that it was not safe to include a backward mutation in the graph.
|
||||
indices_of_inputs_that_requires_grad_with_mutations_in_bw: List[int] = field(
|
||||
indices_of_inputs_that_requires_grad_with_mutations_in_bw: list[int] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
# Indexes of saved tensors which are donated buffer.
|
||||
# Donated buffer means the tensor is not alias of any forward user input, forward user output,
|
||||
# and backward output.
|
||||
bw_donated_idxs: Optional[List[int]] = None
|
||||
bw_donated_idxs: Optional[list[int]] = None
|
||||
|
||||
# Number of tokens used in backward, appended at the end of backward outputs.
|
||||
# Filled after tracing joint function.
|
||||
@ -670,7 +671,7 @@ class SubclassMeta:
|
||||
#
|
||||
# Optional field because we don't compute for inference graphs
|
||||
grad_input_metas: Optional[
|
||||
List[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
list[Union[PlainTensorMeta, SubclassCreationMeta]]
|
||||
] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -704,8 +705,8 @@ class BackwardSignature:
|
||||
Each string name is the `node.name` of the corresponding node in the fx graph.
|
||||
"""
|
||||
|
||||
gradients_to_parameters: Dict[str, str]
|
||||
gradients_to_user_inputs: Dict[str, str]
|
||||
gradients_to_parameters: dict[str, str]
|
||||
gradients_to_user_inputs: dict[str, str]
|
||||
loss_output: str
|
||||
|
||||
|
||||
@ -732,29 +733,29 @@ class GraphSignature:
|
||||
a signature on the backward section of the joint graph.
|
||||
"""
|
||||
|
||||
parameters: List[FQN]
|
||||
buffers: List[FQN]
|
||||
parameters: list[FQN]
|
||||
buffers: list[FQN]
|
||||
|
||||
user_inputs: List[GraphInputName]
|
||||
user_outputs: List[GraphOutputName]
|
||||
inputs_to_parameters: Dict[GraphInputName, FQN]
|
||||
inputs_to_buffers: Dict[GraphInputName, FQN]
|
||||
user_inputs: list[GraphInputName]
|
||||
user_outputs: list[GraphOutputName]
|
||||
inputs_to_parameters: dict[GraphInputName, FQN]
|
||||
inputs_to_buffers: dict[GraphInputName, FQN]
|
||||
|
||||
# If the user's module mutates a buffer,
|
||||
# it's represented in the graph as an extra graph output.
|
||||
# This dict is a mapping from
|
||||
# "graph outputs that correspond to updated buffers"
|
||||
# to the FQN names of those mutated buffers.
|
||||
buffers_to_mutate: Dict[GraphOutputName, FQN]
|
||||
user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName]
|
||||
buffers_to_mutate: dict[GraphOutputName, FQN]
|
||||
user_inputs_to_mutate: dict[GraphOutputName, GraphInputName]
|
||||
|
||||
in_spec: pytree.TreeSpec
|
||||
out_spec: pytree.TreeSpec
|
||||
|
||||
backward_signature: Optional[BackwardSignature]
|
||||
|
||||
input_tokens: List[GraphInputName]
|
||||
output_tokens: List[GraphOutputName]
|
||||
input_tokens: list[GraphInputName]
|
||||
output_tokens: list[GraphOutputName]
|
||||
|
||||
@classmethod
|
||||
def from_tracing_metadata(
|
||||
@ -762,11 +763,11 @@ class GraphSignature:
|
||||
*,
|
||||
in_spec: pytree.TreeSpec,
|
||||
out_spec: pytree.TreeSpec,
|
||||
graph_input_names: List[str],
|
||||
graph_output_names: List[str],
|
||||
graph_input_names: list[str],
|
||||
graph_output_names: list[str],
|
||||
view_mutation_metadata: ViewAndMutationMeta,
|
||||
named_parameters: List[str],
|
||||
named_buffers: List[str],
|
||||
named_parameters: list[str],
|
||||
named_buffers: list[str],
|
||||
num_user_inputs: int,
|
||||
num_user_outputs: int,
|
||||
loss_index: Optional[int],
|
||||
@ -877,15 +878,15 @@ class AOTConfig:
|
||||
fw_compiler: Callable
|
||||
bw_compiler: Callable
|
||||
partition_fn: Callable
|
||||
decompositions: Dict[OpOverload, Callable]
|
||||
decompositions: dict[OpOverload, Callable]
|
||||
num_params_buffers: int
|
||||
aot_id: int
|
||||
keep_inference_input_mutations: bool
|
||||
is_export: bool = False
|
||||
no_tangents: bool = False
|
||||
dynamic_shapes: bool = False
|
||||
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
|
||||
static_input_indices: Optional[List[int]] = None
|
||||
aot_autograd_arg_pos_to_source: Optional[list[Source]] = None
|
||||
static_input_indices: Optional[list[int]] = None
|
||||
inference_compiler: Optional[Callable] = None
|
||||
enable_log: bool = True
|
||||
# this is always false outside of export.
|
||||
|
@ -1,5 +1,5 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
@ -13,13 +13,13 @@ class SubclassCreationMeta:
|
||||
start_idx: int
|
||||
num_tensors: int
|
||||
class_type: Any
|
||||
attrs: Dict[str, "SubclassCreationMeta"]
|
||||
attrs: dict[str, "SubclassCreationMeta"]
|
||||
metadata: Any
|
||||
|
||||
|
||||
class UnwrapTensorSubclass(torch.nn.Module):
|
||||
def forward(self, *tensors) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||
todo: List[torch.Tensor] = list(tensors)
|
||||
todo: list[torch.Tensor] = list(tensors)
|
||||
|
||||
def _unwrap_tensor_subclasses(subclass_meta, tensors, offset): # type: ignore[no-untyped-def]
|
||||
if subclass_meta is None:
|
||||
@ -35,9 +35,9 @@ class UnwrapTensorSubclass(torch.nn.Module):
|
||||
|
||||
return _unwrap_tensor_subclasses(self.subclass_meta, todo, 0)[0]
|
||||
|
||||
def right_inverse(self, tensor: torch.Tensor) -> List[torch.Tensor]:
|
||||
def right_inverse(self, tensor: torch.Tensor) -> list[torch.Tensor]:
|
||||
assert type(tensor) is not torch.Tensor
|
||||
plain_tensors: List[torch.Tensor] = []
|
||||
plain_tensors: list[torch.Tensor] = []
|
||||
|
||||
def _create_subclass_meta(tensor, idx, plain_tensor_container): # type: ignore[no-untyped-def]
|
||||
if type(tensor) is torch.Tensor:
|
||||
@ -79,7 +79,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
|
||||
becomes: {"parametrizations.p2.original0": torch.Tensor, "parametrizations.p2.original1": torch.Tensor}
|
||||
|
||||
"""
|
||||
name_param: List[Tuple[str, torch.nn.Parameter]] = list(
|
||||
name_param: list[tuple[str, torch.nn.Parameter]] = list(
|
||||
module.named_parameters(recurse=False)
|
||||
)
|
||||
for name, param in name_param:
|
||||
|
@ -7,7 +7,8 @@ and this includes tensor subclasses that implement __torch_dispatch__.
|
||||
|
||||
import collections
|
||||
import typing
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -59,7 +60,7 @@ def maybe_suggest_memory_format(
|
||||
|
||||
def get_subclass_typing_container(
|
||||
tensor_subclass: torch.Tensor,
|
||||
) -> Dict[Type[torch.Tensor], List[Type[torch.Tensor]]]:
|
||||
) -> dict[type[torch.Tensor], list[type[torch.Tensor]]]:
|
||||
"""
|
||||
Given a subclass, returns a recursive dictionary mapping each
|
||||
inner tensors to its' subclass types.
|
||||
@ -74,7 +75,7 @@ def get_subclass_typing_container(
|
||||
inner_tensor = getattr(tensor_subclass, key)
|
||||
_get_types_for_subclass(inner_tensor)
|
||||
|
||||
tracker: Dict[Any, List[Any]] = collections.defaultdict(list)
|
||||
tracker: dict[Any, list[Any]] = collections.defaultdict(list)
|
||||
_get_types_for_subclass(tensor_subclass)
|
||||
return tracker
|
||||
|
||||
@ -134,13 +135,13 @@ def create_subclass_metadata(
|
||||
# computes metadata about "how to reconstruct the current list of subclasses,
|
||||
# if we were given their flattened dense tensors instead"
|
||||
def create_subclass_meta(
|
||||
curr_args: Union[List[Any], Tuple[Any, ...]],
|
||||
curr_args: Union[list[Any], tuple[Any, ...]],
|
||||
*,
|
||||
count_symints: bool = True,
|
||||
with_memory_format: bool = False,
|
||||
) -> List[Union[PlainTensorMeta, SubclassCreationMeta]]:
|
||||
) -> list[Union[PlainTensorMeta, SubclassCreationMeta]]:
|
||||
idx = 0
|
||||
infos: List[Union[PlainTensorMeta, SubclassCreationMeta]] = []
|
||||
infos: list[Union[PlainTensorMeta, SubclassCreationMeta]] = []
|
||||
for a in curr_args:
|
||||
if is_traceable_wrapper_subclass(a):
|
||||
assert isinstance(a, Tensor)
|
||||
@ -173,7 +174,7 @@ def filter_symints(lst: Iterable[Union[int, SymInt]]):
|
||||
return [s for s in lst if symint_check(s)]
|
||||
|
||||
|
||||
def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> List[bool]:
|
||||
def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> list[bool]:
|
||||
# Non-nested symints are replaced with None in `make_runtime_safe()`
|
||||
return [s is None for s in lst]
|
||||
|
||||
@ -192,7 +193,7 @@ def compute_symint_placeholders(lst: Iterable[Union[None, int, SymInt]]) -> List
|
||||
# primals (but not tangents) on entry to the forward. See the runtime version of
|
||||
# this function below.
|
||||
def unwrap_tensor_subclasses(
|
||||
wrapped_args: List[Union[Tensor, int]],
|
||||
wrapped_args: list[Union[Tensor, int]],
|
||||
*,
|
||||
append_symints: bool,
|
||||
):
|
||||
@ -213,7 +214,7 @@ def unwrap_tensor_subclasses(
|
||||
out.extend(filter_symints(t.size()))
|
||||
out.extend(filter_symints(t.stride()))
|
||||
|
||||
xs_inner: List[Union[int, Tensor, SymInt]] = []
|
||||
xs_inner: list[Union[int, Tensor, SymInt]] = []
|
||||
|
||||
for x in wrapped_args:
|
||||
flatten_subclass(typing.cast(Tensor, x), out=xs_inner)
|
||||
@ -224,10 +225,10 @@ def unwrap_tensor_subclasses(
|
||||
# subclass_metas is needed at runtime to compute which indices are symints in
|
||||
# the outer_size/outer_stride
|
||||
def runtime_unwrap_tensor_subclasses(
|
||||
wrapped_args: List[Union[Tensor, int]],
|
||||
wrapped_args: list[Union[Tensor, int]],
|
||||
*,
|
||||
append_symints: bool,
|
||||
subclass_metas: Optional[List[Union[PlainTensorMeta, SubclassCreationMeta]]] = None,
|
||||
subclass_metas: Optional[list[Union[PlainTensorMeta, SubclassCreationMeta]]] = None,
|
||||
):
|
||||
def flatten_subclass(x: Tensor, meta: Optional[SubclassCreationMeta], *, out):
|
||||
if not is_traceable_wrapper_subclass(x):
|
||||
@ -262,7 +263,7 @@ def runtime_unwrap_tensor_subclasses(
|
||||
)
|
||||
return out
|
||||
|
||||
xs_inner: List[Union[int, Tensor, SymInt]] = []
|
||||
xs_inner: list[Union[int, Tensor, SymInt]] = []
|
||||
|
||||
if append_symints:
|
||||
assert subclass_metas is not None
|
||||
@ -319,13 +320,13 @@ def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
|
||||
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
|
||||
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
|
||||
def wrap_tensor_subclasses(
|
||||
unwrapped_args: Union[Tuple[Any, ...], List[Any]],
|
||||
unwrapped_args: Union[tuple[Any, ...], list[Any]],
|
||||
*,
|
||||
subclass_metas: List[Union[PlainTensorMeta, SubclassCreationMeta]],
|
||||
subclass_metas: list[Union[PlainTensorMeta, SubclassCreationMeta]],
|
||||
num_fw_outs_saved_for_bw: Optional[int] = None,
|
||||
included_subclass_symints: bool = False,
|
||||
is_runtime: bool = False,
|
||||
) -> Tuple[Any, ...]:
|
||||
) -> tuple[Any, ...]:
|
||||
wrapped_args = []
|
||||
num_args_tallied = 0
|
||||
for subclass_meta in subclass_metas:
|
||||
@ -386,7 +387,7 @@ def wrap_tensor_subclasses(
|
||||
# - when is_joint_structure is False, args is [*primals]
|
||||
def wrap_tensor_subclasses_maybe_joint(
|
||||
unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta
|
||||
) -> Union[Tuple[Any, ...], List[Any]]:
|
||||
) -> Union[tuple[Any, ...], list[Any]]:
|
||||
# Since this function is re-used for both inference and joint graphs,
|
||||
if is_joint_structure:
|
||||
assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2
|
||||
@ -417,7 +418,7 @@ def wrap_tensor_subclasses_maybe_joint(
|
||||
def compute_inner_mutated_inp_indices_from_subclass_meta(
|
||||
fw_metadata: ViewAndMutationMeta,
|
||||
inner_metadata: ViewAndMutationMeta,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
# Note: [Recomputing subclass mutation handling]
|
||||
#
|
||||
# Generally, if a subclass requires grad, its components will not require grad.
|
||||
|
@ -14,7 +14,7 @@ It does so by:
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -190,7 +190,7 @@ def fn_prepped_for_autograd(
|
||||
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
|
||||
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
|
||||
def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||
def inner_fn(primals: List[Any], tangents: List[Any]):
|
||||
def inner_fn(primals: list[Any], tangents: list[Any]):
|
||||
outs, tangent_mask = fn(*primals)
|
||||
|
||||
assert len(tangent_mask) == len(outs)
|
||||
@ -232,7 +232,7 @@ def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
|
||||
|
||||
if config.functionalize_rng_ops:
|
||||
PhiloxStateTracker.mark_beginning_of_backward()
|
||||
backward_out: Tuple[Tensor, ...] = ()
|
||||
backward_out: tuple[Tensor, ...] = ()
|
||||
# Call the backwards pass
|
||||
if grad_primals:
|
||||
functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
|
||||
@ -733,7 +733,7 @@ def handle_effect_tokens_fn(
|
||||
# In particular, we need this to tell the partitioner how many dense forward outputs there are.
|
||||
def aot_dispatch_subclass(
|
||||
flat_fn_maybe_joint,
|
||||
args: List[Any],
|
||||
args: list[Any],
|
||||
*,
|
||||
is_joint_structure: bool,
|
||||
meta: ViewAndMutationMeta,
|
||||
|
@ -8,7 +8,7 @@ import operator
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -114,7 +114,7 @@ def make_boxed_compiler(compiler):
|
||||
|
||||
|
||||
def call_func_at_runtime_with_args(
|
||||
f, args: Union[Tuple[Any], List[Any]], steal_args=False, disable_amp=False
|
||||
f, args: Union[tuple[Any], list[Any]], steal_args=False, disable_amp=False
|
||||
):
|
||||
if not steal_args:
|
||||
args = list(args)
|
||||
@ -156,7 +156,7 @@ class PytreeThunk:
|
||||
if self.spec.is_leaf():
|
||||
self.is_really_simple = True
|
||||
|
||||
def unflatten(self, x: List[Any]) -> Any:
|
||||
def unflatten(self, x: list[Any]) -> Any:
|
||||
if self.is_really_simple:
|
||||
return x[0]
|
||||
if self.is_simple:
|
||||
@ -168,7 +168,7 @@ class PytreeThunk:
|
||||
# Creates a function that returns flattened inputs and outputs
|
||||
# Also returns the output tree spec, which is needed to recover the "unflattened"
|
||||
# output tree structure later.
|
||||
def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]:
|
||||
def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThunk]:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# Save the args_spec for flat_tensor_args to unflatten while tracing
|
||||
|
@ -1,22 +1,10 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import itertools
|
||||
from collections.abc import KeysView, Sequence
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
KeysView,
|
||||
List,
|
||||
NewType,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, NewType, Optional, Protocol, TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -447,7 +435,7 @@ AOT_COUNTER = itertools.count()
|
||||
|
||||
aot_autograd_decompositions = {}
|
||||
|
||||
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
|
||||
FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any])
|
||||
|
||||
|
||||
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
|
||||
@ -477,7 +465,7 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_code_ty: Type[TOutputCode],
|
||||
output_code_ty: type[TOutputCode],
|
||||
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
||||
):
|
||||
self.output_code_ty = output_code_ty
|
||||
@ -492,7 +480,7 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
||||
|
||||
|
||||
def process_inputs(
|
||||
flat_args: List[Any],
|
||||
flat_args: list[Any],
|
||||
aot_config: AOTConfig,
|
||||
fake_mode: FakeTensorMode,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
@ -560,8 +548,8 @@ def process_inputs(
|
||||
|
||||
|
||||
def construct_fake_mode(
|
||||
flat_args: List[Any], aot_config: AOTConfig
|
||||
) -> Tuple[FakeTensorMode, Optional[ShapeEnv]]:
|
||||
flat_args: list[Any], aot_config: AOTConfig
|
||||
) -> tuple[FakeTensorMode, Optional[ShapeEnv]]:
|
||||
fake_mode = detect_fake_mode(flat_args)
|
||||
if fake_mode is None:
|
||||
shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
|
||||
@ -577,7 +565,7 @@ def create_aot_dispatcher_function(
|
||||
aot_config: AOTConfig,
|
||||
fake_mode: FakeTensorMode,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
) -> Tuple[Callable, ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, ViewAndMutationMeta]:
|
||||
with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True):
|
||||
return _create_aot_dispatcher_function(
|
||||
flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
|
||||
@ -590,7 +578,7 @@ def _create_aot_dispatcher_function(
|
||||
aot_config: AOTConfig,
|
||||
fake_mode: FakeTensorMode,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
) -> Tuple[Callable, ViewAndMutationMeta]:
|
||||
) -> tuple[Callable, ViewAndMutationMeta]:
|
||||
"""
|
||||
Traces the forward and backward graphs of the attr:`flat_fn` to generate a
|
||||
joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
|
||||
@ -843,7 +831,7 @@ def aot_function(
|
||||
fw_compiler: Callable,
|
||||
bw_compiler: Optional[Callable] = None,
|
||||
partition_fn: Callable = default_partition,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
num_params_buffers: int = 0,
|
||||
keep_inference_input_mutations: bool = False,
|
||||
inference_compiler: Optional[Callable] = None,
|
||||
@ -1008,7 +996,7 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
|
||||
def _try_get_metadata_from_dynamo(
|
||||
mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int
|
||||
) -> Tuple[Optional[List[torch._guards.Source]], List[int]]:
|
||||
) -> tuple[Optional[list[torch._guards.Source]], list[int]]:
|
||||
"""
|
||||
Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule.
|
||||
We first verify that `mod` does come from Dynamo, then we handle cases where
|
||||
@ -1074,7 +1062,7 @@ def aot_module_simplified(
|
||||
fw_compiler: AOTDispatchCompiler,
|
||||
bw_compiler: Optional[AOTDispatchCompiler] = None,
|
||||
partition_fn: Callable = default_partition,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
keep_inference_input_mutations=False,
|
||||
inference_compiler: Optional[AOTDispatchCompiler] = None,
|
||||
cudagraphs: Optional[BoxedBool] = None,
|
||||
@ -1186,7 +1174,7 @@ def aot_module_simplified(
|
||||
# the inputs so that they can be freed before the end of this scope.
|
||||
# For overhead reasons, this is not the default wrapper, see comment:
|
||||
# https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
|
||||
def boxed_forward(runtime_args: List[Any]):
|
||||
def boxed_forward(runtime_args: list[Any]):
|
||||
flat_args = []
|
||||
flat_args.extend(params_flat)
|
||||
flat_args.extend(runtime_args)
|
||||
@ -1204,7 +1192,7 @@ def aot_module_simplified(
|
||||
# historically returned a function that was not the boxed calling
|
||||
# convention. This should get fixed...
|
||||
# NB: GraphModule/nn.Module rely on the non-boxed calling convention here
|
||||
def forward(*runtime_args: Tuple[Any]):
|
||||
def forward(*runtime_args: tuple[Any]):
|
||||
full_args = []
|
||||
full_args.extend(params_flat)
|
||||
full_args.extend(runtime_args)
|
||||
@ -1222,7 +1210,7 @@ def aot_export_module(
|
||||
mod: nn.Module,
|
||||
args,
|
||||
*,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
# If true, we'll return a joint forward-backward graph,
|
||||
# As well as metadata on the loss + gradients in the backward.
|
||||
trace_joint: bool,
|
||||
@ -1233,7 +1221,7 @@ def aot_export_module(
|
||||
# If None, will be infered from inputs and mod.graph.nodes if mod is a graph module, but the inferred result might be wrong.
|
||||
dynamic_shapes: Optional[bool] = None,
|
||||
kwargs=None,
|
||||
) -> Tuple[torch.fx.GraphModule, GraphSignature]:
|
||||
) -> tuple[torch.fx.GraphModule, GraphSignature]:
|
||||
"""
|
||||
This function takes in a module, and returns:
|
||||
(1) an FX graph that can be exported
|
||||
@ -1434,7 +1422,7 @@ def aot_export_joint_simple(
|
||||
# it will assume that parms/buffers are static.
|
||||
# With the new inferred dynamic shapes API, maybe this doesn't matter?
|
||||
num_params_buffers: int = 0,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
A simplified version of export. Used by higher order operators.
|
||||
@ -1530,7 +1518,7 @@ def _aot_export_function(
|
||||
args,
|
||||
*,
|
||||
num_params_buffers: int = 0,
|
||||
decompositions: Optional[Dict] = None,
|
||||
decompositions: Optional[dict] = None,
|
||||
# If we're exporting a joint graph and we don't want any tangent inputs in the graph
|
||||
# (because we are backpropping through a scalar 1 loss),
|
||||
# we need to explicitly specify not to include tangents in the graph.
|
||||
@ -1543,7 +1531,7 @@ def _aot_export_function(
|
||||
# If None, `dynamic_shapes` will be infered from inputs, but the inferred result might be wrong.
|
||||
dynamic_shapes: Optional[bool] = None,
|
||||
kwargs=None,
|
||||
) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
|
||||
) -> tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
|
||||
kwargs = kwargs or {}
|
||||
|
||||
flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import NamedTuple, Tuple
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -580,7 +580,7 @@ def get_tangents_in_dims(input_dims, tangents):
|
||||
# Wraps a ctx object. Forwards all attr accesses to the underlying object
|
||||
# except for the attrs in _pt_attrs
|
||||
class WrappedCtx:
|
||||
_pt_reserved_attrs: Tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
|
||||
_pt_reserved_attrs: tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
|
||||
|
||||
def __init__(self, ctx):
|
||||
if not isinstance(ctx, WrappedCtx):
|
||||
|
@ -10,7 +10,7 @@ documentation.
|
||||
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch._functorch.apis as apis
|
||||
import torch._functorch.eager_transforms as _impl
|
||||
@ -98,7 +98,7 @@ def jvp(
|
||||
|
||||
def jacrev(
|
||||
func: Callable,
|
||||
argnums: Union[int, Tuple[int]] = 0,
|
||||
argnums: Union[int, tuple[int]] = 0,
|
||||
*,
|
||||
has_aux=False,
|
||||
chunk_size: Optional[int] = None,
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
import contextlib
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.autograd.forward_ad as fwAD
|
||||
@ -428,7 +428,7 @@ def error_if_complex(func_name, args, is_input):
|
||||
@exposed_in("torch.func")
|
||||
def jacrev(
|
||||
func: Callable,
|
||||
argnums: Union[int, Tuple[int]] = 0,
|
||||
argnums: Union[int, tuple[int]] = 0,
|
||||
*,
|
||||
has_aux=False,
|
||||
chunk_size: Optional[int] = None,
|
||||
@ -911,7 +911,7 @@ def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
|
||||
def assert_non_empty_tensor_output(output: list[Any], api: str) -> None:
|
||||
if (len(output) == 1 and output[0] is None) or len(output) < 1:
|
||||
raise RuntimeError(
|
||||
f"{api}: Expected f to be a function that has non-empty output (got output = {output})"
|
||||
@ -946,7 +946,7 @@ def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
|
||||
|
||||
|
||||
def assert_non_empty_list_of_tensors(
|
||||
output: List[torch.Tensor], api: str, argname: str
|
||||
output: list[torch.Tensor], api: str, argname: str
|
||||
) -> None:
|
||||
if len(output) == 0:
|
||||
raise RuntimeError(f"{api}: Expected {argname} to contain at least one Tensor.")
|
||||
@ -1676,7 +1676,7 @@ def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
|
||||
def linearize(func: Callable, *primals) -> tuple[Any, Callable]:
|
||||
"""
|
||||
Returns the value of ``func`` at ``primals`` and linear approximation
|
||||
at ``primals``.
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,9 +11,9 @@ from torch._functorch.utils import exposed_in
|
||||
@exposed_in("torch.func")
|
||||
def functional_call(
|
||||
module: "torch.nn.Module",
|
||||
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
|
||||
args: Optional[Union[Any, Tuple]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
parameter_and_buffer_dicts: Union[dict[str, Tensor], Sequence[dict[str, Tensor]]],
|
||||
args: Optional[Union[Any, tuple]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
tie_weights: bool = True,
|
||||
strict: bool = False,
|
||||
@ -126,7 +127,7 @@ def functional_call(
|
||||
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
|
||||
)
|
||||
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
|
||||
all_keys_counter: Dict[str, int] = {}
|
||||
all_keys_counter: dict[str, int] = {}
|
||||
for k in all_keys:
|
||||
v = all_keys_counter.get(k, 0)
|
||||
all_keys_counter[k] = v + 1
|
||||
@ -157,7 +158,7 @@ def functional_call(
|
||||
@exposed_in("torch.func")
|
||||
def stack_module_state(
|
||||
models: Union[Sequence[nn.Module], nn.ModuleList],
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""stack_module_state(models) -> params, buffers
|
||||
|
||||
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
|
||||
@ -238,7 +239,7 @@ def stack_module_state(
|
||||
|
||||
|
||||
def construct_stacked_leaf(
|
||||
tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
|
||||
tensors: Union[tuple[Tensor, ...], list[Tensor]], name: str
|
||||
) -> Tensor:
|
||||
all_requires_grad = all(t.requires_grad for t in tensors)
|
||||
none_requires_grad = all(not t.requires_grad for t in tensors)
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -22,8 +22,8 @@ is_tuple = object()
|
||||
|
||||
@dataclass
|
||||
class LoadTensorMeta:
|
||||
size: List[int]
|
||||
stride: List[int]
|
||||
size: list[int]
|
||||
stride: list[int]
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
|
||||
@ -164,7 +164,7 @@ def is_power_of_two(n):
|
||||
@dataclass
|
||||
class ReproState:
|
||||
graph: fx.Graph
|
||||
inps: List[torch.Tensor]
|
||||
inps: list[torch.Tensor]
|
||||
|
||||
def __post_init__(self):
|
||||
ph_nodes = get_placeholders(self.graph)
|
||||
|
@ -6,18 +6,8 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, NoReturn, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -41,9 +31,9 @@ def raise_parameter_tying_error() -> NoReturn:
|
||||
|
||||
|
||||
def create_names_map(
|
||||
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
|
||||
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
|
||||
) -> Dict[str, List[str]]:
|
||||
named_params: Union[dict[str, Tensor], Iterable[tuple[str, Tensor]]],
|
||||
tied_named_params: Union[dict[str, Tensor], Iterable[tuple[str, Tensor]]],
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
named_params is a dictionary of tensors: {'A': A, 'B': B}
|
||||
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
|
||||
@ -59,7 +49,7 @@ def create_names_map(
|
||||
tied_tensors_dict_keys = set(tied_named_params.keys())
|
||||
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
|
||||
|
||||
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
|
||||
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
|
||||
for key, tensor in named_params.items():
|
||||
tensor_to_mapping[tensor] = (key, [])
|
||||
for key, tensor in tied_named_params.items():
|
||||
@ -70,9 +60,9 @@ def create_names_map(
|
||||
|
||||
def _extract_members(
|
||||
mod: nn.Module,
|
||||
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
|
||||
named_members: Callable[..., Iterable[tuple[str, Tensor]]],
|
||||
subclass: Callable[[Tensor], Tensor],
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
|
||||
) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]:
|
||||
all_named_members = tuple(named_members(remove_duplicate=False))
|
||||
unique_named_members = tuple(named_members(remove_duplicate=True))
|
||||
names_map = create_names_map(unique_named_members, all_named_members)
|
||||
@ -95,7 +85,7 @@ def _extract_members(
|
||||
|
||||
def extract_weights(
|
||||
mod: nn.Module,
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
|
||||
) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]:
|
||||
"""
|
||||
This function removes all the Parameters from the model and
|
||||
return them as a tuple as well as their original attribute names.
|
||||
@ -109,7 +99,7 @@ def extract_weights(
|
||||
|
||||
def extract_buffers(
|
||||
mod: nn.Module,
|
||||
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
|
||||
) -> tuple[tuple[Tensor, ...], tuple[str, ...], dict[str, list[str]]]:
|
||||
return _extract_members(mod, mod.named_buffers, lambda x: x)
|
||||
|
||||
|
||||
@ -131,9 +121,9 @@ def load_weights(
|
||||
|
||||
|
||||
def _swap_state(
|
||||
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
|
||||
) -> List[Tensor]:
|
||||
result: List[Tensor] = []
|
||||
mod: nn.Module, names_map: dict[str, list[str]], elems: Iterable[Tensor]
|
||||
) -> list[Tensor]:
|
||||
result: list[Tensor] = []
|
||||
accessor = NamedMemberAccessor(mod)
|
||||
for (_, attr_names), elem in zip(names_map.items(), elems):
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
@ -261,10 +251,10 @@ class FunctionalModuleWithBuffers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stateless_model: nn.Module,
|
||||
param_names: Tuple[str, ...],
|
||||
buffer_names: Tuple[str, ...],
|
||||
param_names_map: Dict[str, List[str]],
|
||||
buffer_names_map: Dict[str, List[str]],
|
||||
param_names: tuple[str, ...],
|
||||
buffer_names: tuple[str, ...],
|
||||
param_names_map: dict[str, list[str]],
|
||||
buffer_names_map: dict[str, list[str]],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stateless_model = stateless_model
|
||||
@ -277,7 +267,7 @@ class FunctionalModuleWithBuffers(nn.Module):
|
||||
@staticmethod
|
||||
def _create_from(
|
||||
model: nn.Module, disable_autograd_tracking: bool = False
|
||||
) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
||||
) -> tuple["FunctionalModuleWithBuffers", tuple[Tensor, ...], tuple[Tensor, ...]]:
|
||||
# TODO: We don't need to copy the model to create a stateless copy
|
||||
model_copy = copy.deepcopy(model)
|
||||
params, param_names, param_names_map = extract_weights(model_copy)
|
||||
@ -317,8 +307,8 @@ class FunctionalModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stateless_model: nn.Module,
|
||||
param_names: Tuple[str, ...],
|
||||
names_map: Dict[str, List[str]],
|
||||
param_names: tuple[str, ...],
|
||||
names_map: dict[str, list[str]],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.stateless_model = stateless_model
|
||||
@ -328,7 +318,7 @@ class FunctionalModule(nn.Module):
|
||||
@staticmethod
|
||||
def _create_from(
|
||||
model: nn.Module, disable_autograd_tracking: bool = False
|
||||
) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
|
||||
) -> tuple["FunctionalModule", tuple[Tensor, ...]]:
|
||||
# TODO: We don't need to copy the model to create a stateless copy
|
||||
model_copy = copy.deepcopy(model)
|
||||
params, param_names, names_map = extract_weights(model_copy)
|
||||
@ -349,7 +339,7 @@ class FunctionalModule(nn.Module):
|
||||
|
||||
def make_functional(
|
||||
model: nn.Module, disable_autograd_tracking: bool = False
|
||||
) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
|
||||
) -> tuple[FunctionalModule, tuple[Tensor, ...]]:
|
||||
"""make_functional(model, disable_autograd_tracking=False) -> func, params
|
||||
|
||||
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
|
||||
@ -419,7 +409,7 @@ def make_functional(
|
||||
|
||||
def make_functional_with_buffers(
|
||||
model: nn.Module, disable_autograd_tracking: bool = False
|
||||
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
||||
) -> tuple[FunctionalModuleWithBuffers, tuple[Tensor, ...], tuple[Tensor, ...]]:
|
||||
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
|
||||
|
||||
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
|
||||
@ -479,8 +469,8 @@ def make_functional_with_buffers(
|
||||
|
||||
|
||||
def transpose_stack(
|
||||
tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
|
||||
) -> Tuple[Tensor, ...]:
|
||||
tuple_of_tuple_of_tensors: tuple[tuple[Tensor, ...], ...]
|
||||
) -> tuple[Tensor, ...]:
|
||||
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
|
||||
results = tuple(
|
||||
torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
|
||||
@ -490,7 +480,7 @@ def transpose_stack(
|
||||
|
||||
def combine_state_for_ensemble(
|
||||
models: Sequence[nn.Module],
|
||||
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
|
||||
) -> tuple[FunctionalModuleWithBuffers, tuple[Tensor, ...], tuple[Tensor, ...]]:
|
||||
"""combine_state_for_ensemble(models) -> func, params, buffers
|
||||
|
||||
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
|
||||
@ -551,8 +541,8 @@ def combine_state_for_ensemble(
|
||||
|
||||
|
||||
def functional_init(
|
||||
model_class: Type[nn.Module],
|
||||
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
|
||||
model_class: type[nn.Module],
|
||||
ensemble_shape: Union[tuple[()], tuple[int]] = (),
|
||||
device: torch.types.Device = "cpu",
|
||||
):
|
||||
def wrapped(*args, **kwargs):
|
||||
@ -578,8 +568,8 @@ def functional_init(
|
||||
|
||||
|
||||
def functional_init_with_buffers(
|
||||
model_class: Type[nn.Module],
|
||||
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
|
||||
model_class: type[nn.Module],
|
||||
ensemble_shape: Union[tuple[()], tuple[int]] = (),
|
||||
device: torch.types.Device = "cpu",
|
||||
):
|
||||
def wrapped(*args, **kwargs):
|
||||
|
@ -9,7 +9,7 @@ import operator
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.inductor_prims
|
||||
@ -55,11 +55,11 @@ prims = torch.ops.prims
|
||||
class OpTypes:
|
||||
"""Class for keeping track of different operator categories"""
|
||||
|
||||
fusible_ops: Set[Callable]
|
||||
compute_intensive_ops: Set[Callable]
|
||||
random_ops: Set[Callable]
|
||||
view_ops: Set[Callable]
|
||||
recomputable_ops: Set[Callable]
|
||||
fusible_ops: set[Callable]
|
||||
compute_intensive_ops: set[Callable]
|
||||
random_ops: set[Callable]
|
||||
view_ops: set[Callable]
|
||||
recomputable_ops: set[Callable]
|
||||
|
||||
def is_fusible(self, node: fx.Node):
|
||||
return get_aten_target(node) in self.fusible_ops
|
||||
@ -81,14 +81,14 @@ class OpTypes:
|
||||
class NodeInfo:
|
||||
# Be careful about iterating over these explicitly, as their order may not
|
||||
# be deterministic
|
||||
inputs: List[fx.Node]
|
||||
_required_fw_nodes: Set[fx.Node]
|
||||
required_bw_nodes: Set[fx.Node]
|
||||
unclaimed_nodes: Set[fx.Node]
|
||||
fw_order: Dict[fx.Node, int]
|
||||
inputs: list[fx.Node]
|
||||
_required_fw_nodes: set[fx.Node]
|
||||
required_bw_nodes: set[fx.Node]
|
||||
unclaimed_nodes: set[fx.Node]
|
||||
fw_order: dict[fx.Node, int]
|
||||
|
||||
@functools.cached_property
|
||||
def required_fw_nodes(self) -> List[fx.Node]:
|
||||
def required_fw_nodes(self) -> list[fx.Node]:
|
||||
return sorted(
|
||||
(n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n]
|
||||
)
|
||||
@ -158,8 +158,8 @@ InvalidNode = InvalidNodeBase()
|
||||
|
||||
def _extract_graph_with_inputs_outputs(
|
||||
joint_graph: fx.Graph,
|
||||
inputs: List[fx.Node],
|
||||
outputs: List[fx.Node],
|
||||
inputs: list[fx.Node],
|
||||
outputs: list[fx.Node],
|
||||
subgraph: Optional[str] = None,
|
||||
) -> fx.Graph:
|
||||
"""
|
||||
@ -272,7 +272,7 @@ def _must_be_in_backward(node: fx.Node) -> bool:
|
||||
|
||||
def _extract_fwd_bwd_outputs(
|
||||
joint_module: fx.GraphModule, *, num_fwd_outputs
|
||||
) -> Tuple[List[fx.Node], List[fx.Node]]:
|
||||
) -> tuple[list[fx.Node], list[fx.Node]]:
|
||||
outputs = pytree.arg_tree_leaves(
|
||||
*(node.args for node in joint_module.graph.find_nodes(op="output"))
|
||||
)
|
||||
@ -281,7 +281,7 @@ def _extract_fwd_bwd_outputs(
|
||||
return fwd_outputs, bwd_outputs
|
||||
|
||||
|
||||
def _remove_by_name(saved_values: List[fx.Node], name: str):
|
||||
def _remove_by_name(saved_values: list[fx.Node], name: str):
|
||||
for saved_value in saved_values:
|
||||
if saved_value.name == name:
|
||||
saved_values.remove(saved_value)
|
||||
@ -290,11 +290,11 @@ def _remove_by_name(saved_values: List[fx.Node], name: str):
|
||||
|
||||
def _extract_fwd_bwd_modules(
|
||||
joint_module: fx.GraphModule,
|
||||
saved_values: List[fx.Node],
|
||||
saved_sym_nodes: List[fx.Node],
|
||||
saved_values: list[fx.Node],
|
||||
saved_sym_nodes: list[fx.Node],
|
||||
*,
|
||||
num_fwd_outputs: int,
|
||||
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
||||
joint_module, num_fwd_outputs=num_fwd_outputs
|
||||
)
|
||||
@ -326,7 +326,7 @@ def _extract_fwd_bwd_modules(
|
||||
# we propagate all symbols which are referenced by backwards inputs.
|
||||
# These are not directly used in the graph but are required for downstream
|
||||
# sizevar assignment
|
||||
saved_symbols: Set[sympy.Symbol] = set()
|
||||
saved_symbols: set[sympy.Symbol] = set()
|
||||
saved_sym_nodes_binding = []
|
||||
saved_sym_nodes_derived = []
|
||||
|
||||
@ -389,7 +389,7 @@ def _extract_fwd_bwd_modules(
|
||||
|
||||
def default_partition(
|
||||
joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs
|
||||
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
"""
|
||||
Partitions the :attr:`joint_module` in a manner that closely resembles the
|
||||
behavior observed in the original ``.forward()`` and ``.backward()`` of the
|
||||
@ -512,7 +512,7 @@ def _size_of(node: fx.Node) -> int:
|
||||
def _count_ops(graph: fx.Graph):
|
||||
from collections import defaultdict
|
||||
|
||||
cnt: Dict[str, int] = defaultdict(int)
|
||||
cnt: dict[str, int] = defaultdict(int)
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function":
|
||||
cnt[node.target.__name__] += 1
|
||||
@ -537,7 +537,7 @@ def pointwise_ops():
|
||||
return ops
|
||||
|
||||
|
||||
def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]:
|
||||
def sort_depths(args, depth_map: dict[fx.Node, int]) -> list[tuple[fx.Node, int]]:
|
||||
arg_depths = {
|
||||
arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
|
||||
}
|
||||
@ -568,7 +568,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
||||
"""
|
||||
|
||||
new_graph = fx.Graph()
|
||||
env: Dict[fx.Node, fx.Node] = {}
|
||||
env: dict[fx.Node, fx.Node] = {}
|
||||
|
||||
# Add new placeholder nodes in the order specified by the inputs
|
||||
for node in gm.graph.find_nodes(op="placeholder"):
|
||||
@ -623,7 +623,7 @@ def functionalize_rng_ops(
|
||||
fw_module: fx.GraphModule,
|
||||
bw_module: fx.GraphModule,
|
||||
num_sym_nodes: int,
|
||||
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
# During user-driven activation checkpointing, we have to ensure that a rng
|
||||
# op in fwd yields the same output as the recomputed rng op in the bwd. To
|
||||
# do this, we use functionalize wrappers to wrap the random ops and share
|
||||
@ -1074,12 +1074,12 @@ def solve_min_cut(
|
||||
# backwards pass instead of only relying on whether it's unfusible in the
|
||||
# forwards.
|
||||
|
||||
def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int:
|
||||
def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
|
||||
"""
|
||||
Finds the first unfusible node in the chain of nodes starting from
|
||||
`start_nodes` and returns its position.
|
||||
"""
|
||||
sorted_nodes: List[Tuple[int, fx.Node, bool]] = []
|
||||
sorted_nodes: list[tuple[int, fx.Node, bool]] = []
|
||||
for n in start_nodes:
|
||||
heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True))
|
||||
|
||||
@ -1182,7 +1182,7 @@ def solve_min_cut(
|
||||
raise
|
||||
|
||||
reachable, non_reachable = partition
|
||||
cutset: Set[Tuple[str, str]] = set()
|
||||
cutset: set[tuple[str, str]] = set()
|
||||
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
|
||||
cutset.update((u, v) for v in nbrs if v in non_reachable)
|
||||
|
||||
@ -1219,7 +1219,7 @@ def visualize_min_cut_graph(nx_graph):
|
||||
|
||||
|
||||
def get_default_op_list() -> OpTypes:
|
||||
default_recomputable_ops: List[Callable] = [
|
||||
default_recomputable_ops: list[Callable] = [
|
||||
aten.add,
|
||||
aten.sub,
|
||||
aten.div,
|
||||
@ -1392,12 +1392,12 @@ def get_name_to_node(graph: fx.Graph):
|
||||
|
||||
def _optimize_runtime_with_given_memory(
|
||||
joint_graph: fx.Graph,
|
||||
memory: List[float],
|
||||
runtimes: List[float],
|
||||
memory: list[float],
|
||||
runtimes: list[float],
|
||||
max_memory: float,
|
||||
node_info: NodeInfo,
|
||||
all_recomputable_banned_nodes: List[fx.Node],
|
||||
) -> Tuple[float, List[int], List[int]]:
|
||||
all_recomputable_banned_nodes: list[fx.Node],
|
||||
) -> tuple[float, list[int], list[int]]:
|
||||
SOLVER = config.activation_memory_budget_solver
|
||||
if SOLVER == "greedy":
|
||||
return greedy_knapsack(memory, runtimes, max_memory)
|
||||
@ -1490,7 +1490,7 @@ def choose_saved_values_set(
|
||||
joint_graph: fx.Graph,
|
||||
node_info: NodeInfo,
|
||||
memory_budget=1,
|
||||
) -> List[fx.Node]:
|
||||
) -> list[fx.Node]:
|
||||
if memory_budget > 1 or memory_budget < 0:
|
||||
raise RuntimeError(
|
||||
f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}"
|
||||
@ -1523,7 +1523,7 @@ def choose_saved_values_set(
|
||||
if memory_budget == 1:
|
||||
return runtime_optimized_saved_values
|
||||
|
||||
def estimate_activations_size(saved_values: List[fx.Node]) -> float:
|
||||
def estimate_activations_size(saved_values: list[fx.Node]) -> float:
|
||||
return sum(map(_size_of, saved_values)) / 1e9
|
||||
|
||||
min_act_size = estimate_activations_size(node_info.inputs)
|
||||
@ -1535,7 +1535,7 @@ def choose_saved_values_set(
|
||||
def get_normalized_size(sz):
|
||||
return (sz / 1e9) / (max_act_size - min_act_size)
|
||||
|
||||
def get_mem_ratio(activations: List[fx.Node]):
|
||||
def get_mem_ratio(activations: list[fx.Node]):
|
||||
return (estimate_activations_size(activations) - min_act_size) / (
|
||||
max_act_size - min_act_size
|
||||
)
|
||||
@ -1567,7 +1567,7 @@ def choose_saved_values_set(
|
||||
|
||||
input_storages = {get_node_storage(node) for node in node_info.inputs}
|
||||
|
||||
def get_recomputable_banned_nodes(banned_nodes: Set[fx.Node]) -> List[fx.Node]:
|
||||
def get_recomputable_banned_nodes(banned_nodes: set[fx.Node]) -> list[fx.Node]:
|
||||
return [
|
||||
i
|
||||
for i in banned_nodes
|
||||
@ -1729,7 +1729,7 @@ def min_cut_rematerialization_partition(
|
||||
compiler="inductor",
|
||||
*,
|
||||
num_fwd_outputs,
|
||||
) -> Tuple[fx.GraphModule, fx.GraphModule]:
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
"""
|
||||
Partitions the joint graph such that the backward recomputes the forward.
|
||||
Recomputing helps in trading off memory bandwidth with computation.
|
||||
@ -1799,7 +1799,7 @@ def min_cut_rematerialization_partition(
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, "forward"
|
||||
)
|
||||
required_fw_nodes: Set[fx.Node] = {
|
||||
required_fw_nodes: set[fx.Node] = {
|
||||
name_to_node[node.name]
|
||||
for node in forward_only_graph.nodes
|
||||
if node.op != "output"
|
||||
@ -1886,7 +1886,7 @@ def min_cut_rematerialization_partition(
|
||||
}
|
||||
remat_nodes = fw_module_nodes & bw_module_nodes
|
||||
|
||||
counts: Dict[str, int] = defaultdict(int)
|
||||
counts: dict[str, int] = defaultdict(int)
|
||||
for node in fw_module.graph.nodes:
|
||||
if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
|
||||
counts[str(node.target._overloadpacket)] += 1
|
||||
@ -1906,7 +1906,7 @@ def draw_graph(
|
||||
fname: str,
|
||||
figname: str = "fx_graph",
|
||||
clear_meta: bool = True,
|
||||
prog: Optional[Union[str, List[str]]] = None,
|
||||
prog: Optional[Union[str, list[str]]] = None,
|
||||
parse_stack_trace: bool = False,
|
||||
dot_graph_shape: Optional[str] = None,
|
||||
) -> None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -258,14 +258,14 @@ def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
|
||||
return coerce_cinterpreter(interpreter)
|
||||
|
||||
|
||||
def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
|
||||
def retrieve_all_functorch_interpreters() -> list[FuncTorchInterpreter]:
|
||||
cis = torch._C._functorch.get_interpreter_stack()
|
||||
if cis is None:
|
||||
return []
|
||||
return [coerce_cinterpreter(ci) for ci in cis]
|
||||
|
||||
|
||||
def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
|
||||
def compare_functorch_state(states: list[tuple[Any, ...]]) -> bool:
|
||||
# There are four possible cases covered here:
|
||||
# 1. Current stack empty AND stack when generated not empty -> Invalidate
|
||||
# 2. Current stack not empty AND stack when generated empty -> Invalidate
|
||||
|
@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
from typing import Any, Generator, Tuple, Union
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch._C._functorch import (
|
||||
@ -28,7 +29,7 @@ def enable_single_level_autograd_function() -> Generator[None, None, None]:
|
||||
set_single_level_autograd_function_allowed(prev_state)
|
||||
|
||||
|
||||
def unwrap_dead_wrappers(args: Tuple[Any, ...]) -> Tuple[Any, ...]:
|
||||
def unwrap_dead_wrappers(args: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
# NB: doesn't use tree_map_only for performance reasons
|
||||
result = tuple(
|
||||
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
|
||||
@ -36,4 +37,4 @@ def unwrap_dead_wrappers(args: Tuple[Any, ...]) -> Tuple[Any, ...]:
|
||||
return result
|
||||
|
||||
|
||||
argnums_t = Union[int, Tuple[int, ...]]
|
||||
argnums_t = Union[int, tuple[int, ...]]
|
||||
|
@ -12,7 +12,7 @@ import itertools
|
||||
import os
|
||||
import threading
|
||||
from functools import partial
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -32,8 +32,8 @@ from torch.utils._pytree import (
|
||||
)
|
||||
|
||||
|
||||
in_dims_t = Union[int, Tuple]
|
||||
out_dims_t = Union[int, Tuple[int, ...]]
|
||||
in_dims_t = Union[int, tuple]
|
||||
out_dims_t = Union[int, tuple[int, ...]]
|
||||
|
||||
|
||||
def doesnt_support_saved_tensors_hooks(f):
|
||||
@ -52,7 +52,7 @@ def doesnt_support_saved_tensors_hooks(f):
|
||||
|
||||
# Checks that all args-to-be-batched have the same batch dim size
|
||||
def _validate_and_get_batch_size(
|
||||
flat_in_dims: List[Optional[int]], flat_args: List
|
||||
flat_in_dims: list[Optional[int]], flat_args: list
|
||||
) -> int:
|
||||
batch_sizes = [
|
||||
arg.size(in_dim)
|
||||
@ -69,7 +69,7 @@ def _validate_and_get_batch_size(
|
||||
return batch_sizes[0]
|
||||
|
||||
|
||||
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
|
||||
def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
|
||||
if isinstance(batched_outputs, tuple):
|
||||
return len(batched_outputs)
|
||||
return 1
|
||||
@ -81,7 +81,7 @@ def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
|
||||
|
||||
def _as_tuple(
|
||||
value: Any, num_elements: int, error_message_lambda: Callable[[], str]
|
||||
) -> Tuple:
|
||||
) -> tuple:
|
||||
if not isinstance(value, tuple):
|
||||
return (value,) * num_elements
|
||||
if len(value) != num_elements:
|
||||
@ -90,8 +90,8 @@ def _as_tuple(
|
||||
|
||||
|
||||
def _process_batched_inputs(
|
||||
in_dims: in_dims_t, args: Tuple, func: Callable
|
||||
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
|
||||
in_dims: in_dims_t, args: tuple, func: Callable
|
||||
) -> tuple[int, list[Any], list[Any], TreeSpec]:
|
||||
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
|
||||
raise ValueError(
|
||||
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
|
||||
@ -152,8 +152,8 @@ def _process_batched_inputs(
|
||||
|
||||
|
||||
def _create_batched_inputs(
|
||||
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec
|
||||
) -> Tuple:
|
||||
flat_in_dims: list[Any], flat_args: list[Any], vmap_level: int, args_spec
|
||||
) -> tuple:
|
||||
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
|
||||
batched_inputs = [
|
||||
arg if in_dim is None else _add_batch_dim(arg, in_dim, vmap_level)
|
||||
@ -186,12 +186,12 @@ def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_di
|
||||
|
||||
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
|
||||
def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
batched_outputs: Union[Tensor, tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int,
|
||||
batch_size: int,
|
||||
func: Callable,
|
||||
) -> Tuple:
|
||||
) -> tuple:
|
||||
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
|
||||
|
||||
def incompatible_error():
|
||||
|
Reference in New Issue
Block a user