mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add Comm-Compute Preserving Bucketer (#163960)
tl;dr performs bucketing while preserving comm-compute overlap. In comm-compute overlap we will have a graph with: ``` def foo(...): ag = all_gather(...) hiding_compute = mm(...) wait(ag) ``` There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in. TODO: - need to instrument fx graph so inductor respects these relationships. - the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through - more memory aware handling Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960 Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev ghstack dependencies: #163215, #163754, #163959
This commit is contained in:
committed by
PyTorch MergeBot
parent
92108f4abd
commit
7d59e37434
@ -570,7 +570,7 @@ def process_collective_bucket(
|
||||
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> dict[torch.fx.Node, torch.fx.Node]:
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
"""
|
||||
Process a single bucket of collective operation nodes with flexible insertion control.
|
||||
|
||||
@ -583,6 +583,7 @@ def process_collective_bucket(
|
||||
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
|
||||
|
||||
Returns:
|
||||
new_nodes: List of all newly inserted nodes
|
||||
replacements: Dictionary mapping old wait nodes to new output nodes
|
||||
"""
|
||||
# Collect inputs and waits from current bucket
|
||||
@ -650,15 +651,16 @@ def process_collective_bucket(
|
||||
for pre_node in reversed(ag_node_to_pre_nodes[node]):
|
||||
g.erase_node(pre_node)
|
||||
|
||||
return replacements
|
||||
return new_nodes, replacements
|
||||
|
||||
|
||||
def merge_reduce_scatter_bucket(
|
||||
g: torch.fx.Graph,
|
||||
rs_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> None:
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
# Validate bucket consistency
|
||||
rs0 = rs_nodes[0]
|
||||
rs0_val = rs0.meta["val"]
|
||||
@ -692,11 +694,12 @@ def merge_reduce_scatter_bucket(
|
||||
device,
|
||||
)
|
||||
|
||||
process_collective_bucket(
|
||||
return process_collective_bucket(
|
||||
g,
|
||||
rs_nodes,
|
||||
rs_merge_fn,
|
||||
create_trace_args,
|
||||
insert_before=insert_before,
|
||||
wait_insertion_point=wait_insertion_point,
|
||||
)
|
||||
|
||||
@ -705,8 +708,9 @@ def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> None:
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
ag0 = ag_nodes[0]
|
||||
@ -739,7 +743,7 @@ def merge_all_gather_bucket(
|
||||
rank,
|
||||
)
|
||||
|
||||
process_collective_bucket(
|
||||
return process_collective_bucket(
|
||||
g,
|
||||
ag_nodes,
|
||||
ag_merge_fn,
|
||||
|
Reference in New Issue
Block a user