mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Overlap scheduler improvements (#165318)
Bucketing a number of smallish improvements: - Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time - Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future. - When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well. Better Memory Handling: - Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing). - When we are above peak memory, schedule waits. TODO: - for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory - By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples. - config some hard coded constants, clean up enablement (can do in subsequent pr) On small llama 2d backward : 578 of 618 potentially hideable collectives hidden original mem 14.4GB, rescheduled mem, 15.9GB on forward: 254/256 potentially hideable collectives hidden original mem 5.8 gb, reshceduled mem 5.8GB WIP: adding tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783, #164944, #164945, #165059
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc1f2108d7
commit
b3f6d49b69
@ -70,6 +70,8 @@ def get_patches():
|
||||
"force_disable_caches": True,
|
||||
# Messes up existing test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
|
||||
|
||||
@ -364,6 +366,8 @@ def get_bucket_patches(compute_multiplier=1.0):
|
||||
"force_disable_caches": True,
|
||||
# messes up test strings
|
||||
"test_configs.aten_fx_overlap_insert_overlap_deps": False,
|
||||
# interferes with testing, / custom estimation
|
||||
"test_configs.assume_bucketing_reduces_latency": False,
|
||||
}
|
||||
|
||||
|
||||
@ -579,7 +583,7 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap_blocking(self):
|
||||
def test_bucketing_split_for_overlap_blocking_no_deps(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
|
@ -252,11 +252,6 @@ class TestMemoryTracker(InductorTestCase):
|
||||
if node.op not in ("placeholder", "get_attr", "output")
|
||||
]
|
||||
|
||||
if len(compute_nodes) < 3:
|
||||
self.skipTest(
|
||||
f"Need at least 3 compute nodes, got {len(compute_nodes)}"
|
||||
)
|
||||
|
||||
# Test original order: zeros_like, add, sum
|
||||
# zeros gets freed after sum (last use of zeros)
|
||||
memory_tracker1 = MemoryTracker(fx_graph.graph, device_filter=device_filter)
|
||||
|
@ -356,7 +356,9 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||
return size
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> float:
|
||||
def estimate_nccl_collective_runtime_from_fx_node(
|
||||
fx_node: torch.fx.Node, override_size: Optional[int] = None
|
||||
) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -371,7 +373,10 @@ def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> flo
|
||||
"""
|
||||
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||
|
||||
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
|
||||
if override_size is None:
|
||||
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
|
||||
else:
|
||||
tensor_storage_size_bytes = override_size
|
||||
|
||||
assert not isinstance(fx_node.target, str)
|
||||
opt_args_kwargs = normalize_function(
|
||||
|
@ -2072,6 +2072,9 @@ class test_configs:
|
||||
# to be migrated when ready for use
|
||||
aten_fx_overlap_preserving_bucketing = False
|
||||
|
||||
# mostly disabled testing
|
||||
assume_bucketing_reduces_latency = True
|
||||
|
||||
# to be migrated when ready for use
|
||||
# runtime estimation function for ops
|
||||
# for user-defined estimation function, pass in the function handle
|
||||
|
@ -34,6 +34,15 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
return (group_name, reduce_op, dtype)
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> Optional[object]:
|
||||
if is_all_gather_into_tensor(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter_tensor(node):
|
||||
return _rs_group_key(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
|
||||
"""
|
||||
Determine the size of a bucket based on its ID.
|
||||
|
@ -1,12 +1,10 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
_ag_group_key,
|
||||
_rs_group_key,
|
||||
bucket_key,
|
||||
is_all_gather_into_tensor as is_all_gather,
|
||||
is_reduce_scatter_tensor as is_reduce_scatter,
|
||||
is_wait_tensor,
|
||||
@ -15,15 +13,6 @@ from torch._inductor.fx_passes.overlap_scheduling import CollBucket, CollectiveI
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> Optional[object]:
|
||||
if is_all_gather(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter(node):
|
||||
return _rs_group_key(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class OverlapPreservingBucketer:
|
||||
"""
|
||||
Buckets collective operations while preserving compute-collective overlap relationships.
|
||||
|
@ -17,15 +17,31 @@ from torch._inductor.fx_passes.memory_estimator import (
|
||||
build_memory_profile,
|
||||
MemoryTracker,
|
||||
)
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torch._inductor.fx_passes.bucketing import bucket_key
|
||||
|
||||
from ..pattern_matcher import stable_topological_sort
|
||||
|
||||
|
||||
def get_group_name(n: fx.Node) -> str:
|
||||
"""Extract the group name from a collective operation node."""
|
||||
opt_args_kwargs = normalize_function(
|
||||
n.target, # type: ignore[arg-type]
|
||||
args=n.args,
|
||||
kwargs=n.kwargs,
|
||||
normalize_to_only_use_kwargs=True,
|
||||
)
|
||||
assert opt_args_kwargs is not None
|
||||
_, kwargs = opt_args_kwargs
|
||||
return kwargs["group_name"]
|
||||
|
||||
|
||||
def get_custom_estimation(n: fx.Node) -> Optional[float]:
|
||||
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
|
||||
if runtime_estimation == "default":
|
||||
@ -35,12 +51,13 @@ def get_custom_estimation(n: fx.Node) -> Optional[float]:
|
||||
return runtime_estimation(n)
|
||||
|
||||
|
||||
def estimate_collective_time(n: fx.Node) -> float:
|
||||
def estimate_collective_time(n: fx.Node, override_size: Optional[int] = None) -> float:
|
||||
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
|
||||
if (est := get_custom_estimation(n)) is not None:
|
||||
return est
|
||||
|
||||
return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
|
||||
n
|
||||
n, override_size
|
||||
)
|
||||
|
||||
|
||||
@ -55,6 +72,10 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
|
||||
|
||||
|
||||
def is_compute_node(n: fx.Node) -> bool:
|
||||
"""
|
||||
Should we consider this node computationally expensive ?
|
||||
Currently uses flop registration, but we could expand more generally.
|
||||
"""
|
||||
return (
|
||||
getattr(n.target, "overloadpacket", None)
|
||||
in torch.utils.flop_counter.flop_registry
|
||||
@ -173,6 +194,11 @@ class CollBucket:
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
def gb_to_bytes(gb: float) -> int:
|
||||
"""Convert gigabytes to bytes."""
|
||||
return int(gb * 1024 * 1024 * 1024)
|
||||
|
||||
|
||||
class OverlapScheduler:
|
||||
"""
|
||||
Scheduler that reorders operations to maximize compute-collective overlap.
|
||||
@ -180,7 +206,7 @@ class OverlapScheduler:
|
||||
The reordering is done as a scheduling pass. We maintain a priority queue of
|
||||
schedulable nodes. The nodes are ranked by:
|
||||
|
||||
1) the compute node depth they dominate. this allows reordering locally, such as with
|
||||
1) the compute node index they dominate. this allows reordering locally, such as with
|
||||
parallel mms, and also allows overlapping reduce scatter nodes outputs in the backward
|
||||
with compute by deferring their waits.
|
||||
|
||||
@ -200,15 +226,16 @@ class OverlapScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
max_in_flight_gb: float = 2.0,
|
||||
max_in_flight_gb: float = 0.5,
|
||||
compute_overlap_multipler: float = 2.0,
|
||||
max_coll_distance: int = 1000,
|
||||
max_compute_pre_fetch: int = 5,
|
||||
):
|
||||
self.gm = gm
|
||||
self.graph = gm.graph
|
||||
self.compute_overlap_multipler = compute_overlap_multipler
|
||||
self.max_node_distance = max_coll_distance
|
||||
self.max_in_flight_bytes: int = int(max_in_flight_gb * 1024 * 1024 * 1024)
|
||||
self.max_in_flight_bytes: int = gb_to_bytes(max_in_flight_gb)
|
||||
|
||||
# Build structures
|
||||
stable_topological_sort(self.graph)
|
||||
@ -231,8 +258,9 @@ class OverlapScheduler:
|
||||
self.wait_to_start: dict[fx.Node, fx.Node] = {}
|
||||
self._identify_collectives()
|
||||
|
||||
self.compute_depth = self._calculate_compute_node_depth()
|
||||
self.compute_index_domination = self._calculate_compute_node_domination_index()
|
||||
self.compute_nodes = [n for n in self.nodes if is_compute_node(n)]
|
||||
self.current_compute_index = 0
|
||||
|
||||
# Scheduling state
|
||||
self.potentially_hidden_collectives = (
|
||||
@ -249,6 +277,7 @@ class OverlapScheduler:
|
||||
self.in_flight: dict[fx.Node, CollectiveInfo] = {} # start -> info
|
||||
self.in_flight_bytes = 0
|
||||
self.scheduled: OrderedSet[fx.Node] = OrderedSet()
|
||||
self.max_compute_pre_fetch = max_compute_pre_fetch
|
||||
|
||||
def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
|
||||
"""Collect all ancestors for each node."""
|
||||
@ -260,6 +289,10 @@ class OverlapScheduler:
|
||||
|
||||
return ancestors
|
||||
|
||||
def off_compute_path(self, n: fx.Node) -> bool:
|
||||
"""Check if a node is off the compute path (doesn't block any compute)."""
|
||||
return self.compute_index_domination[n] == sys.maxsize
|
||||
|
||||
def _identify_collectives(self) -> None:
|
||||
"""Identify all collective operations."""
|
||||
for node in self.nodes:
|
||||
@ -278,51 +311,30 @@ class OverlapScheduler:
|
||||
self.wait_to_start[node] = start
|
||||
self.unscheduled_collectives.add(start)
|
||||
|
||||
def _calculate_compute_node_depth(self) -> dict[fx.Node, int]:
|
||||
"""Compute forward depth and minimum dominance depth (infinity if blocks no compute)."""
|
||||
|
||||
# First pass: forward compute depth
|
||||
in_degree: dict[fx.Node, int] = {}
|
||||
compute_depth: dict[fx.Node, int] = {}
|
||||
queue: list[fx.Node] = []
|
||||
def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]:
|
||||
"""
|
||||
Compute the topological index of the earliest compute node each node dominates.
|
||||
|
||||
Compute nodes are assigned indices based on their topological order (0, 1, 2, ...).
|
||||
For each node, returns the minimum index of compute nodes it blocks/dominates.
|
||||
Returns sys.maxsize if the node doesn't block any compute nodes.
|
||||
"""
|
||||
compute_node_index: dict[fx.Node, int] = {}
|
||||
for node in self.graph.nodes:
|
||||
num_inputs = len(node.all_input_nodes)
|
||||
if num_inputs == 0:
|
||||
queue.append(node)
|
||||
else:
|
||||
in_degree[node] = num_inputs
|
||||
|
||||
while queue:
|
||||
node = queue.pop()
|
||||
|
||||
max_input_depth = max(
|
||||
(compute_depth[inp] for inp in node.all_input_nodes), default=0
|
||||
)
|
||||
compute_depth[node] = max_input_depth + is_compute_node(node)
|
||||
|
||||
for use in node.users:
|
||||
in_degree[use] -= 1
|
||||
if in_degree[use] == 0:
|
||||
queue.append(use)
|
||||
|
||||
# Second pass: minimum dominance (what's the earliest compute this blocks)
|
||||
compute_depth_dominance: dict[fx.Node, int] = {}
|
||||
|
||||
for node in reversed(self.graph.nodes):
|
||||
if is_compute_node(node):
|
||||
# consider compute nodes to be at their own depth
|
||||
dominance = compute_depth[node]
|
||||
compute_node_index[node] = len(compute_node_index)
|
||||
|
||||
domination_index: dict[fx.Node, int] = {}
|
||||
for node in reversed(self.graph.nodes):
|
||||
if node in compute_node_index:
|
||||
# Compute nodes dominate themselves (return their own index)
|
||||
domination_index[node] = compute_node_index[node]
|
||||
else:
|
||||
# For non-compute nodes, find minimum compute they block
|
||||
dominance = min(
|
||||
(compute_depth_dominance[succ] for succ in node.users),
|
||||
default=sys.maxsize,
|
||||
domination_index[node] = min(
|
||||
(domination_index[succ] for succ in node.users), default=sys.maxsize
|
||||
)
|
||||
|
||||
compute_depth_dominance[node] = dominance
|
||||
|
||||
return compute_depth_dominance
|
||||
return domination_index
|
||||
|
||||
def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks(
|
||||
self,
|
||||
@ -435,6 +447,13 @@ class OverlapScheduler:
|
||||
self.scheduled.add(node)
|
||||
self.memory_tracker.schedule_node(node)
|
||||
|
||||
log.debug(
|
||||
"Scheduled node %s: current_memory=%d bytes, total_scheduled=%d",
|
||||
node.name,
|
||||
self.memory_tracker.get_current_memory_bytes(),
|
||||
len(self.scheduled),
|
||||
)
|
||||
|
||||
for user in node.users:
|
||||
self.in_degree[user] -= 1
|
||||
if self.in_degree[user] == 0:
|
||||
@ -445,15 +464,20 @@ class OverlapScheduler:
|
||||
|
||||
if is_wait_tensor(node):
|
||||
info = self.collective_info[self.wait_to_start[node]]
|
||||
# TODO: we could consider even deferring waits that are not potentially hidden
|
||||
# so as to overlap comm with itself. although exposed comms should bucketed with each other.
|
||||
overlappable = info.is_exposed and node in self.potentially_hidden_waits
|
||||
# defer waits locally if they are exposed.
|
||||
compute_local_priority = int(info.is_exposed)
|
||||
else:
|
||||
overlappable = self.in_overlappable_collective_unary_chain(node)
|
||||
# if we're scheduling this collective via its queue, then it was not
|
||||
# pre-fetched. we might as well maximize overlap for the
|
||||
# local, non-mm nodes prior to the next compute node.
|
||||
if self.in_overlappable_collective_unary_chain(node):
|
||||
compute_local_priority = -1
|
||||
else:
|
||||
compute_local_priority = 0
|
||||
|
||||
return (
|
||||
self.compute_depth[node], # what depth compute it blocks
|
||||
overlappable, # Defer hideable collective ops
|
||||
self.compute_index_domination[node], # what index compute it blocks
|
||||
compute_local_priority, # collective_start=-1, wait=1, or neither=0
|
||||
self.node_idx[node], # Original order for stability
|
||||
)
|
||||
|
||||
@ -473,7 +497,7 @@ class OverlapScheduler:
|
||||
return False
|
||||
|
||||
if user in self.unscheduled_collectives:
|
||||
return user in self.potentially_hidden_collectives
|
||||
return True
|
||||
|
||||
if not self.is_cheap_fn(user):
|
||||
return False
|
||||
@ -484,7 +508,11 @@ class OverlapScheduler:
|
||||
|
||||
def _should_force_wait_for_memory(self) -> bool:
|
||||
"""Check if we need to force a wait due to memory pressure"""
|
||||
return self.in_flight_bytes >= self.max_in_flight_bytes
|
||||
if not self.in_flight:
|
||||
return False
|
||||
return self.in_flight_bytes >= self.max_in_flight_bytes or (
|
||||
self.memory_tracker.current_memory_bytes - self.original_peak_memory
|
||||
) > gb_to_bytes(1.0)
|
||||
|
||||
def _force_oldest_wait(self) -> None:
|
||||
"""Schedule the oldest in flight wait"""
|
||||
@ -493,6 +521,12 @@ class OverlapScheduler:
|
||||
def _handle_collective_start(self, node: fx.Node) -> None:
|
||||
"""Handle scheduling a collective start."""
|
||||
info = self.collective_info[node]
|
||||
|
||||
if self.should_assume_bucketed(node):
|
||||
latency = estimate_collective_time(node, 0)
|
||||
assert latency <= info.exposed_time_ms
|
||||
info.exposed_time_ms = info.exposed_time_ms - latency
|
||||
|
||||
self.in_flight[node] = info
|
||||
self.in_flight_bytes += info.size_bytes
|
||||
self.unscheduled_collectives.discard(node)
|
||||
@ -502,8 +536,22 @@ class OverlapScheduler:
|
||||
"""Handle scheduling a wait."""
|
||||
assert node in self.wait_to_start
|
||||
coll_start = self.wait_to_start[node]
|
||||
|
||||
assert coll_start in self.in_flight
|
||||
|
||||
# Scheduling a wait of a collective also forces the wait
|
||||
# of every node enqueued prior to the collective on the
|
||||
# same process group
|
||||
group_name = get_group_name(coll_start)
|
||||
to_schedule: list[fx.Node] = []
|
||||
for in_flight_coll in self.in_flight:
|
||||
if in_flight_coll == coll_start:
|
||||
break
|
||||
if get_group_name(in_flight_coll) == group_name:
|
||||
to_schedule.append(in_flight_coll)
|
||||
|
||||
for coll_to_schedule in to_schedule:
|
||||
self._handle_wait(self.collective_info[coll_to_schedule].wait_node)
|
||||
|
||||
self.in_flight_bytes -= self.in_flight[coll_start].size_bytes
|
||||
del self.in_flight[coll_start]
|
||||
self._schedule(node)
|
||||
@ -514,6 +562,7 @@ class OverlapScheduler:
|
||||
compute_time = benchmark_node(node)
|
||||
available_compute = compute_time * self.compute_overlap_multipler
|
||||
|
||||
# TODO: separate overlap time per process group
|
||||
# First reduce exposed time of in-flight collectives
|
||||
for info in self.in_flight.values():
|
||||
if info.exposed_time_ms == 0:
|
||||
@ -531,6 +580,7 @@ class OverlapScheduler:
|
||||
self._schedule_collectives_for_overlap(node, available_compute)
|
||||
|
||||
self._schedule(node)
|
||||
self.current_compute_index += 1
|
||||
|
||||
def _schedule_collectives_for_overlap(
|
||||
self, compute_node: fx.Node, available_compute_time: float
|
||||
@ -538,15 +588,39 @@ class OverlapScheduler:
|
||||
"""Opportunistically schedule collectives that can be hidden by compute."""
|
||||
compute_ancestors = self.node_ancestors[compute_node]
|
||||
|
||||
# copy unscheduled_collectives to local because we modify it during iteration
|
||||
# Filter collectives by distance and compute index domination
|
||||
possible_collectives = []
|
||||
for collective in self.unscheduled_collectives:
|
||||
distance = abs(self.node_idx[compute_node] - self.node_idx[collective])
|
||||
if distance > self.max_node_distance:
|
||||
break
|
||||
|
||||
# Skip collectives that are too far ahead in compute index, but allow scheduling
|
||||
# collectives which are off compute path (which typically release memory)
|
||||
# TODO: we could potentially be more strict about limiting the amount of
|
||||
# pre-fetched memory before memory peak, and adjust allowed collective mem.
|
||||
if not self.off_compute_path(collective):
|
||||
if (
|
||||
self.compute_index_domination[collective]
|
||||
- self.current_compute_index
|
||||
) > self.max_compute_pre_fetch:
|
||||
continue
|
||||
|
||||
possible_collectives.append(collective)
|
||||
|
||||
possible_collectives = sorted(
|
||||
possible_collectives,
|
||||
key=lambda n: (self.compute_index_domination[n], self.node_idx[n]),
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"Scheduling collectives for overlap: compute_node=%s, available_time=%.2f ms, candidates=%d, current_memory=%d bytes",
|
||||
compute_node.name,
|
||||
available_compute_time,
|
||||
len(possible_collectives),
|
||||
self.memory_tracker.current_memory_bytes,
|
||||
)
|
||||
|
||||
for collective in possible_collectives:
|
||||
if available_compute_time == 0:
|
||||
break
|
||||
@ -575,15 +649,25 @@ class OverlapScheduler:
|
||||
if path is None:
|
||||
continue
|
||||
|
||||
log.debug(
|
||||
"Overlapping collective %s with compute %s: coll_domination=%d, current_depth=%d",
|
||||
collective.name,
|
||||
compute_node.name,
|
||||
self.compute_index_domination[collective],
|
||||
self.current_compute_index,
|
||||
)
|
||||
|
||||
# Schedule path to this collective
|
||||
self._schedule_path_to_collective(path, compute_node)
|
||||
self._handle_collective_start(collective)
|
||||
|
||||
# Update the exposed time for this newly scheduled collective
|
||||
overlap_amount = min(info.estimated_time_ms, available_compute_time)
|
||||
# after scheduling, which will account for latency reduction of bucketing
|
||||
overlap_amount = min(available_compute_time, info.exposed_time_ms)
|
||||
info.exposed_time_ms -= overlap_amount
|
||||
if info.exposed_time_ms == 0:
|
||||
info.hiding_node = compute_node
|
||||
available_compute_time -= overlap_amount
|
||||
self._handle_collective_start(collective)
|
||||
|
||||
def _find_schedulable_path(
|
||||
self, target: fx.Node, curr_compute_node: Optional[fx.Node]
|
||||
@ -618,6 +702,24 @@ class OverlapScheduler:
|
||||
|
||||
return unscheduled_ancestors
|
||||
|
||||
def should_assume_bucketed(self, node: fx.Node) -> bool:
|
||||
"""
|
||||
Check if there's an in-flight collective that can be bucketed with the given node. If so, assume they will bucket.
|
||||
This is a optimistic heuristic to account for latency reduction with bucketing. The two nodes may not get bucketed.
|
||||
"""
|
||||
if not torch._inductor.config.test_configs.assume_bucketing_reduces_latency:
|
||||
return False
|
||||
|
||||
key = bucket_key(node)
|
||||
if key is None:
|
||||
return False
|
||||
|
||||
for in_flight_coll in self.in_flight.keys():
|
||||
if bucket_key(in_flight_coll) == key:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_oldest_wait(self) -> fx.Node:
|
||||
oldest_start = next(iter(self.in_flight))
|
||||
return self.collective_info[oldest_start].wait_node
|
||||
@ -633,10 +735,18 @@ class OverlapScheduler:
|
||||
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
|
||||
) -> None:
|
||||
"""Schedule all nodes needed to reach a collective."""
|
||||
|
||||
assert all(n not in self.scheduled for n in path)
|
||||
for node in sorted(path, key=lambda n: self.node_idx[n]):
|
||||
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
|
||||
|
||||
if is_wait_tensor(node):
|
||||
# When we schedule wait tensors, we also force realization of all
|
||||
# collectives enqueued prior to their corresponding collective.
|
||||
# It's possible the scheduling of one wait tensor here has forced
|
||||
# another in the path. If so, skip scheduling it.
|
||||
if node in self.scheduled:
|
||||
continue
|
||||
|
||||
info = self.collective_info[self.wait_to_start[node]]
|
||||
assert info.hiding_node != curr_compute_node
|
||||
self._handle_wait(node)
|
||||
@ -672,15 +782,17 @@ class OverlapScheduler:
|
||||
counters["inductor"]["overlap_scheduling_potentially_hidden"] += len(
|
||||
potentially_hidden_collectives
|
||||
)
|
||||
|
||||
counters["inductor"]["overlap_original_mem"] = self.original_peak_memory
|
||||
counters["inductor"]["rescheduled_mem"] = self.memory_tracker.peak_memory
|
||||
|
||||
log.info(
|
||||
"Overlap scheduling: total exposed %s, total bad exposed %s, total potentially hidden %s",
|
||||
"Overlap scheduling results: exposed=%d, bad_exposed=%d, potentially_hidden=%d, "
|
||||
"original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes",
|
||||
len(exposed),
|
||||
len(bad_exposed),
|
||||
len(potentially_hidden_collectives),
|
||||
self.original_peak_memory,
|
||||
self.memory_tracker.peak_memory,
|
||||
)
|
||||
|
||||
self.reorder_graph()
|
||||
|
Reference in New Issue
Block a user