Add Comm-Compute Preserving Bucketer (#163960)

tl;dr performs bucketing while preserving comm-compute overlap.

In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
This commit is contained in:
eellison
2025-09-29 15:36:45 -07:00
committed by PyTorch MergeBot
parent 92108f4abd
commit 7d59e37434
5 changed files with 703 additions and 16 deletions

View File

@ -29,12 +29,12 @@ from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def estimate_aten_runtime(fx_node):
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
# for tests, assume a matmul can hide a single collective
if "c10" in str(fx_node.target):
return 1.0
elif fx_node.target == aten.mm.default:
return 1.0
return compute_multiplier
else:
return None
@ -347,6 +347,410 @@ graph():
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
def get_bucket_patches(compute_multiplier=1.0):
estimate_aten_runtime_part = functools.partial(
estimate_aten_runtime, compute_multiplier=compute_multiplier
)
return {
"test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
"test_configs.aten_fx_overlap_preserving_bucketing": True,
"reorder_for_locality": False,
"reorder_for_compute_comm_overlap_passes": [],
"compile_threads": 1,
"force_disable_caches": True,
}
class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_basic_all_gather_bucketing(self):
"""Test that independent all_gather operations get bucketed together."""
def func(a, b, c, *, ranks):
# Three independent all_gathers that should be bucketed
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + 3
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + 4
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + 5
return ag1 + ag2 + ag3
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = (
torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
)
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
inputs_c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(
compiled, inputs_a, inputs_b, inputs_c
)
# Should see a single bucketed all_gather
FileCheck().check_count(
"torch.ops._c10d_functional.all_gather_into_tensor", 1, exactly=True
).run(aten_graph_str)
correct = func(inputs_a, inputs_b, inputs_c, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_reduce_scatter_bucketing(self):
"""Test bucketing of reduce_scatter operations."""
def func(a, b, c):
rs1 = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
rs2 = _functional_collectives.reduce_scatter_tensor(b, "sum", 0, "0")
rs3 = _functional_collectives.reduce_scatter_tensor(c, "sum", 0, "0")
return torch.cat([rs1, rs2, rs3])
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = torch.ones(8, 4, dtype=torch.float, device=device_type)
inputs_b = torch.ones(8, 4, dtype=torch.float, device=device_type) * 2
inputs_c = torch.ones(8, 4, dtype=torch.float, device=device_type) * 3
out, aten_graph_str = run_and_get_aten_graph(
torch.compile(func), inputs_a, inputs_b, inputs_c
)
# Should bucket reduce_scatter ops
FileCheck().check_count(
"torch.ops._c10d_functional.reduce_scatter_tensor", 1, exactly=True
).run(aten_graph_str)
# TODO: debug - on ci this fails.
# correct = func(inputs_a, inputs_b, inputs_c)
# self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_no_bucketing_with_dependent_hiding_nodes(self):
"""Test that collectives with dependent hiding nodes don't get bucketed."""
def func(a, b, *, ranks):
# ag1 could be hidden by mm1
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
mm1 = torch.matmul(a, a)
# ag2 can be hidden by mm2, but mm2 depends on ag1's result
# ag2 start
mm2 = torch.matmul(ag1[:4], b)
# ag2 end
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
return ag1.sum() * ag2.sum() * mm1 * mm2
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type)
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs_a, inputs_b)
# mm2 depends on ag1, so if mm2 is to hide ag2, we can't bucket ag1 and ag2
# because that would create a dependency issue, even though we could bucket them
FileCheck().check_count(
"torch.ops._c10d_functional.all_gather_into_tensor", 2, exactly=True
).run(aten_graph_str)
correct = func(inputs_a, inputs_b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_no_bucketing_when_collective_depends_on_hiding_node(self):
"""Test that collectives don't get bucketed when one depends on another's hiding node."""
def func(a, *, ranks):
# ag1 hidden by mm1
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
mm1 = torch.matmul(a, a)
# ag2 depends on mm1 (which hides ag1)
b = mm1 * 2
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
return ag1.sum() * ag2.sum() * mm1
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type)
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# ag2 depends on mm1 (ag1's hiding node), so they can't be bucketed
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor", 2, exactly=True
).run(aten_graph_str)
correct = func(inputs, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_wait_sink(self):
"""Test that 4 independent all-gathers split bucketed."""
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5
mm1 = torch.matmul(e, e.T)
# Second compute - can hide ag3 and ag4
f = b * 6
mm2 = torch.matmul(f, f.T)
# Use all collective results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# The 4 all gathers can be bucketed, and their waits should be sunk below the mms
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
).check_count("ops.aten.mm", 2, exactly=True).check(
"_c10d_functional.wait_tensor"
).run(aten_graph_str)
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap_blocking(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5 # Use a to avoid fusion
mm1 = torch.matmul(e, e.T)
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
# Use first 8x8 elements to match mm1's shape
intermediate = ag1[:8, :8] + ag2[:8, :8]
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
mm2 = torch.matmul(mm1 + intermediate, c[:8])
# Use all results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# The 4 all gathers can be bucketed, and the wait should be sunk below the mms
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
).check_count("ops.aten.mm", 2, exactly=True).check_count(
"_c10d_functional.wait_tensor", 1, exactly=True
).run(aten_graph_str)
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches(2.0))
def test_bucketing_split_for_overlap(self):
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
def func(a, b, c, d, *, ranks):
# All 4 all-gathers are independent - COULD be bucketed together
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
# First compute - can hide ag1 and ag2
e = a * 5 # Use a to avoid fusion
mm1 = torch.matmul(e, e.T)
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
intermediate = ag1[:2, :2] + ag2[:2, :2] # Small slice to minimize compute
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
f = b * 6
# Expand intermediate to match mm1's shape for broadcasting
intermediate_expanded = torch.nn.functional.pad(intermediate, (0, 6, 0, 6))
mm2 = torch.matmul(mm1 + intermediate_expanded, f.T)
# Use all results
result = (
ag1.sum() * 1.1
+ ag2.sum() * 1.2
+ ag3.sum() * 1.3
+ ag4.sum() * 1.4
+ mm1.sum()
+ mm2.sum()
)
return result
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
# Should have 2 bucketed all-gathers (one for ag1+ag2, one for ag3+ag4)
FileCheck().check_count(
"_c10d_functional.all_gather_into_tensor_out", 2, exactly=True
).run(aten_graph_str)
# Verify the ordering - first bucket, then mm1, then second bucket, then mm2
FileCheck().check("_c10d_functional.all_gather_into_tensor_out").check(
"ops.aten.mm"
).check("_c10d_functional.all_gather_into_tensor_out").check(
"ops.aten.mm"
).run(aten_graph_str)
# Verify correctness
correct = func(a, b, c, d, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_bucket_exposed_with_hidden_single_overlap(self):
"""Test that exposed and hidden collectives bucket together when overlap is preserved."""
def func(a, b, c, *, ranks):
# ag1 will be hidden by mm1
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
# ag2 and ag3 are exposed (no compute to hide them)
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks)
# can only hide one collective
mm1 = torch.matmul(a[:2], a[:2].T) # 2x2 matmul, hides only ag1
# All three can bucket together because:
# bucketing ag1, ag2, ag3 together does not prevent ag1 being hidden by mm1.
return ag1.sum() + ag2.sum() + ag3.sum() + mm1.sum()
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
ranks = list(range(self.world_size))
func_c = functools.partial(func, ranks=ranks)
compiled = torch.compile(func_c)
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
# Should have 1 bucketed operation containing all 3 all-gathers
FileCheck().check_count("wait_tensor.default", 1, exactly=True).run(
aten_graph_str
)
# Verify bucketed collective overlaps with mm1
FileCheck().check("functional.all_gather_into_tensor").check(
"aten.mm"
).check("wait_tensor").run(aten_graph_str)
# Verify correctness
correct = func(a, b, c, ranks=ranks)
self.assertTrue(same(out, correct))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2010,6 +2010,9 @@ class test_configs:
# to be migrated when ready for use
aten_fx_overlap_scheduling = False
# to be migrated when ready for use
aten_fx_overlap_preserving_bucketing = False
# to be migrated when ready for use
# runtime estimation function for ops
# for user-defined estimation function, pass in the function handle

