mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_inductor (#145198)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f9d378f7b
commit
bac62341eb
@ -4,7 +4,7 @@ import collections
|
||||
import dataclasses
|
||||
import heapq
|
||||
import logging
|
||||
from typing import Callable, Dict, List, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Callable, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -61,9 +61,9 @@ class FreeableInputBuffer:
|
||||
|
||||
|
||||
def get_freeable_input_buf(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
graph_inputs: OrderedSet[str],
|
||||
) -> Dict[str, FreeableInputBuffer]:
|
||||
) -> dict[str, FreeableInputBuffer]:
|
||||
"""
|
||||
Create and keep track of all input buffers that can be freed during the program
|
||||
|
||||
@ -87,10 +87,10 @@ def get_freeable_input_buf(
|
||||
|
||||
# get freeable input buffers' successor nodes and their sizes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
dep_name_to_succ_nodes: Dict[
|
||||
dep_name_to_succ_nodes: dict[
|
||||
str, OrderedSet[BaseSchedulerNode]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
dep_name_to_size: Dict[str, int] = dict()
|
||||
dep_name_to_size: dict[str, int] = dict()
|
||||
for node in nodes:
|
||||
for dep in node.read_writes.reads:
|
||||
if dep.name in graph_inputs and not dep.name.startswith(
|
||||
@ -100,7 +100,7 @@ def get_freeable_input_buf(
|
||||
dep_name_to_size[dep.name] = _dep_size_hint(dep)
|
||||
|
||||
# create FreeableInputBuffer objects and add them to the returned dictionary
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = dict()
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict()
|
||||
for dep_name, succ_nodes in dep_name_to_succ_nodes.items():
|
||||
name_to_freeable_input_buf[dep_name] = FreeableInputBuffer(
|
||||
dep_name,
|
||||
@ -112,8 +112,8 @@ def get_freeable_input_buf(
|
||||
|
||||
|
||||
def compute_size_for_scheduler_buffer(
|
||||
name_to_buf: Dict[str, SchedulerBuffer]
|
||||
) -> Dict[str, tuple[int, int]]:
|
||||
name_to_buf: dict[str, SchedulerBuffer]
|
||||
) -> dict[str, tuple[int, int]]:
|
||||
"""
|
||||
Compute the size of each scheduler buffer, including (1) memory allocated when
|
||||
it is created and (2) memory deallocated when it is freed.
|
||||
@ -134,7 +134,7 @@ def compute_size_for_scheduler_buffer(
|
||||
from .ir import MultiOutput
|
||||
from .scheduler import OutputNode
|
||||
|
||||
sched_buf_to_size: Dict[str, tuple[int, int]] = dict()
|
||||
sched_buf_to_size: dict[str, tuple[int, int]] = dict()
|
||||
|
||||
def _compute_and_update_buf_size(
|
||||
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
|
||||
@ -175,8 +175,8 @@ def compute_size_for_scheduler_buffer(
|
||||
|
||||
|
||||
def assign_memory_planning_info_for_scheduler_buffers(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
) -> None:
|
||||
"""
|
||||
For each SchedulerBuffer, assign its size info and successor nodes.
|
||||
@ -187,7 +187,7 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
||||
|
||||
# get buffer's successor nodes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
dep_name_to_succ_nodes: Dict[
|
||||
dep_name_to_succ_nodes: dict[
|
||||
str, OrderedSet[BaseSchedulerNode]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
for node in nodes:
|
||||
@ -205,10 +205,10 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
||||
|
||||
|
||||
def assign_memory_planning_info_for_scheduler_nodes(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
) -> None:
|
||||
"""
|
||||
Assign to each scheduler node its predecessor and successor nodes.
|
||||
@ -243,10 +243,10 @@ def assign_memory_planning_info_for_scheduler_nodes(
|
||||
|
||||
|
||||
def estimate_peak_memory(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[int, List[int]]:
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Given a list of nodes in their execution order, estimate the peak memory, by
|
||||
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
|
||||
@ -267,12 +267,12 @@ def estimate_peak_memory(
|
||||
|
||||
# get the execution step of each node, this will be used to determine
|
||||
# the end_step of buffers
|
||||
node_to_step: Dict[BaseSchedulerNode, int] = dict()
|
||||
node_to_step: dict[BaseSchedulerNode, int] = dict()
|
||||
for step, node in enumerate(nodes):
|
||||
node_to_step[node] = step
|
||||
|
||||
# get buffers' size and liveliness information
|
||||
buf_info_list: List[BufferInfo] = []
|
||||
buf_info_list: list[BufferInfo] = []
|
||||
# 1. for freeable input buffers
|
||||
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
||||
end_step = (
|
||||
@ -340,11 +340,11 @@ def estimate_peak_memory(
|
||||
|
||||
|
||||
def topological_sort_lpmf(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> List[BaseSchedulerNode]:
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First".
|
||||
|
||||
@ -372,8 +372,8 @@ def topological_sort_lpmf(
|
||||
class BufferInfo(TypedDict):
|
||||
outdegree: int
|
||||
|
||||
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
buf_info: Dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
|
||||
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
|
||||
|
||||
# compute nodes' number of unmet dependencies (for schedulability)
|
||||
# initialize the list of nodes ready to be scheduled
|
||||
@ -422,7 +422,7 @@ def topological_sort_lpmf(
|
||||
node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free
|
||||
|
||||
# schedule nodes one at a time
|
||||
schedule: List[BaseSchedulerNode] = []
|
||||
schedule: list[BaseSchedulerNode] = []
|
||||
num_iters: int = 0
|
||||
while num_iters < len(nodes) and nodes_to_schedule:
|
||||
# select a node to schedule:
|
||||
@ -464,7 +464,7 @@ def topological_sort_lpmf(
|
||||
return schedule
|
||||
|
||||
|
||||
def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
A BFS topological sort that selects nodes whose dependencies are executed the
|
||||
earliest. This follows a FIFO idea. Specifically, at every iteration, for each node
|
||||
@ -478,11 +478,11 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
indegree: int
|
||||
order: int
|
||||
|
||||
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NodeWithPriority:
|
||||
priority: List[int]
|
||||
priority: list[int]
|
||||
node: BaseSchedulerNode
|
||||
|
||||
def __lt__(self, other: NodeWithPriority) -> bool:
|
||||
@ -490,7 +490,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
return self.node.mpi_node.index < other.node.mpi_node.index
|
||||
return self.priority < other.priority
|
||||
|
||||
def _node_priority(node: BaseSchedulerNode) -> List[int]:
|
||||
def _node_priority(node: BaseSchedulerNode) -> list[int]:
|
||||
# priority is the order in which predecessor nodes are executed
|
||||
assert node_info[node]["indegree"] == 0
|
||||
exec_orders = sorted(
|
||||
@ -502,7 +502,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
|
||||
# compute nodes' number of unmet dependencies (for schedulability)
|
||||
# initialize the list of nodes ready to be scheduled
|
||||
nodes_to_schedule: List[NodeWithPriority] = []
|
||||
nodes_to_schedule: list[NodeWithPriority] = []
|
||||
for node in nodes:
|
||||
node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1}
|
||||
if node_info[node]["indegree"] == 0:
|
||||
@ -511,7 +511,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
)
|
||||
|
||||
# schedule nodes one at a time
|
||||
schedule: List[BaseSchedulerNode] = []
|
||||
schedule: list[BaseSchedulerNode] = []
|
||||
num_iters: int = 0
|
||||
while num_iters < len(nodes) and nodes_to_schedule:
|
||||
# select a node to schedule
|
||||
@ -536,7 +536,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
return schedule
|
||||
|
||||
|
||||
def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
This is a DFS topological sort. The setup is similar to `topological_sort_schedule`
|
||||
in scheduler.py. The difference is the order nodes are visited in the outer loop.
|
||||
@ -546,9 +546,9 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
the nodes in ascending order of this priority.
|
||||
"""
|
||||
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
|
||||
name_to_node: Dict[str, BaseSchedulerNode] = dict()
|
||||
result: List[BaseSchedulerNode] = []
|
||||
size_with_reads: Dict[BaseSchedulerNode, int] = dict()
|
||||
name_to_node: dict[str, BaseSchedulerNode] = dict()
|
||||
result: list[BaseSchedulerNode] = []
|
||||
size_with_reads: dict[BaseSchedulerNode, int] = dict()
|
||||
|
||||
def visit(n: BaseSchedulerNode) -> None:
|
||||
if n not in seen:
|
||||
@ -579,17 +579,17 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
|
||||
|
||||
def reorder_for_peak_memory(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
graph_inputs: OrderedSet[str],
|
||||
graph_outputs: OrderedSet[str],
|
||||
methods: List[Callable[..., List[BaseSchedulerNode]]] = [ # noqa: B006
|
||||
methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006
|
||||
topological_sort_lpmf,
|
||||
topological_sort_bfs,
|
||||
topological_sort_dfs,
|
||||
],
|
||||
) -> List[BaseSchedulerNode]:
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Try a few heuristics based topological sort algorithms, and pick the one whose
|
||||
resulting topological order has the lowest peak memory estimation.
|
||||
@ -599,13 +599,13 @@ def reorder_for_peak_memory(
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PeakMemoryResult:
|
||||
order: List[BaseSchedulerNode]
|
||||
order: list[BaseSchedulerNode]
|
||||
peak_memory: int
|
||||
method: str
|
||||
|
||||
# preparation -- as nodes are scheduled one at a time, these help
|
||||
# keep track of when a buffer can be freed, and when a node can be scheduled
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
nodes, graph_inputs
|
||||
)
|
||||
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
|
||||
@ -614,7 +614,7 @@ def reorder_for_peak_memory(
|
||||
)
|
||||
|
||||
# keep track of the peak memory estimates of different methods
|
||||
peak_memory_diff_methods: List[PeakMemoryResult] = []
|
||||
peak_memory_diff_methods: list[PeakMemoryResult] = []
|
||||
|
||||
# the default
|
||||
estimated_peak_memory, _ = estimate_peak_memory(
|
||||
|
Reference in New Issue
Block a user