[bucketing] Support case of several pgs in graph (#158632)

Main changes:
- bucketing collectives only from the same process_group by group_name
- Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632
Approved by: https://github.com/wconstab
This commit is contained in:
IvanKobzarev
2025-07-22 04:51:46 -07:00
committed by PyTorch MergeBot
parent 1b772de397
commit 371ffaf415

View File

@ -1,6 +1,7 @@
import logging
import math
import operator
from collections import defaultdict
from typing import Any, Callable, Optional, Union
import torch
@ -77,13 +78,9 @@ def bucket_all_gather_by_mb(
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`.
Returns a list of buckets, where each bucket is a list of all_gather nodes.
"""
node_list = gm.graph.nodes
# Prerequisite: Check if there is any all_gather node
found_all_gather = False
for node in node_list:
@ -92,48 +89,53 @@ def bucket_all_gather_by_mb(
break
if not found_all_gather:
return []
ag_nodes: list[torch.fx.Node] = []
group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
defaultdict(list)
)
# Step 1: Find all all_gather nodes
for node in node_list:
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
if (filter_wait_node is None) or filter_wait_node(node):
ag_node = node.args[0]
ag_nodes.append(ag_node)
_, group_size, group_name = ag_node.args
dtype = ag_node.meta["val"].dtype
assert isinstance(group_name, str)
group_name_ag_nodes[(group_name, dtype)].append(ag_node)
# Step 2: Put all_gather nodes into buckets
ag_buckets: list[list[torch.fx.Node]] = []
cur_bucket: list[torch.fx.Node] = []
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
# Convert MiB to bytes
all_gather_bucket_size_bytes = int(
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
for ag_node in ag_nodes:
assert is_all_gather_into_tensor(ag_node)
assert "val" in ag_node.meta
ag_output_size_bytes = (
ag_node.meta["val"].numel()
* torch.finfo(ag_node.meta["val"].dtype).bits
// 8
for (group_name, dtype), ag_nodes in group_name_ag_nodes.items():
cur_bucket: list[torch.fx.Node] = []
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
all_gather_bucket_size_bytes = int(
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
if (
cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
for ag_node in ag_nodes:
assert is_all_gather_into_tensor(ag_node)
if ag_node in cur_bucket_recursive_users:
# We can not bucket successors with the node
continue
assert "val" in ag_node.meta
ag_n_val = ag_node.meta["val"]
ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size()
if (
cur_bucket_size_bytes + ag_output_size_bytes
> all_gather_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
if len(cur_bucket) > 1:
ag_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
cur_bucket_size_bytes += ag_output_size_bytes
cur_bucket.append(ag_node)
find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users)
if len(cur_bucket) > 1:
# add remaining nodes in the last bucket
ag_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
cur_bucket_size_bytes += ag_output_size_bytes
cur_bucket.append(ag_node)
if cur_bucket:
# add remaining nodes in the last bucket
ag_buckets.append(cur_bucket)
return ag_buckets
@ -143,13 +145,9 @@ def bucket_reduce_scatter_by_mb(
) -> list[list[torch.fx.Node]]:
"""
Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`.
Returns a list of buckets, where each bucket is a list of reduce_scatter nodes.
"""
node_list = list(gm.graph.nodes)
# Prerequisite: Check if there is any reduce_scatter node
found_reduce_scatter = False
for node in node_list:
@ -158,64 +156,71 @@ def bucket_reduce_scatter_by_mb(
break
if not found_reduce_scatter:
return []
rs_nodes: list[torch.fx.Node] = []
group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
defaultdict(list)
)
# Step 1: Find all reduce_scatter nodes
for node in node_list:
if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]):
rs_node = node.args[0]
rs_nodes.append(rs_node)
_, reduce_op, group_size, group_name = rs_node.args
dtype = rs_node.meta["val"].dtype
assert isinstance(group_name, str)
assert isinstance(reduce_op, str)
group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node)
# Step 2: Put reduce_scatter nodes into buckets
rs_buckets: list[list[torch.fx.Node]] = []
cur_bucket: list[torch.fx.Node] = []
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
# Convert MiB to bytes
reduce_scatter_bucket_size_bytes = int(
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
for rs_node in rs_nodes:
assert is_reduce_scatter_tensor(rs_node)
rs_input = rs_node.args[0]
assert "val" in rs_input.meta # type: ignore[union-attr]
rs_input_size_bytes = (
rs_input.meta["val"].numel() # type: ignore[union-attr]
* torch.finfo(rs_input.meta["val"].dtype).bits # type: ignore[union-attr]
// 8
for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items():
cur_bucket: list[torch.fx.Node] = []
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
# Convert MiB to bytes
reduce_scatter_bucket_size_bytes = int(
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
if (
cur_bucket_size_bytes + rs_input_size_bytes
> reduce_scatter_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
total_size = cur_bucket_size_bytes + rs_input_size_bytes
for rs_node in rs_nodes:
assert is_reduce_scatter_tensor(rs_node)
if rs_node in cur_bucket_recursive_users:
# We can not bucket successors with the node
continue
rs_input = rs_node.args[0]
assert "val" in rs_input.meta # type: ignore[union-attr]
rs_in_val = rs_input.meta["val"] # type: ignore[union-attr]
rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size()
if (
cur_bucket_size_bytes + rs_input_size_bytes
> reduce_scatter_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
total_size = cur_bucket_size_bytes + rs_input_size_bytes
logger.info(
f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004
f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = "
f"{cur_bucket_size_bytes} + {rs_input_size_bytes},"
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
)
if len(cur_bucket) > 1:
rs_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
reduce_scatter_bucket_size_bytes = int(
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
cur_bucket_size_bytes += rs_input_size_bytes
cur_bucket.append(rs_node)
find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users)
if cur_bucket:
# add remaining nodes in the last bucket
logger.info(
f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004
f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = "
f"{cur_bucket_size_bytes} + {rs_input_size_bytes},"
f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004
f"total_size = {cur_bucket_size_bytes}, "
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
)
rs_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
reduce_scatter_bucket_size_bytes = int(
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
cur_bucket_size_bytes += rs_input_size_bytes
cur_bucket.append(rs_node)
if cur_bucket:
# add remaining nodes in the last bucket
logger.info(
f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004
f"total_size = {cur_bucket_size_bytes}, "
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
)
rs_buckets.append(cur_bucket)
if len(cur_bucket) > 1:
rs_buckets.append(cur_bucket)
return rs_buckets
@ -260,6 +265,18 @@ def env_lookup( # type: ignore[no-untyped-def]
return env[x]
def _rank_idx_dict(group_name: str) -> dict[int, int]:
from torch.distributed.distributed_c10d import (
_resolve_process_group,
get_process_group_ranks,
)
pg = _resolve_process_group(group_name)
ranks = get_process_group_ranks(pg)
rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)}
return rank_idx_dict
def merge_all_gather(
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
) -> None:
@ -297,15 +314,13 @@ def merge_all_gather(
bucket_id_is_scheduled = {}
cast_bucket_id_is_scheduled = {}
_, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
for bucket_id, ag_bucket in enumerate(ag_buckets):
ag_input_nodes = []
wait_nodes = []
for ag_node in ag_bucket:
assert (
ag_node in ag_node_to_wait_node
and ag_node.args[1] == group_size
and ag_node.args[2] == group_name
)
ag_input_nodes.append(ag_node.args[0])
wait_nodes.append(ag_node_to_wait_node[ag_node])
bucket_id_to_bucketed_op_info[bucket_id] = (
@ -314,6 +329,8 @@ def merge_all_gather(
group_name,
wait_nodes,
)
if group_name not in group_name_to_rank_idx_dict:
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
ag_wait_nodes = list(ag_node_to_wait_node.values())
ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes)
@ -334,9 +351,6 @@ def merge_all_gather(
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
bucket_id_to_bucketed_op_info[bucket_id]
)
# device = ag_input_nodes[0].meta["val"].device
# rank = device.index
# dtype = ag_input_nodes[0].meta["val"].dtype
if all(
n.op == "call_function" # type: ignore[union-attr]
and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
@ -398,6 +412,7 @@ def merge_all_gather(
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
bucket_id_to_bucketed_op_info[bucket_id]
)
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr]
rank = device.index
dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr]
@ -468,7 +483,7 @@ def merge_all_gather(
all_gather_output,
inp_split_sizes,
all_gather_input_numel,
rank,
rank_idx_dict[rank],
),
{},
)
@ -585,6 +600,7 @@ def merge_reduce_scatter(
# Prepare bucketed operation info
bucket_id_to_bucketed_op_info = {}
bucket_id_is_scheduled = {}
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
for bucket_id, rs_bucket in enumerate(rs_buckets):
_, reduce_op, group_size, group_name = next(
iter(rs_node_to_wait_node.keys())
@ -612,6 +628,8 @@ def merge_reduce_scatter(
wait_nodes,
wait_node_recursive_users,
)
if group_name not in group_name_to_rank_idx_dict:
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
new_graph: torch.fx.Graph = torch.fx.Graph()
env: dict[torch.fx.Node, torch.fx.Node] = {}
@ -624,155 +642,154 @@ def merge_reduce_scatter(
elif node in rs_node_to_wait_node:
assert node in rs_node_to_bucket_id
bucket_id = rs_node_to_bucket_id[node]
if (
if not (
bucket_id not in bucket_id_is_scheduled
and rs_buckets[bucket_id][-1] == node
):
# If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node
(
rs_input_nodes,
reduce_op,
group_size,
group_name,
orig_wait_nodes,
orig_wait_node_recursive_users,
) = bucket_id_to_bucketed_op_info[bucket_id]
# parents of rs have been scheduled, so we can directly use the env
unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index]
reduce_dtype = unsharded_grads[0].meta["val"].dtype
# Only float32 and bfloat16 are supported for now.
# To support fp16, please see FSDP2 `_get_gradient_divide_factors`.
assert reduce_dtype in (
torch.float32,
torch.bfloat16,
), f"reduce_dtype {reduce_dtype} is not supported"
assert all(
grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads
)
device = unsharded_grads[0].meta["val"].device
rank = device.index
shard_dim = 0
continue
def _get_dim0_padded_size(
tensor_size: torch.Size, dim0_factor: int
) -> torch.Size:
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
return torch.Size([padded_dim0]) + tensor_size[1:]
# If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node
(
rs_input_nodes,
reduce_op,
group_size,
group_name,
orig_wait_nodes,
orig_wait_node_recursive_users,
) = bucket_id_to_bucketed_op_info[bucket_id]
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
# parents of rs have been scheduled, so we can directly use the env
unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index]
reduce_dtype = unsharded_grads[0].meta["val"].dtype
# Only float32 and bfloat16 are supported for now.
# To support fp16, please see FSDP2 `_get_gradient_divide_factors`.
assert reduce_dtype in (
torch.float32, # type: ignore[attr-defined]
torch.bfloat16, # type: ignore[attr-defined]
), f"reduce_dtype {reduce_dtype} is not supported"
assert all(
grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads
)
device = unsharded_grads[0].meta["val"].device
rank = device.index
rank_idx = rank_idx_dict[rank]
shard_dim = 0
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type]
for grad in unsharded_grads
)
reduce_scatter_input_numel = sum(
s.numel() for s in padded_unsharded_sizes
)
def _get_dim0_padded_size(
tensor_size: torch.Size,
dim0_factor: int, # type: ignore[name-defined]
) -> torch.Size: # type: ignore[name-defined]
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined]
return torch.Size([padded_dim0]) + tensor_size[1:]
"""
NOTE: the relationship between the next few nodes is tricky:
- reduce_scatter_input_reshaped is a view of reduce_scatter_input
(same storage, same # elems, different shape).
- chunk_cat writes into reduce_scatter_input_reshaped,
which indirectly writes into reduce_scatter_input
(since they share the same storage).
- reduce_scatter_tensor reads from reduce_scatter_input.
"""
reduce_scatter_input = new_graph_call_function(
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type]
for grad in unsharded_grads
)
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
"""
NOTE: the relationship between the next few nodes is tricky:
- reduce_scatter_input_reshaped is a view of reduce_scatter_input
(same storage, same # elems, different shape).
- chunk_cat writes into reduce_scatter_input_reshaped,
which indirectly writes into reduce_scatter_input
(since they share the same storage).
- reduce_scatter_tensor reads from reduce_scatter_input.
"""
reduce_scatter_input = new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([reduce_scatter_input_numel],),
{
"dtype": reduce_dtype,
"device": device,
"pin_memory": False,
},
)
reduce_scatter_input_reshaped = new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(reduce_scatter_input, [group_size, -1]),
{},
)
new_graph_call_function(
new_graph,
torch.ops.fsdp.chunk_cat.default,
(unsharded_grads,),
{
"dim": 0,
"num_chunks": group_size,
"out": reduce_scatter_input_reshaped,
},
)
reduce_scatter_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
(reduce_scatter_input, reduce_op, group_size, group_name),
{},
)
wait_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.wait_tensor.default,
(reduce_scatter_tensor,),
{},
)
def _chunk_with_empty(
tensor: torch.Tensor, num_chunks: int, dim: int
) -> list[torch.Tensor]:
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
while len(chunks) < num_chunks:
chunks.append(chunks[0].new_empty(0))
return chunks
reduce_output = wait_tensor
# View out and accumulate sharded gradients
new_sharded_grads = []
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, unsharded_grad in zip(
padded_unsharded_sizes, unsharded_grads
):
# NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here
chunks = _chunk_with_empty(
torch.empty_like(unsharded_grad.meta["val"], device="meta"),
group_size, # type: ignore[arg-type]
dim=shard_dim,
)
sharded_param = chunks[rank_idx]
sharded_size = sharded_param.size()
contiguous_sharded_stride = (
torch._prims_common.make_contiguous_strides_for(sharded_size)
)
# Assume even sharding for Shard(i), i > 0; otherwise would require
# copy-out for contiguous strides
new_sharded_grad = new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([reduce_scatter_input_numel],),
torch.ops.aten.as_strided.default,
(reduce_output,),
{
"dtype": reduce_dtype,
"device": device,
"pin_memory": False,
"size": sharded_size,
"stride": contiguous_sharded_stride,
"storage_offset": flat_grad_offset,
},
)
reduce_scatter_input_reshaped = new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(reduce_scatter_input, [group_size, -1]),
{},
)
new_graph_call_function(
new_graph,
torch.ops.fsdp.chunk_cat.default,
(unsharded_grads,),
{
"dim": 0,
"num_chunks": group_size,
"out": reduce_scatter_input_reshaped,
},
)
reduce_scatter_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
(reduce_scatter_input, reduce_op, group_size, group_name),
{},
)
wait_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.wait_tensor.default,
(reduce_scatter_tensor,),
{},
)
def _chunk_with_empty(
tensor: torch.Tensor, num_chunks: int, dim: int
) -> list[torch.Tensor]:
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
while len(chunks) < num_chunks:
chunks.append(chunks[0].new_empty(0))
return chunks
reduce_output = wait_tensor
# View out and accumulate sharded gradients
new_sharded_grads = []
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, unsharded_grad in zip(
padded_unsharded_sizes, unsharded_grads
):
# NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here
chunks = _chunk_with_empty(
torch.empty_like(unsharded_grad.meta["val"], device="meta"),
group_size, # type: ignore[arg-type]
dim=shard_dim,
)
sharded_param = chunks[rank]
sharded_size = sharded_param.size()
contiguous_sharded_stride = (
torch._prims_common.make_contiguous_strides_for(sharded_size)
)
# Assume even sharding for Shard(i), i > 0; otherwise would require
# copy-out for contiguous strides
new_sharded_grad = new_graph_call_function(
new_graph,
torch.ops.aten.as_strided.default,
(reduce_output,),
{
"size": sharded_size,
"stride": contiguous_sharded_stride,
"storage_offset": flat_grad_offset,
},
)
new_sharded_grads.append(new_sharded_grad)
padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator]
flat_grad_offset += padded_sharded_numel # type: ignore[assignment]
assert len(orig_wait_nodes) == len(new_sharded_grads)
assert len(orig_wait_nodes) > 0
for new_sharded_grad, orig_wait_node in zip(
new_sharded_grads, orig_wait_nodes
):
env[orig_wait_node] = new_sharded_grad # noqa: PERF403
for user in sorted(
orig_wait_node_recursive_users, key=lambda x: order[x]
):
# We skip output node here, because output node will be inserted (later)
# as the last node in the new graph.
if user.op != "output":
node_copy(
env, new_graph, user, lambda x: env_lookup(env, x, user)
)
bucket_id_is_scheduled[bucket_id] = True
new_sharded_grads.append(new_sharded_grad)
padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator]
flat_grad_offset += padded_sharded_numel # type: ignore[assignment]
assert len(orig_wait_nodes) == len(new_sharded_grads)
assert len(orig_wait_nodes) > 0
for new_sharded_grad, orig_wait_node in zip(
new_sharded_grads, orig_wait_nodes
):
env[orig_wait_node] = new_sharded_grad # noqa: PERF403
for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]):
# We skip output node here, because output node will be inserted (later)
# as the last node in the new graph.
if user.op != "output":
node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user))
bucket_id_is_scheduled[bucket_id] = True
else:
continue
assert node_list[-1].op == "output"