mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] Bucket only adjacent collectives to prevent reordering (#159983)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159983 Approved by: https://github.com/wconstab, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
4d5b3f2d5a
commit
f33ce40bc0
@ -93,6 +93,12 @@ def greedy_bucket_collective_by_mb(
|
||||
node_group_key: Callable[[torch.fx.Node], Any],
|
||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Bucketing adjacent collectives with equal node_group_key.
|
||||
We can not bucket non adjacent collectives,
|
||||
as this will effectively change the order of collectives.
|
||||
Reordering can lead to different order on different ranks.
|
||||
"""
|
||||
g = gm.graph
|
||||
found_candidates = False
|
||||
for node in g.nodes:
|
||||
@ -102,10 +108,12 @@ def greedy_bucket_collective_by_mb(
|
||||
if not found_candidates:
|
||||
return []
|
||||
|
||||
nodes_groups: dict[Any, list[torch.fx.Node]] = defaultdict(list)
|
||||
nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict(
|
||||
OrderedSet
|
||||
)
|
||||
nodes_groups: list[list[torch.fx.Node]] = []
|
||||
cur_group: list[torch.fx.Node] = []
|
||||
cur_group_key = None
|
||||
|
||||
for node in g.nodes:
|
||||
for n, successors in nodes_successors.items():
|
||||
@ -115,10 +123,19 @@ def greedy_bucket_collective_by_mb(
|
||||
if (filter_wait_node is None) or filter_wait_node(node):
|
||||
coll_node = node.args[0]
|
||||
group_key = node_group_key(coll_node)
|
||||
nodes_groups[group_key].append(coll_node)
|
||||
if group_key == cur_group_key:
|
||||
cur_group.append(coll_node)
|
||||
else:
|
||||
if len(cur_group) > 1:
|
||||
nodes_groups.append(cur_group)
|
||||
cur_group = [coll_node]
|
||||
cur_group_key = group_key
|
||||
|
||||
if len(cur_group) > 1:
|
||||
nodes_groups.append(cur_group)
|
||||
|
||||
buckets: list[list[torch.fx.Node]] = []
|
||||
for nodes in nodes_groups.values():
|
||||
for nodes in nodes_groups:
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_size_bytes: int = 0
|
||||
@ -128,7 +145,7 @@ def greedy_bucket_collective_by_mb(
|
||||
)
|
||||
for node in nodes:
|
||||
if node in cur_bucket_successors:
|
||||
# We can not bucket successors with the node
|
||||
# We cannot bucket successors with the node
|
||||
continue
|
||||
assert "val" in node.meta
|
||||
n_val = node.meta["val"]
|
||||
@ -163,7 +180,7 @@ def bucket_all_gather_by_mb(
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
|
||||
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
||||
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
||||
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
||||
to specify different sizes of the buckets at the start,
|
||||
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
|
||||
@ -201,14 +218,14 @@ def bucket_reduce_scatter_by_mb(
|
||||
|
||||
Args:
|
||||
gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
|
||||
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
||||
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
|
||||
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
|
||||
to specify different sizes of the buckets.
|
||||
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
|
||||
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
|
||||
|
||||
Returns:
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
|
||||
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
|
||||
"""
|
||||
|
||||
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
|
||||
|
Reference in New Issue
Block a user