mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405 Approved by: https://github.com/eellison
1860 lines
70 KiB
Python
1860 lines
70 KiB
Python
# mypy: allow-untyped-defs
|
|
# pyre-strict
|
|
from __future__ import annotations
|
|
|
|
import heapq
|
|
import importlib
|
|
import itertools
|
|
import logging
|
|
import operator
|
|
import sys
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
from torch._logging import trace_structured
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from . import config, ir
|
|
from .dependencies import WeakDep
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from .ir import IRNode, Operation
|
|
from .scheduler import SchedulerBuffer
|
|
|
|
from .memory import (
|
|
estimate_peak_memory,
|
|
estimate_peak_memory_allocfree,
|
|
FreeableInputBuffer,
|
|
get_freeable_input_buf,
|
|
SNodeMemory,
|
|
)
|
|
from .utils import (
|
|
contains_collective,
|
|
contains_wait,
|
|
find_recursive_deps_of_node,
|
|
find_recursive_users_of_node,
|
|
is_collective,
|
|
is_fallback_op,
|
|
is_wait,
|
|
)
|
|
from .virtualized import V
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import BaseSchedulerNode
|
|
|
|
|
|
def align_runtime_estimations_across_all_distributed_ranks(
|
|
snodes: list[BaseSchedulerNode],
|
|
):
|
|
runtime_estimations = {}
|
|
for snode in snodes:
|
|
runtime_estimations[snode] = snode.get_estimated_runtime()
|
|
import torch.distributed as dist
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
|
|
world_size = dist.get_world_size()
|
|
pg = _get_default_group()
|
|
gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
|
|
dist.all_gather_object(
|
|
gathered_runtime_estimations, list(runtime_estimations.values()), pg
|
|
)
|
|
median_runtime_estimations = torch.median(
|
|
torch.tensor(gathered_runtime_estimations), dim=0
|
|
).values.tolist()
|
|
for i in range(len(snodes)):
|
|
snodes[i].override_estimated_runtime = median_runtime_estimations[i]
|
|
|
|
|
|
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
|
"""
|
|
Greedily schedules waits as late as possible.
|
|
"""
|
|
return _schedule_for_comm(
|
|
snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
|
|
)
|
|
|
|
|
|
def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
|
"""
|
|
Greedily schedules comms as early as possible.
|
|
"""
|
|
return _schedule_for_comm(
|
|
snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
|
|
)
|
|
|
|
|
|
def reorder_compute_for_overlap(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> list[BaseSchedulerNode]:
|
|
"""
|
|
This achieves the following overall scheduling procedure:
|
|
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
|
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
|
|
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
|
|
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
|
|
We prioritize compute nodes that are needed sooner.
|
|
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
|
|
Step 4: We schedule comm N + 1.
|
|
Repeat this for subsequent comm nodes.
|
|
"""
|
|
return _schedule_for_comm(
|
|
snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
|
|
)
|
|
|
|
|
|
def reorder_communication_preserving_peak_memory(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> list[BaseSchedulerNode]:
|
|
"""
|
|
Reorders communication ops relative to computation ops to improve communication-compute overlapping and hide comm
|
|
latency. Stops moving a particular op if it reaches a point that would have increased the peak memory footprint.
|
|
|
|
Currently, follows these heuristics (subject to change or tune):
|
|
- never reorders collectives relative to one another, for SPMD safety
|
|
- has an option for per-collective prefetch limit, but does not enable it by default
|
|
- limits the total number of reorder steps to some factor of the graph size to prevent worst-case quadratic
|
|
performance
|
|
|
|
Prerequisite: sink_comms_and_waits - ensure comm and wait nodes are scheduled as late as possible, respecting data
|
|
dependencies. That allows reorder_communication_preserving_peak_memory to take a best case peak-memory snapshot,
|
|
and then monotonically improve latency by moving collectives backward in time.
|
|
|
|
Peak memory impact is computed in an iterative fashion. First, memory use at each timestep is computed, and global
|
|
peak memory is computed as a max over timesteps. Then, when swapping any two adjacent nodes, only the curr-memory
|
|
for the earlier of the nodes after the swap is affected. This enables checking step by step whether a swap is
|
|
peak-memory-safe, and bailing out if not. Example:
|
|
|
|
0 n0 C0
|
|
1 n1 C0 + Allocs(n1) - Frees(n1)
|
|
2 n2 C0 + Allocs(n1) - Frees(n1) + Allocs(n2) - Frees(n2)
|
|
|
|
0 n0 C0
|
|
1 n2 C0 + Allocs(n2) - Frees(n2) <-- After moving n2 to Time 1, only time1 memory changes
|
|
2 n1 C0 + Allocs(n2) - Frees(n2) + Allocs(n1) - Frees(n1)
|
|
|
|
"""
|
|
reordered_snodes, node_stats = (
|
|
_reorder_communication_preserving_peak_memory_internal(snodes)
|
|
)
|
|
|
|
return reordered_snodes
|
|
|
|
|
|
@dataclass
|
|
class ReorderInfo:
|
|
"""
|
|
Debug info describing how an individual snode was reordered
|
|
"""
|
|
|
|
initial_exposed: float = -1
|
|
final_exposed: float = -1
|
|
limiting_factor: str = "None"
|
|
moves: int = 0
|
|
grouped: int = 0
|
|
grouped_info: str = ""
|
|
|
|
@property
|
|
def improvement(self):
|
|
return self.initial_exposed - self.final_exposed
|
|
|
|
|
|
def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
|
|
if node is None:
|
|
return False
|
|
|
|
if is_fallback_op(
|
|
node, # type: ignore[arg-type]
|
|
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
|
):
|
|
return True
|
|
|
|
if (
|
|
python_kernel_name := getattr(node, "python_kernel_name", None)
|
|
) and "extern_kernels" in python_kernel_name:
|
|
return True
|
|
return False
|
|
|
|
|
|
def contains_gemm_like(snode: BaseSchedulerNode) -> bool:
|
|
from torch._inductor.scheduler import GroupedSchedulerNode
|
|
|
|
if isinstance(snode, GroupedSchedulerNode):
|
|
return any(contains_gemm_like(x) for x in snode.snodes)
|
|
else:
|
|
return is_gemm_like(snode.node)
|
|
|
|
|
|
def _temp_group_visit_leaves(snode, fn):
|
|
from torch._inductor.scheduler import GroupedSchedulerNode
|
|
|
|
if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping:
|
|
for _snode in snode.snodes:
|
|
fn(_snode)
|
|
else:
|
|
fn(snode)
|
|
|
|
|
|
def _group_name(snode, with_bufs=False) -> str:
|
|
ret = ""
|
|
for n in snode.snodes:
|
|
if ret:
|
|
ret += "_"
|
|
ret += n.get_name()
|
|
if with_bufs:
|
|
ret += f"{list(snode.get_buffer_names())}"
|
|
return ret
|
|
|
|
|
|
def _is_fake_dep(d):
|
|
return isinstance(d, WeakDep) and d.is_fake
|
|
|
|
|
|
def _group_names(gns: list[BaseSchedulerNode]) -> str:
|
|
return "~".join([gn.get_name() for gn in gns])
|
|
|
|
|
|
def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
|
|
"""Initialize memory tracking data structures"""
|
|
name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
|
|
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
|
|
estimate_peak_memory_allocfree(
|
|
snodes, name_to_freeable_input_buf, graph_outputs
|
|
)
|
|
)
|
|
_curr_memory = dict(zip(snodes, snodes_curr_memory))
|
|
_curr_memory[None] = (0, 0)
|
|
return (
|
|
peak_memory,
|
|
_curr_memory,
|
|
snodes_allocfree,
|
|
buf_to_snode_last_use,
|
|
name_to_freeable_input_buf,
|
|
)
|
|
|
|
|
|
def _initialize_double_linked_list(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> tuple[
|
|
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
|
|
dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
|
|
BaseSchedulerNode,
|
|
]:
|
|
"""Create double-linked list structure from snodes"""
|
|
_prev = {}
|
|
_next = {}
|
|
for i, snode in enumerate(snodes):
|
|
_prev[snode] = snodes[i - 1] if i > 0 else None
|
|
_next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
|
|
_head = snodes[0]
|
|
return _prev, _next, _head
|
|
|
|
|
|
def _reorder_communication_preserving_peak_memory_internal(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
|
|
"""
|
|
Internal testing helper that also returns debug info.
|
|
Returns:
|
|
- reordered snodes list
|
|
- dict {snode: ReorderInfo}
|
|
"""
|
|
has_collectives = False
|
|
for snode in snodes:
|
|
if contains_collective(snode):
|
|
has_collectives = True
|
|
break
|
|
if not has_collectives:
|
|
return snodes, {}
|
|
|
|
from torch._inductor.scheduler import GroupedSchedulerNode
|
|
|
|
original_snodes_num = len(snodes)
|
|
# heuristic to avoid degenerating to quadratic time
|
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
|
(
|
|
peak_memory,
|
|
_curr_memory,
|
|
snodes_allocfree,
|
|
buf_to_snode_last_use,
|
|
name_to_freeable_input_buf,
|
|
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
|
|
runtimes: dict[BaseSchedulerNode, float] = {
|
|
snode: estimate_op_runtime(snode) for snode in snodes
|
|
}
|
|
# debug stats
|
|
stats: dict[BaseSchedulerNode, ReorderInfo] = {}
|
|
|
|
def exposed_communication_time(
|
|
collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
|
|
) -> float:
|
|
# assumes a linear schedule and computes the overlap of the collective with the remaining nodes
|
|
comm_time = estimate_op_runtime(collective_snode)
|
|
compute_time = 0.0
|
|
for snode in remaining_snodes:
|
|
if contains_collective(snode):
|
|
continue
|
|
if contains_wait(snode):
|
|
# TODO - if the wait is for a collective that started before this collective or on another stream,
|
|
# we can ignore it. Otherwise, it's the end of the road for overlap opportunities
|
|
break
|
|
|
|
def accumulate_time(_snode: BaseSchedulerNode) -> None:
|
|
nonlocal compute_time
|
|
compute_time += runtimes[_snode]
|
|
|
|
_temp_group_visit_leaves(snode, accumulate_time)
|
|
return max(0, comm_time - compute_time)
|
|
|
|
total_moves = 0
|
|
|
|
_prev, _next, _head = _initialize_double_linked_list(snodes)
|
|
|
|
def _group_nodes(
|
|
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
|
|
) -> list[BaseSchedulerNode]:
|
|
ret = []
|
|
n = head
|
|
while True:
|
|
if n is not None:
|
|
ret.append(n)
|
|
if n == tail:
|
|
break
|
|
n = _next[n] # type: ignore[index]
|
|
return ret
|
|
|
|
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
|
|
# swap (candidate, group_head...group_tail)
|
|
# Before:
|
|
# candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
|
|
# After:
|
|
# candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
|
|
# 0
|
|
candidate_prev = _prev[candidate]
|
|
if candidate_prev:
|
|
_next[candidate_prev] = group_head
|
|
_prev[group_head] = candidate_prev
|
|
|
|
# 2
|
|
group_tail_next = _next[group_tail]
|
|
if group_tail_next:
|
|
_prev[group_tail_next] = candidate
|
|
_next[candidate] = group_tail_next
|
|
|
|
# 1
|
|
_prev[candidate] = group_tail
|
|
_next[group_tail] = candidate
|
|
|
|
nonlocal _head
|
|
if _head == candidate:
|
|
_head = group_head
|
|
|
|
def _calculate_potential_peak_memory(
|
|
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
|
|
):
|
|
# Caching calculations of memory for group nodes and candidate,
|
|
# to apply without recalculation after swap.
|
|
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
|
|
potential_peak: int = 0
|
|
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
|
|
# Not accounting for buffers last use change
|
|
potential_peak = max(
|
|
group_peak_memory - candidate_delta_mem,
|
|
_curr_memory[group_tail][1]
|
|
- candidate_delta_mem
|
|
+ candidate_allocfree.size_alloc,
|
|
)
|
|
return potential_peak, _post_alloc_update
|
|
|
|
# If candidate will be after group, the starting memory level of group nodes
|
|
# changes to the -(candidate.size_alloc - candidate.size_free)
|
|
mem_after_reorder_delta: int = -candidate_delta_mem
|
|
for gn in gns:
|
|
gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
|
|
_post_alloc_update[gn] = gn_post_alloc_mem
|
|
potential_peak = max(potential_peak, gn_post_alloc_mem)
|
|
|
|
bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
|
|
if bufs is not None:
|
|
for buf in bufs:
|
|
# Candidate will deallocate those buffers
|
|
mem_after_reorder_delta += buf.mpi_buffer.size_free
|
|
|
|
candidate_mem_post_alloc = (
|
|
_curr_memory[group_tail][1]
|
|
+ mem_after_reorder_delta
|
|
+ candidate_allocfree.size_alloc
|
|
)
|
|
_post_alloc_update[candidate] = candidate_mem_post_alloc
|
|
potential_peak = max(potential_peak, candidate_mem_post_alloc)
|
|
return potential_peak, _post_alloc_update
|
|
|
|
def _update_memory_tracking_after_swap(
|
|
candidate,
|
|
gns,
|
|
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
|
_post_alloc_update,
|
|
):
|
|
if not group_n_to_bufs_after_swap_dealloc_by_candidate:
|
|
for gn in gns:
|
|
cm = _curr_memory[gn]
|
|
_curr_memory[gn] = (
|
|
cm[0] - candidate_delta_mem,
|
|
cm[1] - candidate_delta_mem,
|
|
)
|
|
_candidate_post_alloc_mem = (
|
|
_curr_memory[group_tail][1] + candidate_allocfree.size_alloc
|
|
)
|
|
_candidate_post_free_mem = (
|
|
_candidate_post_alloc_mem - candidate_allocfree.size_free
|
|
)
|
|
_curr_memory[candidate] = (
|
|
_candidate_post_alloc_mem,
|
|
_candidate_post_free_mem,
|
|
)
|
|
return
|
|
|
|
# Candidate becomes last use of some bufs
|
|
for (
|
|
gn,
|
|
bufs,
|
|
) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
|
|
for buf in bufs:
|
|
buf_to_snode_last_use[buf] = candidate
|
|
|
|
size_free_to_move_to_candidate_sum: int = 0
|
|
for n in gns:
|
|
_gn_post_alloc_mem: int = _post_alloc_update[n]
|
|
size_free_to_move_to_candidate: int = sum(
|
|
buf.mpi_buffer.size_free
|
|
for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
|
|
)
|
|
size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
|
|
# group node does not deallocate this after swap
|
|
snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
|
|
gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
|
|
_curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
|
|
_candidate_post_alloc_mem = _post_alloc_update[candidate]
|
|
snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
|
|
candidate_post_free_mem = (
|
|
_candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
|
|
)
|
|
_curr_memory[candidate] = (
|
|
_candidate_post_alloc_mem,
|
|
candidate_post_free_mem,
|
|
)
|
|
|
|
debug_num_collectives_to_reorder: Optional[int] = (
|
|
config.reorder_iterative_debug_limit_to_reorder
|
|
)
|
|
|
|
num_processed_collectives: int = 0
|
|
curr = _head
|
|
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
|
|
iterative_recompute_error = False
|
|
|
|
while _next[curr] is not None:
|
|
if iterative_recompute_error:
|
|
break
|
|
if contains_collective(curr):
|
|
if debug_num_collectives_to_reorder is not None and (
|
|
num_processed_collectives >= debug_num_collectives_to_reorder
|
|
):
|
|
break
|
|
num_processed_collectives += 1
|
|
|
|
info = stats[curr] = ReorderInfo()
|
|
info.initial_exposed = info.final_exposed = exposed_communication_time(
|
|
curr, _group_nodes(_next[curr], None)
|
|
)
|
|
|
|
candidate = _prev[curr]
|
|
group_head = curr
|
|
group_tail = curr
|
|
group_peak_memory = _curr_memory[curr][0] # post_alloc memory
|
|
while candidate is not None:
|
|
if contains_collective(candidate):
|
|
info.limiting_factor = "collective ordering"
|
|
break
|
|
|
|
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
|
|
group = GroupedSchedulerNode(
|
|
curr.scheduler,
|
|
gns,
|
|
temp_grouping=True,
|
|
)
|
|
|
|
# We can have multiple deps with the same name.
|
|
# As we ignore WeakDep(is_fake=True) =>
|
|
# filter them out first to avoid overwriting of real dep.
|
|
data_deps = {
|
|
d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d)
|
|
}
|
|
|
|
candidate_outs = candidate.get_outputs()
|
|
data_dep = None
|
|
for o in candidate_outs:
|
|
if d := data_deps.get(o.get_name(), None):
|
|
data_dep = d
|
|
break
|
|
|
|
if data_dep is not None:
|
|
|
|
def is_groupable(
|
|
candidate: BaseSchedulerNode,
|
|
) -> tuple[bool, Optional[str]]:
|
|
# preserve ordering
|
|
if contains_collective(candidate):
|
|
return False, "contains_collective"
|
|
|
|
if contains_gemm_like(candidate):
|
|
return False, "contains_gemm_like"
|
|
return True, None
|
|
|
|
is_groupable_result, grouping_reason = is_groupable(candidate)
|
|
if is_groupable_result:
|
|
group_head = candidate
|
|
group_peak_memory = max(
|
|
group_peak_memory, _curr_memory[candidate][0]
|
|
)
|
|
info.grouped += 1
|
|
info.grouped_info = _group_names(gns)
|
|
candidate = _prev[candidate]
|
|
continue
|
|
else:
|
|
msg = (
|
|
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
|
f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
|
|
f"dep on {_group_names(gns)}"
|
|
f"\n non_group_reason:{grouping_reason}"
|
|
)
|
|
info.limiting_factor = msg
|
|
break
|
|
|
|
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
|
|
candidate_delta_mem: int = (
|
|
candidate_allocfree.size_alloc - candidate_allocfree.size_free
|
|
)
|
|
# candidate and one of group nodes are successors of the same buffer
|
|
# and last use of the buffer happen in group nodes.
|
|
# This last use deallocates it.
|
|
# If we swap [candidate [group]] to [[group] candidate],
|
|
# candidate becomes the last use
|
|
# and deallocated this buffer instead of group node.
|
|
# we need to update size_free accordingly to group_node and candidate,
|
|
# and recalculate post_alloc, post_free for them.
|
|
#
|
|
# Buf that changes its last use snode,
|
|
# after swap will be deallocated only by candidate,
|
|
# while before it was deallocated by group node.
|
|
group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
|
|
BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
|
|
] = defaultdict(list)
|
|
for (
|
|
buf,
|
|
snode_last_use,
|
|
) in buf_to_snode_last_use.items():
|
|
succ_nodes = buf.mpi_buffer.succ_nodes
|
|
if candidate not in succ_nodes:
|
|
continue
|
|
|
|
if not any(gn == snode_last_use for gn in gns):
|
|
continue
|
|
|
|
group_n_to_bufs_after_swap_dealloc_by_candidate[
|
|
snode_last_use
|
|
].append(buf)
|
|
|
|
potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
|
|
candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
|
|
)
|
|
|
|
if potential_peak > peak_memory:
|
|
info.limiting_factor = (
|
|
f"peak memory new:{potential_peak} vs base:{peak_memory}"
|
|
)
|
|
break
|
|
info.moves += 1
|
|
total_moves += 1
|
|
|
|
_perform_double_linked_list_swap(candidate, group_head, group_tail)
|
|
|
|
info.final_exposed = exposed_communication_time(
|
|
curr, _group_nodes(_next[curr], None)
|
|
)
|
|
|
|
_update_memory_tracking_after_swap(
|
|
candidate,
|
|
gns,
|
|
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
|
_post_alloc_update,
|
|
)
|
|
|
|
if debug_iterative_memory_recompute:
|
|
# Compare iteratively recomputed memory data
|
|
# with full run of estimate_peak_memory
|
|
|
|
from .comms_debug import _debug_iterative_memory_recompute
|
|
|
|
iterative_recompute_error = _debug_iterative_memory_recompute(
|
|
candidate,
|
|
gns,
|
|
_group_names(gns),
|
|
_group_nodes(_head, None),
|
|
name_to_freeable_input_buf,
|
|
graph_outputs,
|
|
peak_memory,
|
|
_curr_memory,
|
|
snodes_allocfree,
|
|
"reorder_communication_preserving_peak_memory",
|
|
group_n_to_bufs_after_swap_dealloc_by_candidate,
|
|
)
|
|
if iterative_recompute_error:
|
|
break
|
|
candidate = _prev[group_head]
|
|
curr = _next[curr] # type: ignore[assignment]
|
|
|
|
node_stats = stats
|
|
improvement = {snode: node_stats[snode].improvement for snode in node_stats}
|
|
total_improvement = sum([improvement[snode] for snode in improvement])
|
|
total_moves = sum([node_stats[snode].moves for snode in node_stats])
|
|
|
|
reorder_log_str = (
|
|
f"reorder_communication_preserving_peak_memory improved overlap by {total_improvement} ns"
|
|
f" after {total_moves} reorders.\n"
|
|
)
|
|
headers = [
|
|
"Collective node",
|
|
"initial exposed",
|
|
"final exposed",
|
|
"improvement",
|
|
"limiting factor",
|
|
"moves",
|
|
"grouped",
|
|
"grouped_info",
|
|
]
|
|
rows = [
|
|
[
|
|
node_summary(snode),
|
|
node_info.initial_exposed,
|
|
node_info.final_exposed,
|
|
node_info.improvement,
|
|
node_info.limiting_factor,
|
|
node_info.moves,
|
|
node_info.grouped,
|
|
node_info.grouped_info,
|
|
]
|
|
for snode, node_info in node_stats.items()
|
|
]
|
|
if importlib.util.find_spec("tabulate"):
|
|
from tabulate import tabulate
|
|
|
|
reorder_log_str += tabulate(
|
|
rows,
|
|
headers=headers,
|
|
)
|
|
else:
|
|
reorder_log_str += (
|
|
"Please `pip install tabulate` to nicely render overlap stats.\n"
|
|
)
|
|
reorder_log_str += str(headers) + "\n"
|
|
reorder_log_str += "\n".join(map(str, rows))
|
|
|
|
new_snodes = _group_nodes(_head, None)
|
|
assert len(new_snodes) == original_snodes_num
|
|
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
|
|
new_snodes, name_to_freeable_input_buf, graph_outputs
|
|
)
|
|
reorder_log_str += f"\n peak_memory_before:{peak_memory}"
|
|
reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
|
|
|
|
overlap_log.info(reorder_log_str)
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "reorder_communication_preserving_peak_memory",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: reorder_log_str,
|
|
)
|
|
|
|
return new_snodes, stats
|
|
|
|
|
|
def _schedule_for_comm(
|
|
snodes: list[BaseSchedulerNode],
|
|
raise_comms: bool,
|
|
sink_waits: bool,
|
|
reorder_for_overlap: bool,
|
|
) -> list[BaseSchedulerNode]:
|
|
"""
|
|
Schedule `snodes` for various comm optimization objectives.
|
|
|
|
Args:
|
|
snodes: the nodes to be scheduled.
|
|
raise_comms: whether to greedily schedule collectives as early as possible
|
|
sink_wait: whether to greedily schedule waits as late as possible
|
|
reorder_compute_for_overlap: whether to reorder compute nodes to
|
|
optimize for compute/communication overlapping.
|
|
|
|
Returns:
|
|
The new schedule order.
|
|
|
|
Some notes on the synergy between different options:
|
|
- `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
|
|
- When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
|
|
"""
|
|
# We assign each node a tuple of scores (score_0, score_1, score_2),
|
|
# decreasing in importance, with a lower value indicating a higher ranking:
|
|
#
|
|
# - score_0: the lowest comm_idx among the comm nodes that the node blocks.
|
|
# If a node doesn't block any comm nodes, its score_0 is set to
|
|
# sys.maxsize. This score ensures that comm nodes get scheduled as early as
|
|
# possible.
|
|
# - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
|
|
# that wait nodes are deferred as late as possible.
|
|
# - score_2: the index of the node in the original topological order. This
|
|
# score provides stability in case of ties.
|
|
#
|
|
# When only raise_comms is True, only score_0 and score_2 are considered.
|
|
# When only sink_waits is True, only score_1 and score_2 are considered.
|
|
# When neither is True, the original order is yielded.
|
|
buf_name_to_snode = {}
|
|
name_to_fused_node = {}
|
|
scores_0, scores_1, scores_2 = {}, {}, {}
|
|
for idx, snode in enumerate(snodes):
|
|
for buf_name in snode.get_buffer_names():
|
|
buf_name_to_snode[buf_name] = snode
|
|
|
|
for op_name in snode.get_operation_names():
|
|
name_to_fused_node[op_name] = snode
|
|
name_to_fused_node[snode.get_name()] = snode
|
|
|
|
node_name = snode.get_name()
|
|
scores_0[node_name] = sys.maxsize
|
|
scores_1[node_name] = 0
|
|
scores_2[node_name] = idx
|
|
|
|
comm_idx = 0
|
|
for snode in snodes:
|
|
if raise_comms and contains_collective(snode):
|
|
scores_0[snode.get_name()] = comm_idx
|
|
for ancestor in snode.ancestors:
|
|
anc_fused_name = name_to_fused_node[ancestor].get_name()
|
|
scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
|
|
comm_idx += 1
|
|
elif sink_waits and contains_wait(snode):
|
|
scores_1[snode.get_name()] = 1
|
|
|
|
class Runnable:
|
|
def __init__(self, snode) -> None:
|
|
self.snode = snode
|
|
name = next(iter(snode.get_operation_names()))
|
|
fused_name = name_to_fused_node[name].get_name()
|
|
self.score = (
|
|
scores_0[fused_name],
|
|
scores_1[fused_name],
|
|
scores_2[fused_name],
|
|
)
|
|
|
|
def __lt__(self, other):
|
|
return self.score < other.score
|
|
|
|
unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
|
|
snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
|
|
for snode in snodes
|
|
}
|
|
|
|
ready: list[Runnable] = []
|
|
buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
|
|
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
|
|
|
|
for snode, deps in unmet_deps.items():
|
|
if len(deps) == 0:
|
|
heapq.heappush(ready, Runnable(snode))
|
|
for dep in deps:
|
|
buffer_users[dep].add(snode)
|
|
|
|
scheduled = []
|
|
|
|
def schedule(snode):
|
|
"""
|
|
Schedules `snode` and put all unblocked nodes onto the ready queue.
|
|
"""
|
|
scheduled.append(snode)
|
|
for buf_name in snode.get_buffer_names():
|
|
for snode in buffer_users[buf_name]:
|
|
unmet_deps[snode].remove(buf_name)
|
|
if len(unmet_deps[snode]) == 0:
|
|
heapq.heappush(ready, Runnable(snode))
|
|
|
|
def get_overlapping_candidate():
|
|
"""
|
|
Return the next node in the ready queue that's neither a collective or
|
|
a wait.
|
|
"""
|
|
candidates = [
|
|
x
|
|
for x in ready
|
|
if not contains_collective(x.snode) and not contains_wait(x.snode)
|
|
]
|
|
if len(candidates) == 0:
|
|
return None
|
|
return min(candidates, key=lambda x: x.score)
|
|
|
|
def schedule_collective_for_overlap(snode):
|
|
"""
|
|
Schedules collective node `snode`, along with one or more compute nodes
|
|
to overlap with it. The strategy is described in the comment of
|
|
`reorder_compute_for_overlap`.
|
|
"""
|
|
assert contains_collective(snode)
|
|
schedule(snode)
|
|
|
|
collective_cost = snode_to_cost[snode]
|
|
while (
|
|
collective_cost > 0
|
|
and (candidate := get_overlapping_candidate()) is not None
|
|
):
|
|
ready.remove(candidate)
|
|
schedule(candidate.snode)
|
|
collective_cost -= snode_to_cost[candidate.snode]
|
|
heapq.heapify(ready)
|
|
|
|
while len(ready):
|
|
snode = heapq.heappop(ready).snode
|
|
if reorder_for_overlap and contains_collective(snode):
|
|
schedule_collective_for_overlap(snode)
|
|
else:
|
|
schedule(snode)
|
|
|
|
for snode, deps in unmet_deps.items():
|
|
assert len(deps) == 0, (
|
|
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
|
|
)
|
|
return scheduled
|
|
|
|
|
|
def decide_global_ordering_of_comms(
|
|
nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
|
) -> list[BaseSchedulerNode]:
|
|
"""
|
|
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
|
(might not be the same ordering as the eager mode program).
|
|
TODO: Come up with a better approach
|
|
"""
|
|
if not torch.distributed.is_available():
|
|
return nodes
|
|
|
|
comm_nodes = [n for n in nodes if contains_collective(n)]
|
|
|
|
for i in range(1, len(comm_nodes)):
|
|
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
|
|
mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
|
|
for buf in comm_nodes[i - 1].get_buffer_names():
|
|
comm_nodes[i].add_fake_dep(
|
|
WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
|
|
)
|
|
|
|
return nodes
|
|
|
|
|
|
@dataclass
|
|
class SinkWaitInfo:
|
|
grouped: int = 0
|
|
grouped_info: str = ""
|
|
moves: int = 0
|
|
moves_info: str = ""
|
|
limiting_factor: str = "None"
|
|
|
|
|
|
def _sink_waits_iterative_internal(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
|
|
from torch._inductor.scheduler import GroupedSchedulerNode
|
|
|
|
original_snodes_num = len(snodes)
|
|
if original_snodes_num == 0:
|
|
return snodes, {}
|
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
|
(
|
|
peak_memory,
|
|
_curr_memory,
|
|
snodes_allocfree,
|
|
buf_to_snode_last_use,
|
|
name_to_freeable_input_buf,
|
|
) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
|
|
|
|
_prev, _next, _head = _initialize_double_linked_list(snodes)
|
|
|
|
stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
|
|
|
|
def _group_nodes(
|
|
head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
|
|
) -> list[BaseSchedulerNode]:
|
|
ret = []
|
|
n = head
|
|
while True:
|
|
if n is not None:
|
|
ret.append(n)
|
|
if n == tail:
|
|
break
|
|
n = _next[n] # type: ignore[index]
|
|
return ret
|
|
|
|
def _calculate_potential_peak_memory(
|
|
candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
|
|
):
|
|
pre_group_mem = (
|
|
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
|
|
)
|
|
# Stash memory tracing updates to not recompute them after swap
|
|
_post_alloc_update: dict[BaseSchedulerNode, int] = {}
|
|
_size_free_delta_update: dict[BaseSchedulerNode, int] = {}
|
|
|
|
potential_peak = 0
|
|
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
|
# Not accounting for buffers liveliness change
|
|
potential_peak = max(
|
|
group_peak_memory + candidate_delta_mem,
|
|
pre_group_mem + candidate_allocfree.size_alloc,
|
|
)
|
|
return potential_peak, _post_alloc_update, _size_free_delta_update
|
|
|
|
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
|
|
_post_alloc_update[candidate] = candidate_post_alloc
|
|
potential_peak = candidate_post_alloc
|
|
candidate_size_free_to_move = sum(
|
|
buf.mpi_buffer.size_free # type: ignore[attr-defined]
|
|
for buf in itertools.chain.from_iterable(
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
|
|
)
|
|
)
|
|
_size_free_delta_update[candidate] = -candidate_size_free_to_move
|
|
delta_mem = candidate_delta_mem + candidate_size_free_to_move
|
|
for gn in gns:
|
|
gn_post_alloc = _curr_memory[gn][0] + delta_mem
|
|
_post_alloc_update[gn] = gn_post_alloc
|
|
potential_peak = max(potential_peak, gn_post_alloc)
|
|
gn_size_free_to_add = 0
|
|
if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
|
bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
|
|
for buf in bufs:
|
|
gn_size_free_to_add += buf.mpi_buffer.size_free
|
|
_size_free_delta_update[gn] = gn_size_free_to_add
|
|
delta_mem -= gn_size_free_to_add
|
|
return potential_peak, _post_alloc_update, _size_free_delta_update
|
|
|
|
def _perform_double_linked_list_swap(candidate, group_head, group_tail):
|
|
# group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
|
|
# 0:
|
|
group_head_prev = _prev[group_head]
|
|
if group_head_prev:
|
|
_next[group_head_prev] = candidate
|
|
_prev[candidate] = group_head_prev
|
|
|
|
# 2:
|
|
candidate_next = _next[candidate]
|
|
if candidate_next:
|
|
_prev[candidate_next] = group_tail
|
|
_next[group_tail] = candidate_next
|
|
|
|
# 1:
|
|
_prev[group_head] = candidate
|
|
_next[candidate] = group_head
|
|
nonlocal _head
|
|
if group_head == _head:
|
|
_head = candidate
|
|
|
|
def _update_memory_tracking_after_swap(
|
|
candidate,
|
|
gns,
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
|
_post_alloc_update,
|
|
_size_free_delta_update,
|
|
):
|
|
group_head = gns[0]
|
|
pre_group_mem = (
|
|
_curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
|
|
)
|
|
if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
|
|
candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
|
|
_curr_memory[candidate] = (
|
|
candidate_post_alloc,
|
|
candidate_post_alloc - candidate_allocfree.size_free,
|
|
)
|
|
for gn in gns:
|
|
cm = _curr_memory[gn]
|
|
_curr_memory[gn] = (
|
|
cm[0] + candidate_delta_mem,
|
|
cm[1] + candidate_delta_mem,
|
|
)
|
|
return
|
|
|
|
for n in [candidate, *gns]:
|
|
post_alloc = _post_alloc_update[n]
|
|
snodes_allocfree[n].size_free += _size_free_delta_update[n]
|
|
_curr_memory[n] = (
|
|
post_alloc,
|
|
post_alloc - snodes_allocfree[n].size_free,
|
|
)
|
|
|
|
curr = snodes[-1]
|
|
|
|
processed_waits = OrderedSet() # type: ignore[var-annotated]
|
|
debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
|
|
debug_num_sink_waits_to_reorder: Optional[int] = (
|
|
config.sink_waits_iterative_debug_limit_to_sink
|
|
)
|
|
|
|
iterative_recompute_error = False
|
|
|
|
while _prev[curr] is not None:
|
|
if iterative_recompute_error:
|
|
break
|
|
if (
|
|
debug_num_sink_waits_to_reorder is not None
|
|
and len(processed_waits) >= debug_num_sink_waits_to_reorder
|
|
):
|
|
break
|
|
|
|
if contains_wait(curr) and curr not in processed_waits:
|
|
processed_waits.add(curr)
|
|
info = stats[curr] = SinkWaitInfo()
|
|
candidate = _next[curr]
|
|
wait_snode = curr
|
|
group_head = curr
|
|
group_tail = curr
|
|
group_peak_memory = _curr_memory[curr][0]
|
|
while candidate is not None:
|
|
if iterative_recompute_error:
|
|
break
|
|
gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
|
|
group = GroupedSchedulerNode(
|
|
wait_snode.scheduler,
|
|
gns,
|
|
temp_grouping=True,
|
|
)
|
|
|
|
# We can have multiple deps with the same name.
|
|
# As we ignore WeakDep(is_fake=True) =>
|
|
# filter them out first to avoid overwriting of real dep.
|
|
data_deps = {
|
|
d.name: d
|
|
for d in candidate.unmet_dependencies
|
|
if not _is_fake_dep(d)
|
|
}
|
|
|
|
group_outs = group.get_outputs()
|
|
data_dep = None
|
|
for o in group_outs:
|
|
if d := data_deps.get(o.get_name(), None):
|
|
data_dep = d
|
|
break
|
|
# 1. If we have data_dep - we can not swap => trying to group
|
|
# 2. If swap candidate and current node both contain collectives => trying to group
|
|
if data_dep is not None or (
|
|
both_contain_comms := (
|
|
contains_collective(group) and contains_collective(candidate)
|
|
)
|
|
):
|
|
|
|
def is_groupable(snode):
|
|
# We do not want to group with collectives to not reorder them forward.
|
|
if contains_collective(snode):
|
|
return (
|
|
False,
|
|
f"candidate contains collective {snode.get_name()}",
|
|
)
|
|
if contains_gemm_like(snode):
|
|
return (
|
|
False,
|
|
f"candidate contains gemm_like {snode.get_name()}",
|
|
)
|
|
return True, None
|
|
|
|
is_grp, grp_reason = is_groupable(candidate)
|
|
if is_grp:
|
|
group_tail = candidate
|
|
group_peak_memory = max(
|
|
group_peak_memory, _curr_memory[candidate][0]
|
|
)
|
|
info.grouped += 1
|
|
info.grouped_info = _group_names(gns)
|
|
candidate = _next[candidate]
|
|
continue
|
|
elif (data_dep is None) and both_contain_comms:
|
|
info.limiting_factor = (
|
|
f"collective ordering {_group_names(gns)}"
|
|
f" with candidate:{candidate.get_name()}"
|
|
)
|
|
break
|
|
else:
|
|
info.limiting_factor = (
|
|
f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
|
|
f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
|
|
f"dep on {gns}"
|
|
f"\n outs:{[o.get_name() for o in group_outs]}"
|
|
f"\n non_group_reason:{grp_reason}"
|
|
)
|
|
break
|
|
candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
|
|
candidate_delta_mem = (
|
|
candidate_allocfree.size_alloc - candidate_allocfree.size_free
|
|
)
|
|
# [group] candidate -> candidate [group]
|
|
# Check for buffers with successors in group and candidate last successor
|
|
#
|
|
# Buf that changes its last use snode,
|
|
# It was deallocated by candidate,
|
|
# but after swap it will be deallocated by group node.
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
|
|
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
|
|
] = defaultdict(list)
|
|
for (
|
|
buf,
|
|
snode_last_use,
|
|
) in buf_to_snode_last_use.items():
|
|
succ_nodes = buf.mpi_buffer.succ_nodes
|
|
if snode_last_use != candidate: # noqa: E711
|
|
continue
|
|
# candidate is last use of buf
|
|
last_succ_gn = None
|
|
for gn in gns:
|
|
if gn in succ_nodes:
|
|
last_succ_gn = gn
|
|
if last_succ_gn is None:
|
|
continue
|
|
|
|
# gn has successors of buf that after potential swap will become
|
|
# last use of buf and start deallocating buf instead of candidate
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
|
|
last_succ_gn
|
|
].append(buf)
|
|
|
|
potential_peak, _post_alloc_update, _size_free_delta_update = (
|
|
_calculate_potential_peak_memory(
|
|
candidate,
|
|
gns,
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
|
)
|
|
)
|
|
if potential_peak > peak_memory:
|
|
info.limiting_factor = (
|
|
f"peak memory new:{potential_peak} vs base:{peak_memory}"
|
|
)
|
|
break
|
|
|
|
info.moves += 1
|
|
info.moves_info += f"+{candidate.get_name()}"
|
|
|
|
_perform_double_linked_list_swap(candidate, group_head, group_tail)
|
|
|
|
_update_memory_tracking_after_swap(
|
|
candidate,
|
|
gns,
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
|
_post_alloc_update,
|
|
_size_free_delta_update,
|
|
)
|
|
|
|
if debug_iterative_memory_recompute:
|
|
from .comms_debug import _debug_iterative_memory_recompute
|
|
|
|
iterative_recompute_error = _debug_iterative_memory_recompute(
|
|
candidate,
|
|
gns,
|
|
_group_names(gns),
|
|
_group_nodes(_head, None),
|
|
name_to_freeable_input_buf,
|
|
graph_outputs,
|
|
peak_memory,
|
|
_curr_memory,
|
|
snodes_allocfree,
|
|
"sink_waits_iterative",
|
|
group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
|
|
)
|
|
if iterative_recompute_error:
|
|
break
|
|
|
|
candidate = _next[group_tail]
|
|
curr = _prev[curr] # type: ignore[assignment]
|
|
|
|
headers = [
|
|
"Wait node",
|
|
"grouped",
|
|
"grouped_info",
|
|
"moves",
|
|
"moves_info",
|
|
"limiting factor",
|
|
]
|
|
rows = [
|
|
[
|
|
node_summary(snode),
|
|
info.grouped,
|
|
info.grouped_info,
|
|
info.moves,
|
|
info.moves_info,
|
|
info.limiting_factor,
|
|
]
|
|
for snode, info in stats.items()
|
|
]
|
|
log_str = ""
|
|
if importlib.util.find_spec("tabulate"):
|
|
from tabulate import tabulate
|
|
|
|
log_str += tabulate(
|
|
rows,
|
|
headers=headers,
|
|
)
|
|
else:
|
|
log_str += "Please `pip install tabulate` to nicely render overlap stats.\n"
|
|
log_str += str(headers) + "\n"
|
|
log_str += "\n".join(map(str, rows))
|
|
overlap_log.info(log_str)
|
|
new_snodes = _group_nodes(_head, None)
|
|
assert len(new_snodes) == original_snodes_num
|
|
new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
|
|
new_snodes, name_to_freeable_input_buf, graph_outputs
|
|
)
|
|
log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
|
|
log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
|
|
trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "sink_waits_iterative_info",
|
|
"encoding": "string",
|
|
},
|
|
payload_fn=lambda: log_str,
|
|
)
|
|
return new_snodes, stats
|
|
|
|
|
|
def sink_waits_iterative(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> list[BaseSchedulerNode]:
|
|
return _sink_waits_iterative_internal(snodes)[0]
|
|
|
|
|
|
def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
|
|
"""
|
|
Returns estimated op runtime in nanoseconds (ns)
|
|
"""
|
|
if config.estimate_op_runtime == "default":
|
|
runtime = snode.get_estimated_runtime()
|
|
else:
|
|
assert callable(config.estimate_op_runtime)
|
|
runtime = config.estimate_op_runtime(snode)
|
|
return runtime
|
|
|
|
|
|
def node_summary(snode):
|
|
snodes = snode.get_nodes()
|
|
if len(snodes) == 1:
|
|
detail = ""
|
|
if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
|
|
outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
|
|
ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
|
|
detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})"
|
|
layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
|
|
out_tensor_info = ",".join(
|
|
[
|
|
f" (size={layout.size}, stride={layout.stride})"
|
|
if isinstance(layout, ir.Layout)
|
|
else ""
|
|
for layout in layouts
|
|
]
|
|
)
|
|
try:
|
|
node_name = snode.node.maybe_get_name()
|
|
except AttributeError:
|
|
# TODO: node_summary was written without FusedSchedulerNode in mind, generally needs to be hardened
|
|
node_name = ""
|
|
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name} ({snode.get_estimated_runtime():.0f} ns)"
|
|
|
|
# Flatten the summaries for Fused/Foreach/Grouped nodes
|
|
summaries = []
|
|
for child_snode in snodes:
|
|
summaries.append(node_summary(child_snode))
|
|
return f"{snode.__class__.__name__}: {', '.join(summaries)}"
|
|
|
|
|
|
def visualize_overlap(order):
|
|
# TODO - this function probably doesn't do a very good job estimating the runtime because it doesn't carefully model
|
|
# streams and overlap. For now its mostly useful as a debug visualization.
|
|
|
|
total_est_runtime: float = 0.0
|
|
cur_comm_node = None
|
|
|
|
def step_log(step, msg):
|
|
overlap_log.debug(f"{step:>6}: {msg}") # noqa: G004
|
|
|
|
for step, snode in enumerate(order):
|
|
if cur_comm_node is None:
|
|
if contains_collective(snode):
|
|
total_est_runtime += estimate_op_runtime(snode)
|
|
cur_comm_node = snode.node
|
|
elif is_wait(snode.node):
|
|
# raise AssertionError(
|
|
# "Wait is not expected when there is no collective running"
|
|
# )
|
|
pass
|
|
else: # exposed compute op
|
|
total_est_runtime += estimate_op_runtime(snode)
|
|
step_log(step, f"{node_summary(snode)}")
|
|
else: # cur_comm_node is not None
|
|
if contains_collective(snode):
|
|
total_est_runtime += estimate_op_runtime(snode)
|
|
cur_comm_node = snode.node
|
|
step_log(step, f"{node_summary(snode)}") # noqa: G004
|
|
elif is_wait(snode.node): # end of this comm op
|
|
step_log(step, f"{node_summary(snode)}")
|
|
cur_comm_node = None
|
|
else: # overlapped compute op
|
|
step_log(step, f"| {node_summary(snode)}")
|
|
overlap_log.debug(
|
|
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
|
|
)
|
|
|
|
|
|
def reorder_compute_and_comm_for_overlap(
|
|
snodes: list[BaseSchedulerNode],
|
|
) -> list[BaseSchedulerNode]:
|
|
order = snodes
|
|
graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
|
|
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
|
|
for p in config.reorder_for_compute_comm_overlap_passes:
|
|
if isinstance(p, str) and p in globals():
|
|
p = globals()[p] # it is a builtin pass
|
|
assert callable(p), (
|
|
f"Invalid reorder_compute_and_comm_for_overlap pass: {p} is not callable"
|
|
)
|
|
peak_memory, _ = estimate_peak_memory(
|
|
snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs
|
|
)
|
|
if torch.distributed.get_rank() == 0:
|
|
overlap_log.debug(
|
|
f"==== Visualize overlap before reordering pass {p}, {peak_memory=} ====" # noqa: G004
|
|
)
|
|
try:
|
|
visualize_overlap(order)
|
|
except Exception as e:
|
|
overlap_log.debug("", exc_info=e)
|
|
t0 = time.time()
|
|
order = p(order) # type: ignore[operator]
|
|
t = time.time() - t0
|
|
if torch.distributed.get_rank() == 0:
|
|
overlap_log.debug(
|
|
f"==== Visualize overlap after reordering pass {p} (ran in {t} sec)====" # noqa: G004
|
|
)
|
|
try:
|
|
visualize_overlap(order)
|
|
except Exception as e:
|
|
overlap_log.debug("", exc_info=e)
|
|
peak_memory, _ = estimate_peak_memory(
|
|
snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs
|
|
)
|
|
print(f"final {peak_memory=}")
|
|
return order
|
|
|
|
|
|
def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
|
|
"""
|
|
This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding
|
|
graph intermediates that were fsdp.copy_ into the unsharded params in the original graph.
|
|
|
|
NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern
|
|
(or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case
|
|
where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't
|
|
remove these resize and copy ops and thus we will have worse performance there.
|
|
|
|
In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param"
|
|
is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern
|
|
(in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed.
|
|
"""
|
|
node_list = list(graph.nodes)
|
|
|
|
# Find all graph inputs and their resize counts
|
|
graph_input_to_resized_to_full_node_idxes = defaultdict(list)
|
|
graph_input_to_resized_to_0_node_idxes = defaultdict(list)
|
|
for idx, node in enumerate(node_list):
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
|
):
|
|
assert node.args[0].op == "placeholder", f"""\
|
|
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
|
|
"""
|
|
graph_input = node.args[0]
|
|
new_size = node.args[1]
|
|
if new_size > 0:
|
|
graph_input_to_resized_to_full_node_idxes[graph_input].append(idx)
|
|
else:
|
|
graph_input_to_resized_to_0_node_idxes[graph_input].append(idx)
|
|
|
|
def check_resize_pattern(graph_input):
|
|
# Check the number of resize-to-full and resize-to-0 nodes are equal,
|
|
# and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node
|
|
# always happens before the resize-to-0 node.
|
|
# This is the precondition for being able to remove all the resize and copy nodes
|
|
# for this specific unsharded param.
|
|
resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get(
|
|
graph_input, []
|
|
)
|
|
resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
|
|
|
|
if not len(resized_to_full_idxes) == len(resized_to_0_idxes):
|
|
log.warning(
|
|
f"""
|
|
Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:
|
|
{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}.
|
|
Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass.
|
|
""" # noqa: G004
|
|
)
|
|
return False
|
|
|
|
# Check the sequence: (resize_to_full -> resize_to_0)+
|
|
for resize_to_full_idx, resize_to_0_idx in zip(
|
|
resized_to_full_idxes, resized_to_0_idxes
|
|
):
|
|
if resize_to_full_idx >= resize_to_0_idx:
|
|
log.warning(
|
|
f"""
|
|
For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx}
|
|
happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}.
|
|
Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param.
|
|
""" # noqa: G004
|
|
)
|
|
return False
|
|
return True
|
|
|
|
# Find all eligible unsharded params and their corresponding graph intermediates.
|
|
unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
|
|
for idx, node in enumerate(node_list):
|
|
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
|
|
fsdp_copy_node = node
|
|
unsharded_param = node.args[0]
|
|
assert unsharded_param.op == "placeholder", f"""
|
|
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
|
|
Offending node: {unsharded_param}. Graph: {graph}
|
|
"""
|
|
if check_resize_pattern(unsharded_param):
|
|
unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx)
|
|
|
|
def is_allowed_mutation(node):
|
|
return (
|
|
node.target == torch.ops.fsdp.copy_.default
|
|
or node.target == torch.ops.inductor.resize_storage_bytes_.default
|
|
)
|
|
|
|
def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
|
|
# Check whether the node is mutating any of the unsharded params or their aliases.
|
|
mutated_arg_idxes = (
|
|
[
|
|
i
|
|
for i, x in enumerate(node.target._schema.arguments)
|
|
if x.alias_info is not None and x.alias_info.is_write
|
|
]
|
|
if isinstance(node.target, torch._ops.OpOverload)
|
|
else []
|
|
)
|
|
mutated_node_arg_storages = OrderedSet(
|
|
[
|
|
StorageWeakRef(node.args[i].meta["val"].untyped_storage())
|
|
for i in mutated_arg_idxes
|
|
]
|
|
)
|
|
storages_of_unsharded_params = OrderedSet(
|
|
[
|
|
StorageWeakRef(unsharded_param.meta["val"].untyped_storage())
|
|
for unsharded_param in unsharded_params
|
|
]
|
|
)
|
|
return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0
|
|
|
|
# Check no user mutation on any unsharded_param
|
|
for node in node_list:
|
|
if (
|
|
node.op == "call_function"
|
|
and isinstance(node.target, torch._ops.OpOverload)
|
|
and node.target._schema.is_mutable
|
|
and not is_allowed_mutation(node)
|
|
):
|
|
assert not is_node_mutating_unsharded_param_or_its_alias(
|
|
node, unsharded_param_to_fsdp_copy_node_idxes.keys()
|
|
), f"""\
|
|
User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node}
|
|
"""
|
|
|
|
# For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`.
|
|
#
|
|
# NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input.
|
|
# e.g.
|
|
# ```
|
|
# fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1)
|
|
# ... (use of unsharded_param_1) -> Subgraph 1
|
|
# fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2)
|
|
# ... (use of unsharded_param_1) -> Subgraph 2
|
|
# fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3)
|
|
# ... (use of unsharded_param_1) -> Subgraph 3
|
|
# ```
|
|
# We must do the replacement only within each subgraph.
|
|
for (
|
|
unsharded_param,
|
|
fsdp_copy_node_idxes,
|
|
) in unsharded_param_to_fsdp_copy_node_idxes.items():
|
|
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
|
|
fsdp_copy_node = node_list[fsdp_copy_node_idx]
|
|
assert fsdp_copy_node.args[0] is unsharded_param
|
|
_, replacement = fsdp_copy_node.args
|
|
# subgraph_start_idx is exclusive
|
|
subgraph_start_idx = fsdp_copy_node_idx + 1
|
|
# subgraph_end_idx is exclusive (also intentionally don't replace args in return op)
|
|
subgraph_end_idx = (
|
|
fsdp_copy_node_idxes[i + 1]
|
|
if i < len(fsdp_copy_node_idxes) - 1
|
|
else len(node_list) - 1
|
|
)
|
|
subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx]
|
|
assert not any(
|
|
is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param])
|
|
for node in subgraph_nodes
|
|
), f"""\
|
|
Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true!
|
|
Graph: {graph}
|
|
"""
|
|
for node in subgraph_nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and unsharded_param in node.args
|
|
and node.target != torch.ops.inductor.resize_storage_bytes_.default
|
|
): # TODO(yf225): implement replacement in kwargs
|
|
new_args = tuple(
|
|
replacement if arg is unsharded_param else arg
|
|
for arg in node.args
|
|
)
|
|
node.args = new_args
|
|
|
|
# Delete `fsdp.copy_(unsharded_param, Y)` nodes
|
|
for (
|
|
unsharded_param,
|
|
fsdp_copy_node_idxes,
|
|
) in unsharded_param_to_fsdp_copy_node_idxes.items():
|
|
for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
|
|
fsdp_copy_node = node_list[fsdp_copy_node_idx]
|
|
graph.erase_node(fsdp_copy_node)
|
|
|
|
# Delete `resize_(unsharded_param, ...)` nodes
|
|
for node in node_list:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
|
and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
|
|
):
|
|
graph.erase_node(node)
|
|
|
|
|
|
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
|
try:
|
|
import torch.distributed.fsdp._fully_shard._fsdp_collectives
|
|
|
|
assert torch.distributed.is_available()
|
|
# Assert existence of these ops
|
|
assert (
|
|
torch.ops._c10d_functional.all_gather_into_tensor
|
|
and torch.ops._c10d_functional.all_gather_into_tensor_out
|
|
)
|
|
except (ImportError, AttributeError, AssertionError):
|
|
return
|
|
|
|
from .pattern_matcher import (
|
|
CallFunction,
|
|
KeywordArg,
|
|
Match,
|
|
PatternMatcherPass,
|
|
register_graph_pattern,
|
|
)
|
|
|
|
"""
|
|
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
|
|
getitem = all_gather_copy_in[0];
|
|
(getitem_1 = all_gather_copy_in[1];) # optional
|
|
|
|
all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
|
|
|
|
->
|
|
|
|
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
|
|
getitem = all_gather_copy_in[0];
|
|
getitem_1 = all_gather_copy_in[1];
|
|
|
|
all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
|
|
"""
|
|
|
|
def remove_unused_getitem(g):
|
|
# Remove `getitem_X = all_gather_copy_in[1]` which is never used.
|
|
node_list = list(g.nodes)
|
|
for n in node_list:
|
|
if (
|
|
n.target == operator.getitem
|
|
and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
|
|
and n.args[1] == 1
|
|
):
|
|
g.erase_node(n)
|
|
|
|
graph_pass = PatternMatcherPass()
|
|
|
|
@register_graph_pattern(
|
|
CallFunction(
|
|
torch.ops._c10d_functional.all_gather_into_tensor.default,
|
|
CallFunction(
|
|
operator.getitem,
|
|
CallFunction(
|
|
torch.ops.fsdp.all_gather_copy_in.default,
|
|
KeywordArg("all_gather_inputs"),
|
|
KeywordArg("all_gather_output"),
|
|
KeywordArg("inp_split_sizes"),
|
|
KeywordArg("all_gather_input_numel"),
|
|
KeywordArg("rank"),
|
|
),
|
|
KeywordArg("item_idx"),
|
|
),
|
|
KeywordArg("group_size"),
|
|
KeywordArg("group_name"),
|
|
),
|
|
pass_dict=graph_pass,
|
|
extra_check=lambda match: match.kwargs["item_idx"] == 0,
|
|
)
|
|
def reinplace_all_gather(match: Match, *args, **kwargs):
|
|
def repl(
|
|
*args,
|
|
):
|
|
copy_in_args = args[:-2]
|
|
group_size = args[-2]
|
|
group_name = args[-1]
|
|
all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
|
|
*copy_in_args
|
|
)
|
|
getitem = all_gather_copy_in[0]
|
|
getitem_1 = all_gather_copy_in[1]
|
|
all_gather_into_tensor = (
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
|
|
getitem, group_size, group_name, out=getitem_1
|
|
)
|
|
)
|
|
return all_gather_into_tensor
|
|
|
|
match.replace_by_example(
|
|
repl,
|
|
[
|
|
kwargs["all_gather_inputs"],
|
|
kwargs["all_gather_output"],
|
|
kwargs["inp_split_sizes"],
|
|
kwargs["all_gather_input_numel"],
|
|
kwargs["rank"],
|
|
kwargs["group_size"],
|
|
kwargs["group_name"],
|
|
],
|
|
)
|
|
|
|
remove_unused_getitem(graph)
|
|
graph_pass.apply(graph) # type: ignore[arg-type]
|
|
|
|
|
|
def get_op_idx(snode):
|
|
assert not isinstance(
|
|
snode,
|
|
(
|
|
torch._inductor.scheduler.FusedSchedulerNode,
|
|
torch._inductor.scheduler.GroupedSchedulerNode,
|
|
),
|
|
)
|
|
return int(snode.get_name()[2:])
|
|
|
|
|
|
def enforce_comm_ordering_for_fsdp(
|
|
snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
|
|
name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
|
name_to_fused_node: dict[str, BaseSchedulerNode],
|
|
) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
|
|
from . import scheduler
|
|
|
|
new_order: list[BaseSchedulerNode] = []
|
|
scheduled = OrderedSet[Any]()
|
|
ag_exists = False
|
|
rs_exists = False
|
|
ag_grouped_node_to_wait_grouped_node = {}
|
|
rs_grouped_node_to_wait_grouped_node = {}
|
|
snode_name_to_final_snode = {}
|
|
|
|
def _create_group_node(snodes_to_group):
|
|
group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
|
|
for snode in snodes_to_group:
|
|
snode_name_to_final_snode[snode.get_name()] = group_node
|
|
snode_name_to_final_snode[group_node.get_name()] = group_node
|
|
return group_node
|
|
|
|
# Create grouped nodes for specific sets of ops
|
|
for snode in snodes:
|
|
# Case 1: Handle AllGather
|
|
if is_collective(
|
|
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
|
|
) and any(
|
|
is_fallback_op(
|
|
name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
|
|
)
|
|
for x in snode.ancestors
|
|
):
|
|
ag_exists = True
|
|
ag_snode = snode
|
|
ag_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
|
|
|
|
# Find the "cast + copy_in + getitem + all_gather" code block
|
|
find_recursive_deps_of_node(
|
|
ag_snode,
|
|
ag_related_snode_set,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
)
|
|
|
|
# Find the "all_gather + all_gather_wait_tensor + copy_out" code block
|
|
allowed_ops = OrderedSet(
|
|
[
|
|
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
|
|
torch.ops._c10d_functional.wait_tensor.default,
|
|
torch.ops.fsdp.split_with_sizes_copy.default,
|
|
]
|
|
)
|
|
find_recursive_users_of_node(
|
|
ag_snode,
|
|
ag_related_snode_set,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
criteria_cb=lambda x: not (
|
|
isinstance(x, scheduler.NopKernelSchedulerNode)
|
|
or (
|
|
isinstance(x, scheduler.ExternKernelSchedulerNode)
|
|
and x.node.op_overload in allowed_ops # type: ignore[union-attr]
|
|
)
|
|
),
|
|
)
|
|
|
|
# sort nodes by original operation order
|
|
ag_related_snodes = sorted(
|
|
ag_related_snode_set, key=lambda x: get_op_idx(x)
|
|
)
|
|
|
|
# In the "reuse layer" case, some ops in the 2nd all-gather code block could also
|
|
# depend on ops in the 1st all-gather code block, and we don't want to group them together.
|
|
end_idx_of_current_ag_block = len(ag_related_snodes)
|
|
copy_out_count = 0
|
|
for i in range(len(ag_related_snodes)):
|
|
cur_snode = ag_related_snodes[i]
|
|
if is_fallback_op(
|
|
cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
|
|
):
|
|
copy_out_count += 1
|
|
if copy_out_count > 1:
|
|
end_idx_of_current_ag_block = i
|
|
break
|
|
|
|
ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
|
|
|
|
# Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
|
|
wait_node_idx = None
|
|
for i in range(len(ag_related_snodes) - 1):
|
|
if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
|
|
wait_node_idx = i + 1
|
|
break
|
|
assert wait_node_idx is not None
|
|
ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
|
|
|
|
# Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode
|
|
ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
|
|
|
|
ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
|
|
|
|
# Case 2: Handle ReduceScatter
|
|
elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
|
|
rs_exists = True
|
|
rs_snode = snode
|
|
|
|
# Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
|
|
rs_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
|
|
find_recursive_users_of_node(
|
|
rs_snode,
|
|
rs_related_snode_set,
|
|
name_to_buf,
|
|
name_to_fused_node,
|
|
)
|
|
|
|
# sort nodes by original operation order
|
|
rs_related_snodes = sorted(
|
|
rs_related_snode_set, key=lambda x: get_op_idx(x)
|
|
)
|
|
|
|
# Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
|
|
wait_node_idx = None
|
|
for i in range(len(rs_related_snodes) - 1):
|
|
if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
|
|
wait_node_idx = i + 1
|
|
break
|
|
assert wait_node_idx is not None
|
|
rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
|
|
|
|
# Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
|
|
rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
|
|
|
|
rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
|
|
|
|
assert len(snode_name_to_final_snode) > 0
|
|
if ag_exists:
|
|
assert len(ag_grouped_node_to_wait_grouped_node) > 0
|
|
if rs_exists:
|
|
assert len(rs_grouped_node_to_wait_grouped_node) > 0
|
|
|
|
# Build the new node schedule, taking GroupedSchedulerNode into account
|
|
for snode in snodes:
|
|
if snode.get_name() in snode_name_to_final_snode:
|
|
snode = snode_name_to_final_snode[snode.get_name()]
|
|
if snode in scheduled:
|
|
continue
|
|
new_order.append(snode)
|
|
scheduled.add(snode)
|
|
|
|
# Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
|
|
# before next AllGather's "copy_in then AG" group node
|
|
prev_ag_wait = None
|
|
for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
|
|
if prev_ag_wait is not None:
|
|
mutating_buf = next(iter(ag_group_node.get_buffer_names()))
|
|
for o in prev_ag_wait.get_outputs():
|
|
ag_group_node.add_fake_dep(
|
|
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
|
|
)
|
|
prev_ag_wait = wait_group_node
|
|
|
|
# Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
|
|
# before next ReduceScatter's "copy_in then RS" group node
|
|
prev_rs_wait = None
|
|
for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
|
|
if prev_rs_wait is not None:
|
|
mutating_buf = next(iter(rs_group_node.get_buffer_names()))
|
|
for o in prev_rs_wait.get_outputs():
|
|
rs_group_node.add_fake_dep(
|
|
WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
|
|
)
|
|
prev_rs_wait = wait_group_node
|
|
|
|
return new_order # type: ignore[return-value]
|