View File

@ -570,7 +570,7 @@ def process_collective_bucket(
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> dict[torch.fx.Node, torch.fx.Node]:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
"""
Process a single bucket of collective operation nodes with flexible insertion control.
@ -583,6 +583,7 @@ def process_collective_bucket(
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
Returns:
new_nodes: List of all newly inserted nodes
replacements: Dictionary mapping old wait nodes to new output nodes
"""
# Collect inputs and waits from current bucket
@ -650,15 +651,16 @@ def process_collective_bucket(
for pre_node in reversed(ag_node_to_pre_nodes[node]):
g.erase_node(pre_node)
return replacements
return new_nodes, replacements
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> None:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
# Validate bucket consistency
rs0 = rs_nodes[0]
rs0_val = rs0.meta["val"]
@ -692,11 +694,12 @@ def merge_reduce_scatter_bucket(
device,
)
process_collective_bucket(
return process_collective_bucket(
g,
rs_nodes,
rs_merge_fn,
create_trace_args,
insert_before=insert_before,
wait_insertion_point=wait_insertion_point,
)
@ -705,8 +708,9 @@ def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> None:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
from torch.distributed.distributed_c10d import _resolve_process_group
ag0 = ag_nodes[0]
@ -739,7 +743,7 @@ def merge_all_gather_bucket(
rank,
)
process_collective_bucket(
return process_collective_bucket(
g,
ag_nodes,
ag_merge_fn,

View File

@ -0,0 +1,256 @@
from collections import defaultdict
from typing import Optional
import torch
import torch.fx as fx
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
from torch._inductor.fx_passes.bucketing import (
_ag_group_key,
_rs_group_key,
is_all_gather_into_tensor as is_all_gather,
is_reduce_scatter_tensor as is_reduce_scatter,
is_wait_tensor,
)
from torch._inductor.fx_passes.overlap_scheduling import CollBucket, CollectiveInfo
from torch.utils._ordered_set import OrderedSet
def bucket_key(node: torch.fx.Node) -> Optional[object]:
if is_all_gather(node):
return _ag_group_key(node)
elif is_reduce_scatter(node):
return _rs_group_key(node)
else:
return None
class OverlapPreservingBucketer:
"""
Buckets collective operations while preserving compute-collective overlap relationships.
Uses an augmented graph to track dependencies between compute and collective operations.
"""
def __init__(
self,
graph: fx.Graph,
collective_info: dict[fx.Node, CollectiveInfo],
node_ancestors: dict[fx.Node, OrderedSet[fx.Node]],
scheduled: OrderedSet[fx.Node],
max_bucket_memory_gb: float = 1.0,
):
self.graph = graph
self.collective_info = collective_info
self.node_ancestors = node_ancestors
self.scheduled = scheduled
self.max_bucket_memory_gb = max_bucket_memory_gb
self.node_idx = {n: i for i, n in enumerate(scheduled)}
def bucket_collectives(self) -> None:
"""Main entry point for bucketing collectives."""
aug_graph = AugmentedGraphHelper(self.graph)
# Add extra dependencies for hidden collectives
# For each hidden collective, add: compute -> start and wait -> compute
for start_node, info in self.collective_info.items():
if info.hiding_node and not info.is_exposed:
# Add edge: hiding_compute depends on start (start must come before compute)
aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
# Add edge: wait depends on hiding_compute (compute must come before wait)
aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
# Group collectives by bucket key (type, group, etc.)
grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for start in self.collective_info:
key = bucket_key(start)
if key is not None:
grouped_collectives[key].add(start)
all_buckets: list[CollBucket] = []
for collective_group in grouped_collectives.values():
buckets = self._find_buckets(collective_group, aug_graph)
all_buckets.extend(buckets)
# Collect all extra dependencies to preserve after bucketing
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
# Apply bucketing transformations
for coll_bucket in all_buckets:
if len(coll_bucket.collectives) <= 1:
continue
bucket_deps = self._apply_bucket(coll_bucket)
additional_deps.update(bucket_deps)
# Apply topological sort with all the collected dependencies
from torch._dynamo.graph_deduplication import _stable_topological_sort
_stable_topological_sort(self.graph, additional_deps)
self.graph.lint()
def _find_buckets(
self,
collective_group: OrderedSet[fx.Node],
aug_graph: AugmentedGraphHelper,
) -> list[CollBucket]:
"""Find valid buckets within a group of similar collectives."""
max_bucket_bytes = int(self.max_bucket_memory_gb * 1024 * 1024 * 1024)
buckets = []
processed: OrderedSet[fx.Node] = OrderedSet()
for start_node in collective_group:
if start_node in processed:
continue
# Initialize bucket with first collective
bucket_info = CollBucket(
collectives=[start_node],
total_bytes=self.collective_info[start_node].size_bytes,
)
processed.add(start_node)
# TODO - limit within range
for candidate in collective_group:
if candidate in processed:
continue
candidate_bytes = self.collective_info[candidate].size_bytes
if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
continue
if self._can_add_to_bucket(bucket_info, candidate, aug_graph):
bucket_info.collectives.append(candidate)
bucket_info.total_bytes += candidate_bytes
processed.add(candidate)
if len(bucket_info.collectives) > 1:
buckets.append(bucket_info)
return buckets
def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool:
"""Check if there's an ancestor relationship between two nodes."""
return n1 in self.node_ancestors[n2] or n2 in self.node_ancestors[n1]
def _can_add_to_bucket(
self,
bucket_info: CollBucket,
candidate: fx.Node,
aug_graph: AugmentedGraphHelper,
) -> bool:
"""
Check if candidate can be added to bucket without interfering
with comm/compute overlap.
"""
candidate_info = self.collective_info[candidate]
candidate_wait = candidate_info.wait_node
# Step 1: Quick check using precomputed ancestors
# This will not be fully up to date because bucketing changes ancestors,
# however any ancestor at the start of bucketing will remain an ancestor.
for coll in bucket_info.collectives:
if self._ancestor_dep(coll, candidate):
return False
coll_wait = self.collective_info[coll].wait_node
if self._ancestor_dep(candidate_wait, coll_wait):
return False
if hiding_node := self.collective_info[coll].hiding_node:
if self._ancestor_dep(hiding_node, candidate_wait):
return False
if new_hiding_node := candidate_info.hiding_node:
if self._ancestor_dep(new_hiding_node, coll_wait):
return False
# Step 2: Check and merge starts
# Check if there's a path between any existing start and candidate start.
# Because the collectives have already been merged, we can just start from one
# of them.
# TODO: we have a range of possible idxs of the merged node, and idx of new node.
# we should not do path search beyond that range
existing_coll = bucket_info.collectives[0]
if aug_graph.has_path(existing_coll, candidate):
return False
if aug_graph.has_path(candidate, existing_coll):
return False
# Safe to merge starts - do the merge
aug_graph.merge_to_set(existing_coll, candidate)
# Step 3: Check and merge waits
existing_wait = self.collective_info[existing_coll].wait_node
candidate_wait = candidate_info.wait_node
# TODO - as above, limit search by idx
if aug_graph.has_path(existing_wait, candidate_wait) or aug_graph.has_path(
candidate_wait, existing_wait
):
# Unmerge the start we just merged
aug_graph.unmerge_node(candidate)
return False
aug_graph.merge_to_set(existing_wait, candidate_wait)
return True
def _apply_bucket(
self, bucket_info: CollBucket
) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""Apply bucketing transformation and return dependencies to preserve."""
from torch._inductor.fx_passes.bucketing import (
merge_all_gather_bucket,
merge_reduce_scatter_bucket,
)
bucket = bucket_info.collectives
# Find where to place the bucketed operations
next_node = bucket[0]
while next_node in bucket:
next_node = next_node.next
waits = [self.collective_info[n].wait_node for n in bucket]
first_wait = min(waits, key=lambda w: self.node_idx[w])
# Create bucketed collective
if is_all_gather(bucket[0]):
new_nodes, replacements = merge_all_gather_bucket(
self.graph,
bucket,
wait_insertion_point=first_wait,
insert_before=next_node,
mode="custom_ops",
)
else:
assert is_reduce_scatter(bucket[0])
new_nodes, replacements = merge_reduce_scatter_bucket(
self.graph,
bucket,
wait_insertion_point=first_wait,
insert_before=next_node,
mode="custom_ops",
)
# Build dependencies to preserve overlap
# replacements maps old_start -> new_start, old_wait -> new_wait
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
assert len(new_waits) == 1
new_wait = new_waits[0]
new_start = new_wait.args[0]
assert isinstance(new_start, fx.Node)
overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
# Create dependencies to preserve overlap
for coll in bucket:
info = self.collective_info[coll]
if info.hiding_node and not info.is_exposed:
# Compute depends on collective start
overlap_deps[info.hiding_node].add(new_start)
# Wait depends on compute
overlap_deps[new_wait].add(info.hiding_node)
return overlap_deps

View File

@ -11,6 +11,7 @@ from typing import Any, Callable, Optional, Union
import torch
import torch.fx as fx
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.fx_passes.bucketing import is_wait_tensor
from torch.utils._mode_utils import no_dispatch
from torch.utils._ordered_set import OrderedSet
@ -20,13 +21,6 @@ log = logging.getLogger(__name__)
from ..pattern_matcher import stable_topological_sort
def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)
def get_custom_estimation(n: fx.Node) -> Optional[float]:
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
if runtime_estimation == "default":
@ -156,6 +150,16 @@ class CollectiveInfo:
return self.exposed_time_ms != 0
@dataclass
class CollBucket:
"""Track information about a bucket of collectives."""
collectives: list[fx.Node] # Original collective starts
bucketed_start: Optional[fx.Node] = None # After bucketing
bucketed_wait: Optional[fx.Node] = None # After bucketing
total_bytes: int = 0
class OverlapScheduler:
"""
Scheduler that reorders operations to maximize compute-collective overlap.
@ -184,7 +188,7 @@ class OverlapScheduler:
self,
gm: torch.fx.GraphModule,
max_in_flight_gb: float = 2.0,
compute_overlap_multipler: float = 1.0,
compute_overlap_multipler: float = 2.0,
max_coll_distance: int = 1000,
):
self.gm = gm
@ -325,6 +329,8 @@ class OverlapScheduler:
self._handle_other(node)
self._reorder_graph()
if torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing:
self._bucket_collectives()
return self.gm
def _handle_other(self, node: fx.Node) -> None:
@ -583,6 +589,20 @@ class OverlapScheduler:
self.reorder_graph()
def _bucket_collectives(self) -> None:
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
OverlapPreservingBucketer,
)
bucketer = OverlapPreservingBucketer(
graph=self.graph,
collective_info=self.collective_info,
node_ancestors=self.node_ancestors,
scheduled=self.scheduled,
max_bucket_memory_gb=1.0, # Could make this configurable
)
bucketer.bucket_collectives()
def compute_potential_hidden_nodes(
self, nodes_to_check: Iterable[fx.Node], limit_coll_per_compute: bool = False
) -> dict[fx.Node, fx.Node]: