PEP585 update - torch/_inductor (#145198)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-20 12:27:30 -08:00
committed by PyTorch MergeBot
parent 2f9d378f7b
commit bac62341eb
34 changed files with 494 additions and 545 deletions

View File

@ -4,7 +4,7 @@ import collections
import dataclasses
import heapq
import logging
from typing import Callable, Dict, List, TYPE_CHECKING, TypedDict, Union
from typing import Callable, TYPE_CHECKING, TypedDict, Union
from torch._utils_internal import signpost_event
from torch.utils._ordered_set import OrderedSet
@ -61,9 +61,9 @@ class FreeableInputBuffer:
def get_freeable_input_buf(
nodes: List[BaseSchedulerNode],
nodes: list[BaseSchedulerNode],
graph_inputs: OrderedSet[str],
) -> Dict[str, FreeableInputBuffer]:
) -> dict[str, FreeableInputBuffer]:
"""
Create and keep track of all input buffers that can be freed during the program
@ -87,10 +87,10 @@ def get_freeable_input_buf(
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: Dict[
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
dep_name_to_size: Dict[str, int] = dict()
dep_name_to_size: dict[str, int] = dict()
for node in nodes:
for dep in node.read_writes.reads:
if dep.name in graph_inputs and not dep.name.startswith(
@ -100,7 +100,7 @@ def get_freeable_input_buf(
dep_name_to_size[dep.name] = _dep_size_hint(dep)
# create FreeableInputBuffer objects and add them to the returned dictionary
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = dict()
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict()
for dep_name, succ_nodes in dep_name_to_succ_nodes.items():
name_to_freeable_input_buf[dep_name] = FreeableInputBuffer(
dep_name,
@ -112,8 +112,8 @@ def get_freeable_input_buf(
def compute_size_for_scheduler_buffer(
name_to_buf: Dict[str, SchedulerBuffer]
) -> Dict[str, tuple[int, int]]:
name_to_buf: dict[str, SchedulerBuffer]
) -> dict[str, tuple[int, int]]:
"""
Compute the size of each scheduler buffer, including (1) memory allocated when
it is created and (2) memory deallocated when it is freed.
@ -134,7 +134,7 @@ def compute_size_for_scheduler_buffer(
from .ir import MultiOutput
from .scheduler import OutputNode
sched_buf_to_size: Dict[str, tuple[int, int]] = dict()
sched_buf_to_size: dict[str, tuple[int, int]] = dict()
def _compute_and_update_buf_size(
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
@ -175,8 +175,8 @@ def compute_size_for_scheduler_buffer(
def assign_memory_planning_info_for_scheduler_buffers(
nodes: List[BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
) -> None:
"""
For each SchedulerBuffer, assign its size info and successor nodes.
@ -187,7 +187,7 @@ def assign_memory_planning_info_for_scheduler_buffers(
# get buffer's successor nodes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: Dict[
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
for node in nodes:
@ -205,10 +205,10 @@ def assign_memory_planning_info_for_scheduler_buffers(
def assign_memory_planning_info_for_scheduler_nodes(
nodes: List[BaseSchedulerNode],
name_to_fused_node: Dict[str, BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
nodes: list[BaseSchedulerNode],
name_to_fused_node: dict[str, BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
) -> None:
"""
Assign to each scheduler node its predecessor and successor nodes.
@ -243,10 +243,10 @@ def assign_memory_planning_info_for_scheduler_nodes(
def estimate_peak_memory(
nodes: List[BaseSchedulerNode],
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[int, List[int]]:
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
@ -267,12 +267,12 @@ def estimate_peak_memory(
# get the execution step of each node, this will be used to determine
# the end_step of buffers
node_to_step: Dict[BaseSchedulerNode, int] = dict()
node_to_step: dict[BaseSchedulerNode, int] = dict()
for step, node in enumerate(nodes):
node_to_step[node] = step
# get buffers' size and liveliness information
buf_info_list: List[BufferInfo] = []
buf_info_list: list[BufferInfo] = []
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = (
@ -340,11 +340,11 @@ def estimate_peak_memory(
def topological_sort_lpmf(
nodes: List[BaseSchedulerNode],
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
name_to_buf: Dict[str, SchedulerBuffer],
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
name_to_buf: dict[str, SchedulerBuffer],
graph_outputs: OrderedSet[str],
) -> List[BaseSchedulerNode]:
) -> list[BaseSchedulerNode]:
"""
A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First".
@ -372,8 +372,8 @@ def topological_sort_lpmf(
class BufferInfo(TypedDict):
outdegree: int
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
buf_info: Dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
@ -422,7 +422,7 @@ def topological_sort_lpmf(
node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free
# schedule nodes one at a time
schedule: List[BaseSchedulerNode] = []
schedule: list[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule:
@ -464,7 +464,7 @@ def topological_sort_lpmf(
return schedule
def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
A BFS topological sort that selects nodes whose dependencies are executed the
earliest. This follows a FIFO idea. Specifically, at every iteration, for each node
@ -478,11 +478,11 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
indegree: int
order: int
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
@dataclasses.dataclass
class NodeWithPriority:
priority: List[int]
priority: list[int]
node: BaseSchedulerNode
def __lt__(self, other: NodeWithPriority) -> bool:
@ -490,7 +490,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
return self.node.mpi_node.index < other.node.mpi_node.index
return self.priority < other.priority
def _node_priority(node: BaseSchedulerNode) -> List[int]:
def _node_priority(node: BaseSchedulerNode) -> list[int]:
# priority is the order in which predecessor nodes are executed
assert node_info[node]["indegree"] == 0
exec_orders = sorted(
@ -502,7 +502,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
nodes_to_schedule: List[NodeWithPriority] = []
nodes_to_schedule: list[NodeWithPriority] = []
for node in nodes:
node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1}
if node_info[node]["indegree"] == 0:
@ -511,7 +511,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
)
# schedule nodes one at a time
schedule: List[BaseSchedulerNode] = []
schedule: list[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule
@ -536,7 +536,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
return schedule
def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
This is a DFS topological sort. The setup is similar to `topological_sort_schedule`
in scheduler.py. The difference is the order nodes are visited in the outer loop.
@ -546,9 +546,9 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
the nodes in ascending order of this priority.
"""
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []
size_with_reads: Dict[BaseSchedulerNode, int] = dict()
name_to_node: dict[str, BaseSchedulerNode] = dict()
result: list[BaseSchedulerNode] = []
size_with_reads: dict[BaseSchedulerNode, int] = dict()
def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
@ -579,17 +579,17 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
def reorder_for_peak_memory(
nodes: List[BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
name_to_fused_node: Dict[str, BaseSchedulerNode],
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
methods: List[Callable[..., List[BaseSchedulerNode]]] = [ # noqa: B006
methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006
topological_sort_lpmf,
topological_sort_bfs,
topological_sort_dfs,
],
) -> List[BaseSchedulerNode]:
) -> list[BaseSchedulerNode]:
"""
Try a few heuristics based topological sort algorithms, and pick the one whose
resulting topological order has the lowest peak memory estimation.
@ -599,13 +599,13 @@ def reorder_for_peak_memory(
@dataclasses.dataclass
class PeakMemoryResult:
order: List[BaseSchedulerNode]
order: list[BaseSchedulerNode]
peak_memory: int
method: str
# preparation -- as nodes are scheduled one at a time, these help
# keep track of when a buffer can be freed, and when a node can be scheduled
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf(
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
nodes, graph_inputs
)
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
@ -614,7 +614,7 @@ def reorder_for_peak_memory(
)
# keep track of the peak memory estimates of different methods
peak_memory_diff_methods: List[PeakMemoryResult] = []
peak_memory_diff_methods: list[PeakMemoryResult] = []
# the default
estimated_peak_memory, _ = estimate_peak_memory(