[bucketing] Bucket only adjacent collectives to prevent reordering (#159983)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159983
Approved by: https://github.com/wconstab, https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-08-07 03:15:48 -07:00
committed by PyTorch MergeBot
parent 4d5b3f2d5a
commit f33ce40bc0
2 changed files with 49 additions and 17 deletions

View File

@ -93,6 +93,12 @@ def greedy_bucket_collective_by_mb(
node_group_key: Callable[[torch.fx.Node], Any],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
"""
Bucketing adjacent collectives with equal node_group_key.
We can not bucket non adjacent collectives,
as this will effectively change the order of collectives.
Reordering can lead to different order on different ranks.
"""
g = gm.graph
found_candidates = False
for node in g.nodes:
@ -102,10 +108,12 @@ def greedy_bucket_collective_by_mb(
if not found_candidates:
return []
nodes_groups: dict[Any, list[torch.fx.Node]] = defaultdict(list)
nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict(
OrderedSet
)
nodes_groups: list[list[torch.fx.Node]] = []
cur_group: list[torch.fx.Node] = []
cur_group_key = None
for node in g.nodes:
for n, successors in nodes_successors.items():
@ -115,10 +123,19 @@ def greedy_bucket_collective_by_mb(
if (filter_wait_node is None) or filter_wait_node(node):
coll_node = node.args[0]
group_key = node_group_key(coll_node)
nodes_groups[group_key].append(coll_node)
if group_key == cur_group_key:
cur_group.append(coll_node)
else:
if len(cur_group) > 1:
nodes_groups.append(cur_group)
cur_group = [coll_node]
cur_group_key = group_key
if len(cur_group) > 1:
nodes_groups.append(cur_group)
buckets: list[list[torch.fx.Node]] = []
for nodes in nodes_groups.values():
for nodes in nodes_groups:
cur_bucket: list[torch.fx.Node] = []
cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
@ -128,7 +145,7 @@ def greedy_bucket_collective_by_mb(
)
for node in nodes:
if node in cur_bucket_successors:
# We can not bucket successors with the node
# We cannot bucket successors with the node
continue
assert "val" in node.meta
n_val = node.meta["val"]
@ -163,7 +180,7 @@ def bucket_all_gather_by_mb(
Args:
gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets at the start,
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
@ -201,14 +218,14 @@ def bucket_reduce_scatter_by_mb(
Args:
gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
"""
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: