Add Comm-Compute Preserving Bucketer (#163960)

tl;dr performs bucketing while preserving comm-compute overlap.

In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
This commit is contained in:
eellison
2025-09-29 15:36:45 -07:00
committed by PyTorch MergeBot
parent 92108f4abd
commit 7d59e37434
5 changed files with 703 additions and 16 deletions

View File

@ -570,7 +570,7 @@ def process_collective_bucket(
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> dict[torch.fx.Node, torch.fx.Node]:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
"""
Process a single bucket of collective operation nodes with flexible insertion control.
@ -583,6 +583,7 @@ def process_collective_bucket(
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
Returns:
new_nodes: List of all newly inserted nodes
replacements: Dictionary mapping old wait nodes to new output nodes
"""
# Collect inputs and waits from current bucket
@ -650,15 +651,16 @@ def process_collective_bucket(
for pre_node in reversed(ag_node_to_pre_nodes[node]):
g.erase_node(pre_node)
return replacements
return new_nodes, replacements
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> None:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
# Validate bucket consistency
rs0 = rs_nodes[0]
rs0_val = rs0.meta["val"]
@ -692,11 +694,12 @@ def merge_reduce_scatter_bucket(
device,
)
process_collective_bucket(
return process_collective_bucket(
g,
rs_nodes,
rs_merge_fn,
create_trace_args,
insert_before=insert_before,
wait_insertion_point=wait_insertion_point,
)
@ -705,8 +708,9 @@ def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: Optional[str] = None,
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> None:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
from torch.distributed.distributed_c10d import _resolve_process_group
ag0 = ag_nodes[0]
@ -739,7 +743,7 @@ def merge_all_gather_bucket(
rank,
)
process_collective_bucket(
return process_collective_bucket(
g,
ag_nodes,
ag_merge_fn,