Compare commits

...

2 Commits

Author SHA1 Message Date
bd04dc7750 small changes
[ghstack-poisoned]
2025-11-14 10:30:28 -08:00
3b90bf36f9 Add multiple hiding nodes
[ghstack-poisoned]
2025-11-14 09:04:39 -08:00
4 changed files with 236 additions and 58 deletions

View File

@ -29,7 +29,7 @@ from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0):
# for tests, assume a matmul can hide a single collective
if "c10" in str(fx_node.target):
return 1.0
@ -1016,6 +1016,63 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(a, b, c)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multiple_hiding_nodes_bucketing(self):
"""Test that collectives hidden by multiple compute ops can bucket together."""
# Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden
def estimate_with_half_compute(fx_node, override_size=None):
return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5)
def func(a, b, *, ranks):
# Two all_gathers that will be hidden by multiple compute operations
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
# Multiple compute operations that can hide the collectives
# With 0.5 multiplier: mm1 and mm2 together hide ag1, mm2 and mm3 together hide ag2
mm1 = torch.matmul(a, a.T)
mm2 = torch.matmul(b, b.T)
mm3 = torch.matmul(a + b, (a + b).T)
return ag1.sum() + ag2.sum() + mm1.sum() + mm2.sum() + mm3.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
# Patch with custom estimation that uses 0.5 multiplier
with torch._inductor.config.patch(
{
"aten_distributed_optimizations.custom_runtime_estimation": estimate_with_half_compute
}
):
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b)
# Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together)
FileCheck().check_count(
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
).run(aten_graph_str)
# Verify bucketed collective is scheduled before all matmuls
FileCheck().check("functional.all_gather_into_tensor").check(
"aten.mm"
).check("aten.mm").check("aten.mm").check("wait_tensor").run(aten_graph_str)
# Verify correctness
correct = func(a, b, ranks=ranks)
self.assertTrue(same(out, correct))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -49,7 +49,8 @@ def build_collective_info(graph, hiding_annotations):
"""
Build CollectiveInfo dict from manual hiding annotations.
hiding_annotations: dict mapping collective_start -> hiding_compute_node
hiding_annotations: dict mapping collective_start -> hiding_compute_node(s)
Can be a single node or a list/OrderedSet of nodes
"""
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
@ -65,12 +66,20 @@ def build_collective_info(graph, hiding_annotations):
# Build CollectiveInfo for each collective
for start_node, wait_node in start_to_wait.items():
hiding_node = hiding_annotations.get(start_node)
hiding_annotation = hiding_annotations.get(start_node)
# Convert to OrderedSet
hiding_nodes = OrderedSet()
if hiding_annotation is not None:
if isinstance(hiding_annotation, list | OrderedSet):
hiding_nodes = OrderedSet(hiding_annotation)
else:
hiding_nodes = OrderedSet([hiding_annotation])
# Estimate size and time
size_bytes = 16 * 4 # 4x4 tensor of floats
estimated_time_ms = 1.0 # Dummy time
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
exposed_time_ms = 0.0 if hiding_nodes else 1.0 # Hidden if has hiding_nodes
collective_info[start_node] = CollectiveInfo(
start_node=start_node,
@ -78,7 +87,7 @@ def build_collective_info(graph, hiding_annotations):
size_bytes=size_bytes,
estimated_time_ms=estimated_time_ms,
exposed_time_ms=exposed_time_ms,
hiding_node=hiding_node,
hiding_nodes=hiding_nodes,
)
return collective_info
@ -567,6 +576,94 @@ class TestOverlapPreservingBucketing(InductorTestCase):
graph_str
)
def test_can_bucket_with_multiple_hiding_nodes(self):
"""
Test that collectives with multiple hiding nodes CAN bucket.
Graph structure:
ag1_start -> ag2_start -> mm1 -> mm2 -> mm3 -> ag1_wait -> ag2_wait
Where:
- ag1 is hidden by mm1 and mm2
- ag2 is hidden by mm2 and mm3
- Both collectives share mm2 as a hiding node
"""
def func(a, b):
group_name = "0"
group_size = 1
# Start both collectives
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
a, group_size, group_name
)
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
b, group_size, group_name
)
# Three compute operations that hide the collectives
mm1 = torch.mm(a, a)
mm2 = torch.mm(b, b)
mm3 = torch.mm(a + b, a + b)
# Wait for both
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + mm3.sum()
# Use fake mode to trace without executing
with FakeTensorMode():
a = torch.ones(4, 4, device=self.device)
b = torch.ones(4, 4, device=self.device) * 2
# Trace with make_fx
traced = make_fx(func)(a, b)
# Find nodes using find_nodes
ag1, ag2 = traced.graph.find_nodes(
op="call_function",
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
)
mm1, mm2, mm3 = traced.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
# Manually annotate hiding relationships with multiple hiding nodes
hiding_annotations = {
ag1: [mm1, mm2], # ag1 is hidden by mm1 and mm2
ag2: [mm2, mm3], # ag2 is hidden by mm2 and mm3
}
# Build collective info and ancestors
collective_info = build_collective_info(traced.graph, hiding_annotations)
node_ancestors = compute_ancestors(traced.graph)
scheduled = OrderedSet(traced.graph.nodes)
# Verify hiding_nodes are correctly set
self.assertEqual(len(collective_info[ag1].hiding_nodes), 2)
self.assertIn(mm1, collective_info[ag1].hiding_nodes)
self.assertIn(mm2, collective_info[ag1].hiding_nodes)
self.assertEqual(len(collective_info[ag2].hiding_nodes), 2)
self.assertIn(mm2, collective_info[ag2].hiding_nodes)
self.assertIn(mm3, collective_info[ag2].hiding_nodes)
# Run bucketing
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
traced.graph,
collective_info,
node_ancestors,
scheduled,
)
bucketer.bucket_collectives()
FileCheck().check_count(
"all_gather_into_tensor_out", 1, exactly=False
).check_count("%mm", 3, exactly=True).run(str(traced.graph))
if __name__ == "__main__":
run_tests()

View File

@ -176,6 +176,7 @@ class OverlapPreservingBucketer:
head = None
prev_event = None
position = 0
hiding_nodes = OrderedSet()
for node in self.scheduled:
node_type = None
@ -183,11 +184,12 @@ class OverlapPreservingBucketer:
# Determine if this node is relevant for this PG
if node in self.collective_info and get_group_name(node) == pg:
node_type = "starts"
hiding_nodes |= self.collective_info[node].hiding_nodes
elif is_wait_tensor(node):
wait_input = node.args[0]
if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg:
node_type = "waits"
elif is_compute_node(node):
elif is_compute_node(node) or node in hiding_nodes:
node_type = "compute"
if node_type is None:
@ -205,7 +207,6 @@ class OverlapPreservingBucketer:
prev_event = event
position += 1
return head
def _populate_node_to_event(self, pg: str) -> None:
@ -222,10 +223,12 @@ class OverlapPreservingBucketer:
Add hiding interval constraints: start -> compute -> wait.
"""
for start, info in self.collective_info.items():
if info.hiding_node and not info.is_exposed:
if info.is_exposed:
continue
for hn in info.hiding_nodes:
# Enforce: start -> compute -> wait
self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start)
self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
self.aug_graph.add_extra_dep(n=hn, dep=start)
self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn)
def bucket_collectives(self) -> None:
"""Main entry point for bucketing collectives."""
@ -358,13 +361,13 @@ class OverlapPreservingBucketer:
def _get_intervals(
self, event: PGEvent
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
"""Get (execution_interval, hiding_interval) for a collective event.
) -> tuple[Optional[tuple[int, int]], list[tuple[int, int]]]:
"""Get (execution_interval, hiding_intervals) for a collective event.
Returns:
(execution_interval, hiding_interval) where:
(execution_interval, hiding_intervals) where:
- execution_interval is (start_pos, wait_pos) or None
- hiding_interval is (start_pos, compute_pos) or None if no hiding node
- hiding_intervals is a list of (start_pos, compute_pos) tuples, one for each hiding node
Works for both start and wait events by looking up the collective info.
"""
@ -375,13 +378,13 @@ class OverlapPreservingBucketer:
elif event.is_wait:
wait_input = event.node.args[0]
if not isinstance(wait_input, fx.Node):
return None, None
return None, []
coll = wait_input
else:
return None, None
return None, []
if coll not in self.collective_info:
return None, None
return None, []
info = self.collective_info[coll]
start_event = self.node_to_event[coll]
@ -389,14 +392,17 @@ class OverlapPreservingBucketer:
execution_interval = (start_event.position, wait_event.position)
hiding_interval = None
if info.hiding_node:
hiding_interval = (
start_event.position,
self.node_to_event[info.hiding_node].position,
)
hiding_intervals = []
if info.hiding_nodes:
for hiding_node in info.hiding_nodes:
hiding_intervals.append(
(
start_event.position,
self.node_to_event[hiding_node].position,
)
)
return execution_interval, hiding_interval
return execution_interval, hiding_intervals
def _preserves_hiding_intervals(
self,
@ -424,9 +430,9 @@ class OverlapPreservingBucketer:
# Collect hiding compute positions for the bucket
bucket_hiding_compute_positions = []
for coll in all_bucketed_colls:
if hiding_node := self.collective_info[coll].hiding_node:
for coll_hiding_node in self.collective_info[coll].hiding_nodes:
bucket_hiding_compute_positions.append(
self.node_to_event[hiding_node].position
self.node_to_event[coll_hiding_node].position
)
# Get new positions
@ -478,11 +484,10 @@ class OverlapPreservingBucketer:
curr_event.node not in all_bucketed_colls
and curr_event.node not in all_bucketed_waits
):
exec_interval, hiding_interval = self._get_intervals(curr_event)
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
if exec_interval:
execution_intervals.append(exec_interval)
if hiding_interval:
hiding_intervals.append(hiding_interval)
hiding_intervals.extend(hiding_interval_list)
curr_event = curr_event.next
curr_event = new_wait_event.prev
@ -491,11 +496,10 @@ class OverlapPreservingBucketer:
curr_event.node not in all_bucketed_colls
and curr_event.node not in all_bucketed_waits
):
exec_interval, hiding_interval = self._get_intervals(curr_event)
exec_interval, hiding_interval_list = self._get_intervals(curr_event)
if exec_interval:
execution_intervals.append(exec_interval)
if hiding_interval:
hiding_intervals.append(hiding_interval)
hiding_intervals.extend(hiding_interval_list)
curr_event = curr_event.prev
# Check: no hiding interval should be enclosed by any execution interval
@ -659,12 +663,12 @@ class OverlapPreservingBucketer:
return True
# Check if existing hiding node conflicts with candidate wait
if hiding_node := self.collective_info[coll].hiding_node:
if self._ancestor_dep(hiding_node, candidate_wait):
for old_hiding_node in self.collective_info[coll].hiding_nodes:
if self._ancestor_dep(old_hiding_node, candidate_wait):
return True
# Check if candidate hiding node conflicts with existing wait
if new_hiding_node := candidate_info.hiding_node:
for new_hiding_node in candidate_info.hiding_nodes:
if self._ancestor_dep(new_hiding_node, coll_wait):
return True

View File

@ -4,9 +4,9 @@ import itertools
import logging
import sys
from collections import Counter, defaultdict
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from typing import Any
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Callable
import torch
import torch.fx as fx
@ -45,11 +45,12 @@ def get_group_name(n: fx.Node) -> str:
def get_custom_estimation(
n: fx.Node,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
override_size=None,
) -> float | None:
if custom_runtime_estimation is None:
return None
return custom_runtime_estimation(n)
return custom_runtime_estimation(n, override_size)
def estimate_collective_time(
@ -58,7 +59,9 @@ def estimate_collective_time(
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
) -> float:
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None:
if (
est := get_custom_estimation(n, custom_runtime_estimation, override_size)
) is not None:
return est
return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
@ -67,13 +70,28 @@ def estimate_collective_time(
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
# todo - symbolic
size += t.numel() * t.element_size()
"""Estimate the size of a collective operation in bytes.
return size
For all_gather and reduce_scatter, both input and output buffers are needed.
For all_reduce, it can typically be done in-place, so only one buffer is needed.
"""
from torch._inductor.fx_passes.bucketing import (
is_all_reduce_tensor as is_all_reduce,
)
input_node = fx_node.args[0]
input_tensor = input_node.meta.get("val")
input_size = input_tensor.numel() * input_tensor.element_size()
output_tensor = fx_node.meta.get("val")
output_size = output_tensor.numel() * output_tensor.element_size()
# all_reduce can be in-place, so only needs one buffer
if is_all_reduce(fx_node):
return max(input_size, output_size)
# otherwise assume both live concurrently
return input_size + output_size
def is_compute_node(n: fx.Node) -> bool:
@ -188,7 +206,7 @@ class CollectiveInfo:
size_bytes: int
estimated_time_ms: float
exposed_time_ms: float # How much of this collective is still exposed
hiding_node: fx.Node | None = None # Node that hides this collective
hiding_nodes: OrderedSet[fx.Node] = field(default_factory=OrderedSet)
@property
def is_exposed(self) -> bool:
@ -419,6 +437,8 @@ class OverlapScheduler:
self._handle_collective_start(node)
elif is_wait_tensor(node):
self._handle_wait(node)
elif node.op == "placeholder":
self._schedule(node)
else:
self._handle_other(node)
@ -447,11 +467,13 @@ class OverlapScheduler:
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for start_node, info in self.collective_info.items():
if info.hiding_node and not info.is_exposed:
if info.is_exposed:
continue
for hn in info.hiding_nodes:
# Compute depends on collective start (compute must wait for collective to start)
additional_deps[info.hiding_node].add(start_node)
additional_deps[hn].add(start_node)
# Wait depends on compute (wait must wait for compute to finish)
additional_deps[info.wait_node].add(info.hiding_node)
additional_deps[info.wait_node].add(hn)
# Apply effect tokens to preserve these dependencies
if additional_deps:
@ -592,9 +614,8 @@ class OverlapScheduler:
overlap_amount = min(info.exposed_time_ms, available_compute)
info.exposed_time_ms -= overlap_amount
available_compute -= overlap_amount
if info.exposed_time_ms == 0:
info.hiding_node = node
elif available_compute == 0:
info.hiding_nodes.add(node)
if available_compute == 0:
break
# Then, look for unscheduled collectives we can overlap
@ -687,8 +708,7 @@ class OverlapScheduler:
# after scheduling, which will account for latency reduction of bucketing
overlap_amount = min(available_compute_time, info.exposed_time_ms)
info.exposed_time_ms -= overlap_amount
if info.exposed_time_ms == 0:
info.hiding_node = compute_node
info.hiding_nodes.add(compute_node)
available_compute_time -= overlap_amount
def _find_schedulable_path(
@ -715,7 +735,7 @@ class OverlapScheduler:
# it's fine to schedule it
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
if info.hiding_node and info.hiding_node != curr_compute_node:
if info.hiding_nodes and curr_compute_node not in info.hiding_nodes:
continue
elif node not in self.potentially_hidden_waits:
continue
@ -751,7 +771,7 @@ class OverlapScheduler:
) -> bool:
assert is_wait_tensor(wait_node)
info = self.collective_info[self.wait_to_start[wait_node]]
return not info.is_exposed and info.hiding_node != compute_node
return not info.is_exposed and compute_node not in info.hiding_nodes
def _schedule_path_to_collective(
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
@ -770,7 +790,7 @@ class OverlapScheduler:
continue
info = self.collective_info[self.wait_to_start[node]]
assert info.hiding_node != curr_compute_node
assert curr_compute_node not in info.hiding_nodes
self._handle_wait(node)
continue