mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] Support case of several pgs in graph (#158632)
Main changes: - bucketing collectives only from the same process_group by group_name - Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b772de397
commit
371ffaf415
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -77,13 +78,9 @@ def bucket_all_gather_by_mb(
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`.
|
||||
|
||||
|
||||
Returns a list of buckets, where each bucket is a list of all_gather nodes.
|
||||
"""
|
||||
|
||||
node_list = gm.graph.nodes
|
||||
|
||||
# Prerequisite: Check if there is any all_gather node
|
||||
found_all_gather = False
|
||||
for node in node_list:
|
||||
@ -92,48 +89,53 @@ def bucket_all_gather_by_mb(
|
||||
break
|
||||
if not found_all_gather:
|
||||
return []
|
||||
|
||||
ag_nodes: list[torch.fx.Node] = []
|
||||
|
||||
group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
|
||||
defaultdict(list)
|
||||
)
|
||||
# Step 1: Find all all_gather nodes
|
||||
for node in node_list:
|
||||
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
|
||||
if (filter_wait_node is None) or filter_wait_node(node):
|
||||
ag_node = node.args[0]
|
||||
ag_nodes.append(ag_node)
|
||||
|
||||
_, group_size, group_name = ag_node.args
|
||||
dtype = ag_node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
group_name_ag_nodes[(group_name, dtype)].append(ag_node)
|
||||
# Step 2: Put all_gather nodes into buckets
|
||||
ag_buckets: list[list[torch.fx.Node]] = []
|
||||
for (group_name, dtype), ag_nodes in group_name_ag_nodes.items():
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
# Convert MiB to bytes
|
||||
all_gather_bucket_size_bytes = int(
|
||||
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
for ag_node in ag_nodes:
|
||||
assert is_all_gather_into_tensor(ag_node)
|
||||
if ag_node in cur_bucket_recursive_users:
|
||||
# We can not bucket successors with the node
|
||||
continue
|
||||
assert "val" in ag_node.meta
|
||||
ag_output_size_bytes = (
|
||||
ag_node.meta["val"].numel()
|
||||
* torch.finfo(ag_node.meta["val"].dtype).bits
|
||||
// 8
|
||||
)
|
||||
ag_n_val = ag_node.meta["val"]
|
||||
ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size()
|
||||
if (
|
||||
cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes
|
||||
cur_bucket_size_bytes + ag_output_size_bytes
|
||||
> all_gather_bucket_size_bytes
|
||||
and cur_bucket
|
||||
):
|
||||
# Current bucket is full, create new bucket
|
||||
if len(cur_bucket) > 1:
|
||||
ag_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
cur_bucket_size_bytes += ag_output_size_bytes
|
||||
cur_bucket.append(ag_node)
|
||||
if cur_bucket:
|
||||
find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users)
|
||||
if len(cur_bucket) > 1:
|
||||
# add remaining nodes in the last bucket
|
||||
ag_buckets.append(cur_bucket)
|
||||
|
||||
return ag_buckets
|
||||
|
||||
|
||||
@ -143,13 +145,9 @@ def bucket_reduce_scatter_by_mb(
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`.
|
||||
|
||||
|
||||
Returns a list of buckets, where each bucket is a list of reduce_scatter nodes.
|
||||
"""
|
||||
|
||||
node_list = list(gm.graph.nodes)
|
||||
|
||||
# Prerequisite: Check if there is any reduce_scatter node
|
||||
found_reduce_scatter = False
|
||||
for node in node_list:
|
||||
@ -158,18 +156,23 @@ def bucket_reduce_scatter_by_mb(
|
||||
break
|
||||
if not found_reduce_scatter:
|
||||
return []
|
||||
|
||||
rs_nodes: list[torch.fx.Node] = []
|
||||
|
||||
group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
|
||||
defaultdict(list)
|
||||
)
|
||||
# Step 1: Find all reduce_scatter nodes
|
||||
for node in node_list:
|
||||
if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]):
|
||||
rs_node = node.args[0]
|
||||
rs_nodes.append(rs_node)
|
||||
|
||||
_, reduce_op, group_size, group_name = rs_node.args
|
||||
dtype = rs_node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
assert isinstance(reduce_op, str)
|
||||
group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node)
|
||||
# Step 2: Put reduce_scatter nodes into buckets
|
||||
rs_buckets: list[list[torch.fx.Node]] = []
|
||||
for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items():
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
# Convert MiB to bytes
|
||||
@ -178,13 +181,13 @@ def bucket_reduce_scatter_by_mb(
|
||||
)
|
||||
for rs_node in rs_nodes:
|
||||
assert is_reduce_scatter_tensor(rs_node)
|
||||
if rs_node in cur_bucket_recursive_users:
|
||||
# We can not bucket successors with the node
|
||||
continue
|
||||
rs_input = rs_node.args[0]
|
||||
assert "val" in rs_input.meta # type: ignore[union-attr]
|
||||
rs_input_size_bytes = (
|
||||
rs_input.meta["val"].numel() # type: ignore[union-attr]
|
||||
* torch.finfo(rs_input.meta["val"].dtype).bits # type: ignore[union-attr]
|
||||
// 8
|
||||
)
|
||||
rs_in_val = rs_input.meta["val"] # type: ignore[union-attr]
|
||||
rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size()
|
||||
if (
|
||||
cur_bucket_size_bytes + rs_input_size_bytes
|
||||
> reduce_scatter_bucket_size_bytes
|
||||
@ -198,6 +201,7 @@ def bucket_reduce_scatter_by_mb(
|
||||
f"{cur_bucket_size_bytes} + {rs_input_size_bytes},"
|
||||
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
|
||||
)
|
||||
if len(cur_bucket) > 1:
|
||||
rs_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
@ -207,6 +211,7 @@ def bucket_reduce_scatter_by_mb(
|
||||
)
|
||||
cur_bucket_size_bytes += rs_input_size_bytes
|
||||
cur_bucket.append(rs_node)
|
||||
find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users)
|
||||
if cur_bucket:
|
||||
# add remaining nodes in the last bucket
|
||||
logger.info(
|
||||
@ -214,8 +219,8 @@ def bucket_reduce_scatter_by_mb(
|
||||
f"total_size = {cur_bucket_size_bytes}, "
|
||||
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
|
||||
)
|
||||
if len(cur_bucket) > 1:
|
||||
rs_buckets.append(cur_bucket)
|
||||
|
||||
return rs_buckets
|
||||
|
||||
|
||||
@ -260,6 +265,18 @@ def env_lookup( # type: ignore[no-untyped-def]
|
||||
return env[x]
|
||||
|
||||
|
||||
def _rank_idx_dict(group_name: str) -> dict[int, int]:
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_resolve_process_group,
|
||||
get_process_group_ranks,
|
||||
)
|
||||
|
||||
pg = _resolve_process_group(group_name)
|
||||
ranks = get_process_group_ranks(pg)
|
||||
rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)}
|
||||
return rank_idx_dict
|
||||
|
||||
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
|
||||
) -> None:
|
||||
@ -297,15 +314,13 @@ def merge_all_gather(
|
||||
bucket_id_is_scheduled = {}
|
||||
cast_bucket_id_is_scheduled = {}
|
||||
_, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args
|
||||
|
||||
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
|
||||
|
||||
for bucket_id, ag_bucket in enumerate(ag_buckets):
|
||||
ag_input_nodes = []
|
||||
wait_nodes = []
|
||||
for ag_node in ag_bucket:
|
||||
assert (
|
||||
ag_node in ag_node_to_wait_node
|
||||
and ag_node.args[1] == group_size
|
||||
and ag_node.args[2] == group_name
|
||||
)
|
||||
ag_input_nodes.append(ag_node.args[0])
|
||||
wait_nodes.append(ag_node_to_wait_node[ag_node])
|
||||
bucket_id_to_bucketed_op_info[bucket_id] = (
|
||||
@ -314,6 +329,8 @@ def merge_all_gather(
|
||||
group_name,
|
||||
wait_nodes,
|
||||
)
|
||||
if group_name not in group_name_to_rank_idx_dict:
|
||||
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
|
||||
|
||||
ag_wait_nodes = list(ag_node_to_wait_node.values())
|
||||
ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes)
|
||||
@ -334,9 +351,6 @@ def merge_all_gather(
|
||||
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
|
||||
bucket_id_to_bucketed_op_info[bucket_id]
|
||||
)
|
||||
# device = ag_input_nodes[0].meta["val"].device
|
||||
# rank = device.index
|
||||
# dtype = ag_input_nodes[0].meta["val"].dtype
|
||||
if all(
|
||||
n.op == "call_function" # type: ignore[union-attr]
|
||||
and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
|
||||
@ -398,6 +412,7 @@ def merge_all_gather(
|
||||
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
|
||||
bucket_id_to_bucketed_op_info[bucket_id]
|
||||
)
|
||||
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
|
||||
device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr]
|
||||
rank = device.index
|
||||
dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr]
|
||||
@ -468,7 +483,7 @@ def merge_all_gather(
|
||||
all_gather_output,
|
||||
inp_split_sizes,
|
||||
all_gather_input_numel,
|
||||
rank,
|
||||
rank_idx_dict[rank],
|
||||
),
|
||||
{},
|
||||
)
|
||||
@ -585,6 +600,7 @@ def merge_reduce_scatter(
|
||||
# Prepare bucketed operation info
|
||||
bucket_id_to_bucketed_op_info = {}
|
||||
bucket_id_is_scheduled = {}
|
||||
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
|
||||
for bucket_id, rs_bucket in enumerate(rs_buckets):
|
||||
_, reduce_op, group_size, group_name = next(
|
||||
iter(rs_node_to_wait_node.keys())
|
||||
@ -612,6 +628,8 @@ def merge_reduce_scatter(
|
||||
wait_nodes,
|
||||
wait_node_recursive_users,
|
||||
)
|
||||
if group_name not in group_name_to_rank_idx_dict:
|
||||
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
|
||||
|
||||
new_graph: torch.fx.Graph = torch.fx.Graph()
|
||||
env: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
@ -624,10 +642,12 @@ def merge_reduce_scatter(
|
||||
elif node in rs_node_to_wait_node:
|
||||
assert node in rs_node_to_bucket_id
|
||||
bucket_id = rs_node_to_bucket_id[node]
|
||||
if (
|
||||
if not (
|
||||
bucket_id not in bucket_id_is_scheduled
|
||||
and rs_buckets[bucket_id][-1] == node
|
||||
):
|
||||
continue
|
||||
|
||||
# If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node
|
||||
(
|
||||
rs_input_nodes,
|
||||
@ -637,35 +657,36 @@ def merge_reduce_scatter(
|
||||
orig_wait_nodes,
|
||||
orig_wait_node_recursive_users,
|
||||
) = bucket_id_to_bucketed_op_info[bucket_id]
|
||||
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
|
||||
# parents of rs have been scheduled, so we can directly use the env
|
||||
unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index]
|
||||
reduce_dtype = unsharded_grads[0].meta["val"].dtype
|
||||
# Only float32 and bfloat16 are supported for now.
|
||||
# To support fp16, please see FSDP2 `_get_gradient_divide_factors`.
|
||||
assert reduce_dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float32, # type: ignore[attr-defined]
|
||||
torch.bfloat16, # type: ignore[attr-defined]
|
||||
), f"reduce_dtype {reduce_dtype} is not supported"
|
||||
assert all(
|
||||
grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads
|
||||
)
|
||||
device = unsharded_grads[0].meta["val"].device
|
||||
rank = device.index
|
||||
rank_idx = rank_idx_dict[rank]
|
||||
shard_dim = 0
|
||||
|
||||
def _get_dim0_padded_size(
|
||||
tensor_size: torch.Size, dim0_factor: int
|
||||
) -> torch.Size:
|
||||
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
|
||||
tensor_size: torch.Size,
|
||||
dim0_factor: int, # type: ignore[name-defined]
|
||||
) -> torch.Size: # type: ignore[name-defined]
|
||||
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined]
|
||||
return torch.Size([padded_dim0]) + tensor_size[1:]
|
||||
|
||||
padded_unsharded_sizes = tuple(
|
||||
_get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type]
|
||||
for grad in unsharded_grads
|
||||
)
|
||||
reduce_scatter_input_numel = sum(
|
||||
s.numel() for s in padded_unsharded_sizes
|
||||
)
|
||||
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
|
||||
|
||||
"""
|
||||
NOTE: the relationship between the next few nodes is tricky:
|
||||
@ -737,7 +758,7 @@ def merge_reduce_scatter(
|
||||
group_size, # type: ignore[arg-type]
|
||||
dim=shard_dim,
|
||||
)
|
||||
sharded_param = chunks[rank]
|
||||
sharded_param = chunks[rank_idx]
|
||||
sharded_size = sharded_param.size()
|
||||
contiguous_sharded_stride = (
|
||||
torch._prims_common.make_contiguous_strides_for(sharded_size)
|
||||
@ -763,15 +784,11 @@ def merge_reduce_scatter(
|
||||
new_sharded_grads, orig_wait_nodes
|
||||
):
|
||||
env[orig_wait_node] = new_sharded_grad # noqa: PERF403
|
||||
for user in sorted(
|
||||
orig_wait_node_recursive_users, key=lambda x: order[x]
|
||||
):
|
||||
for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]):
|
||||
# We skip output node here, because output node will be inserted (later)
|
||||
# as the last node in the new graph.
|
||||
if user.op != "output":
|
||||
node_copy(
|
||||
env, new_graph, user, lambda x: env_lookup(env, x, user)
|
||||
)
|
||||
node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user))
|
||||
bucket_id_is_scheduled[bucket_id] = True
|
||||
else:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user