mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[bucketing] Support case of several pgs in graph (#158632)
Main changes: - bucketing collectives only from the same process_group by group_name - Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b772de397
commit
371ffaf415
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -77,13 +78,9 @@ def bucket_all_gather_by_mb(
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`.
|
||||
|
||||
|
||||
Returns a list of buckets, where each bucket is a list of all_gather nodes.
|
||||
"""
|
||||
|
||||
node_list = gm.graph.nodes
|
||||
|
||||
# Prerequisite: Check if there is any all_gather node
|
||||
found_all_gather = False
|
||||
for node in node_list:
|
||||
@ -92,48 +89,53 @@ def bucket_all_gather_by_mb(
|
||||
break
|
||||
if not found_all_gather:
|
||||
return []
|
||||
|
||||
ag_nodes: list[torch.fx.Node] = []
|
||||
|
||||
group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
|
||||
defaultdict(list)
|
||||
)
|
||||
# Step 1: Find all all_gather nodes
|
||||
for node in node_list:
|
||||
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
|
||||
if (filter_wait_node is None) or filter_wait_node(node):
|
||||
ag_node = node.args[0]
|
||||
ag_nodes.append(ag_node)
|
||||
|
||||
_, group_size, group_name = ag_node.args
|
||||
dtype = ag_node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
group_name_ag_nodes[(group_name, dtype)].append(ag_node)
|
||||
# Step 2: Put all_gather nodes into buckets
|
||||
ag_buckets: list[list[torch.fx.Node]] = []
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
# Convert MiB to bytes
|
||||
all_gather_bucket_size_bytes = int(
|
||||
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
for ag_node in ag_nodes:
|
||||
assert is_all_gather_into_tensor(ag_node)
|
||||
assert "val" in ag_node.meta
|
||||
ag_output_size_bytes = (
|
||||
ag_node.meta["val"].numel()
|
||||
* torch.finfo(ag_node.meta["val"].dtype).bits
|
||||
// 8
|
||||
for (group_name, dtype), ag_nodes in group_name_ag_nodes.items():
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
all_gather_bucket_size_bytes = int(
|
||||
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
if (
|
||||
cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes
|
||||
and cur_bucket
|
||||
):
|
||||
# Current bucket is full, create new bucket
|
||||
for ag_node in ag_nodes:
|
||||
assert is_all_gather_into_tensor(ag_node)
|
||||
if ag_node in cur_bucket_recursive_users:
|
||||
# We can not bucket successors with the node
|
||||
continue
|
||||
assert "val" in ag_node.meta
|
||||
ag_n_val = ag_node.meta["val"]
|
||||
ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size()
|
||||
if (
|
||||
cur_bucket_size_bytes + ag_output_size_bytes
|
||||
> all_gather_bucket_size_bytes
|
||||
and cur_bucket
|
||||
):
|
||||
# Current bucket is full, create new bucket
|
||||
if len(cur_bucket) > 1:
|
||||
ag_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
cur_bucket_size_bytes += ag_output_size_bytes
|
||||
cur_bucket.append(ag_node)
|
||||
find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users)
|
||||
if len(cur_bucket) > 1:
|
||||
# add remaining nodes in the last bucket
|
||||
ag_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
cur_bucket_size_bytes += ag_output_size_bytes
|
||||
cur_bucket.append(ag_node)
|
||||
if cur_bucket:
|
||||
# add remaining nodes in the last bucket
|
||||
ag_buckets.append(cur_bucket)
|
||||
|
||||
return ag_buckets
|
||||
|
||||
|
||||
@ -143,13 +145,9 @@ def bucket_reduce_scatter_by_mb(
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`.
|
||||
|
||||
|
||||
Returns a list of buckets, where each bucket is a list of reduce_scatter nodes.
|
||||
"""
|
||||
|
||||
node_list = list(gm.graph.nodes)
|
||||
|
||||
# Prerequisite: Check if there is any reduce_scatter node
|
||||
found_reduce_scatter = False
|
||||
for node in node_list:
|
||||
@ -158,64 +156,71 @@ def bucket_reduce_scatter_by_mb(
|
||||
break
|
||||
if not found_reduce_scatter:
|
||||
return []
|
||||
|
||||
rs_nodes: list[torch.fx.Node] = []
|
||||
|
||||
group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
|
||||
defaultdict(list)
|
||||
)
|
||||
# Step 1: Find all reduce_scatter nodes
|
||||
for node in node_list:
|
||||
if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]):
|
||||
rs_node = node.args[0]
|
||||
rs_nodes.append(rs_node)
|
||||
|
||||
_, reduce_op, group_size, group_name = rs_node.args
|
||||
dtype = rs_node.meta["val"].dtype
|
||||
assert isinstance(group_name, str)
|
||||
assert isinstance(reduce_op, str)
|
||||
group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node)
|
||||
# Step 2: Put reduce_scatter nodes into buckets
|
||||
rs_buckets: list[list[torch.fx.Node]] = []
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
# Convert MiB to bytes
|
||||
reduce_scatter_bucket_size_bytes = int(
|
||||
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
for rs_node in rs_nodes:
|
||||
assert is_reduce_scatter_tensor(rs_node)
|
||||
rs_input = rs_node.args[0]
|
||||
assert "val" in rs_input.meta # type: ignore[union-attr]
|
||||
rs_input_size_bytes = (
|
||||
rs_input.meta["val"].numel() # type: ignore[union-attr]
|
||||
* torch.finfo(rs_input.meta["val"].dtype).bits # type: ignore[union-attr]
|
||||
// 8
|
||||
for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items():
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
# Convert MiB to bytes
|
||||
reduce_scatter_bucket_size_bytes = int(
|
||||
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
if (
|
||||
cur_bucket_size_bytes + rs_input_size_bytes
|
||||
> reduce_scatter_bucket_size_bytes
|
||||
and cur_bucket
|
||||
):
|
||||
# Current bucket is full, create new bucket
|
||||
total_size = cur_bucket_size_bytes + rs_input_size_bytes
|
||||
for rs_node in rs_nodes:
|
||||
assert is_reduce_scatter_tensor(rs_node)
|
||||
if rs_node in cur_bucket_recursive_users:
|
||||
# We can not bucket successors with the node
|
||||
continue
|
||||
rs_input = rs_node.args[0]
|
||||
assert "val" in rs_input.meta # type: ignore[union-attr]
|
||||
rs_in_val = rs_input.meta["val"] # type: ignore[union-attr]
|
||||
rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size()
|
||||
if (
|
||||
cur_bucket_size_bytes + rs_input_size_bytes
|
||||
> reduce_scatter_bucket_size_bytes
|
||||
and cur_bucket
|
||||
):
|
||||
# Current bucket is full, create new bucket
|
||||
total_size = cur_bucket_size_bytes + rs_input_size_bytes
|
||||
logger.info(
|
||||
f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004
|
||||
f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = "
|
||||
f"{cur_bucket_size_bytes} + {rs_input_size_bytes},"
|
||||
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
|
||||
)
|
||||
if len(cur_bucket) > 1:
|
||||
rs_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
reduce_scatter_bucket_size_bytes = int(
|
||||
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
cur_bucket_size_bytes += rs_input_size_bytes
|
||||
cur_bucket.append(rs_node)
|
||||
find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users)
|
||||
if cur_bucket:
|
||||
# add remaining nodes in the last bucket
|
||||
logger.info(
|
||||
f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004
|
||||
f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = "
|
||||
f"{cur_bucket_size_bytes} + {rs_input_size_bytes},"
|
||||
f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004
|
||||
f"total_size = {cur_bucket_size_bytes}, "
|
||||
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
|
||||
)
|
||||
rs_buckets.append(cur_bucket)
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
reduce_scatter_bucket_size_bytes = int(
|
||||
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
cur_bucket_size_bytes += rs_input_size_bytes
|
||||
cur_bucket.append(rs_node)
|
||||
if cur_bucket:
|
||||
# add remaining nodes in the last bucket
|
||||
logger.info(
|
||||
f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004
|
||||
f"total_size = {cur_bucket_size_bytes}, "
|
||||
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
|
||||
)
|
||||
rs_buckets.append(cur_bucket)
|
||||
|
||||
if len(cur_bucket) > 1:
|
||||
rs_buckets.append(cur_bucket)
|
||||
return rs_buckets
|
||||
|
||||
|
||||
@ -260,6 +265,18 @@ def env_lookup( # type: ignore[no-untyped-def]
|
||||
return env[x]
|
||||
|
||||
|
||||
def _rank_idx_dict(group_name: str) -> dict[int, int]:
|
||||
from torch.distributed.distributed_c10d import (
|
||||
_resolve_process_group,
|
||||
get_process_group_ranks,
|
||||
)
|
||||
|
||||
pg = _resolve_process_group(group_name)
|
||||
ranks = get_process_group_ranks(pg)
|
||||
rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)}
|
||||
return rank_idx_dict
|
||||
|
||||
|
||||
def merge_all_gather(
|
||||
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
|
||||
) -> None:
|
||||
@ -297,15 +314,13 @@ def merge_all_gather(
|
||||
bucket_id_is_scheduled = {}
|
||||
cast_bucket_id_is_scheduled = {}
|
||||
_, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args
|
||||
|
||||
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
|
||||
|
||||
for bucket_id, ag_bucket in enumerate(ag_buckets):
|
||||
ag_input_nodes = []
|
||||
wait_nodes = []
|
||||
for ag_node in ag_bucket:
|
||||
assert (
|
||||
ag_node in ag_node_to_wait_node
|
||||
and ag_node.args[1] == group_size
|
||||
and ag_node.args[2] == group_name
|
||||
)
|
||||
ag_input_nodes.append(ag_node.args[0])
|
||||
wait_nodes.append(ag_node_to_wait_node[ag_node])
|
||||
bucket_id_to_bucketed_op_info[bucket_id] = (
|
||||
@ -314,6 +329,8 @@ def merge_all_gather(
|
||||
group_name,
|
||||
wait_nodes,
|
||||
)
|
||||
if group_name not in group_name_to_rank_idx_dict:
|
||||
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
|
||||
|
||||
ag_wait_nodes = list(ag_node_to_wait_node.values())
|
||||
ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes)
|
||||
@ -334,9 +351,6 @@ def merge_all_gather(
|
||||
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
|
||||
bucket_id_to_bucketed_op_info[bucket_id]
|
||||
)
|
||||
# device = ag_input_nodes[0].meta["val"].device
|
||||
# rank = device.index
|
||||
# dtype = ag_input_nodes[0].meta["val"].dtype
|
||||
if all(
|
||||
n.op == "call_function" # type: ignore[union-attr]
|
||||
and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
|
||||
@ -398,6 +412,7 @@ def merge_all_gather(
|
||||
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
|
||||
bucket_id_to_bucketed_op_info[bucket_id]
|
||||
)
|
||||
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
|
||||
device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr]
|
||||
rank = device.index
|
||||
dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr]
|
||||
@ -468,7 +483,7 @@ def merge_all_gather(
|
||||
all_gather_output,
|
||||
inp_split_sizes,
|
||||
all_gather_input_numel,
|
||||
rank,
|
||||
rank_idx_dict[rank],
|
||||
),
|
||||
{},
|
||||
)
|
||||
@ -585,6 +600,7 @@ def merge_reduce_scatter(
|
||||
# Prepare bucketed operation info
|
||||
bucket_id_to_bucketed_op_info = {}
|
||||
bucket_id_is_scheduled = {}
|
||||
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
|
||||
for bucket_id, rs_bucket in enumerate(rs_buckets):
|
||||
_, reduce_op, group_size, group_name = next(
|
||||
iter(rs_node_to_wait_node.keys())
|
||||
@ -612,6 +628,8 @@ def merge_reduce_scatter(
|
||||
wait_nodes,
|
||||
wait_node_recursive_users,
|
||||
)
|
||||
if group_name not in group_name_to_rank_idx_dict:
|
||||
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
|
||||
|
||||
new_graph: torch.fx.Graph = torch.fx.Graph()
|
||||
env: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
@ -624,155 +642,154 @@ def merge_reduce_scatter(
|
||||
elif node in rs_node_to_wait_node:
|
||||
assert node in rs_node_to_bucket_id
|
||||
bucket_id = rs_node_to_bucket_id[node]
|
||||
if (
|
||||
if not (
|
||||
bucket_id not in bucket_id_is_scheduled
|
||||
and rs_buckets[bucket_id][-1] == node
|
||||
):
|
||||
# If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node
|
||||
(
|
||||
rs_input_nodes,
|
||||
reduce_op,
|
||||
group_size,
|
||||
group_name,
|
||||
orig_wait_nodes,
|
||||
orig_wait_node_recursive_users,
|
||||
) = bucket_id_to_bucketed_op_info[bucket_id]
|
||||
# parents of rs have been scheduled, so we can directly use the env
|
||||
unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index]
|
||||
reduce_dtype = unsharded_grads[0].meta["val"].dtype
|
||||
# Only float32 and bfloat16 are supported for now.
|
||||
# To support fp16, please see FSDP2 `_get_gradient_divide_factors`.
|
||||
assert reduce_dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
), f"reduce_dtype {reduce_dtype} is not supported"
|
||||
assert all(
|
||||
grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads
|
||||
)
|
||||
device = unsharded_grads[0].meta["val"].device
|
||||
rank = device.index
|
||||
shard_dim = 0
|
||||
continue
|
||||
|
||||
def _get_dim0_padded_size(
|
||||
tensor_size: torch.Size, dim0_factor: int
|
||||
) -> torch.Size:
|
||||
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor
|
||||
return torch.Size([padded_dim0]) + tensor_size[1:]
|
||||
# If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node
|
||||
(
|
||||
rs_input_nodes,
|
||||
reduce_op,
|
||||
group_size,
|
||||
group_name,
|
||||
orig_wait_nodes,
|
||||
orig_wait_node_recursive_users,
|
||||
) = bucket_id_to_bucketed_op_info[bucket_id]
|
||||
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
|
||||
# parents of rs have been scheduled, so we can directly use the env
|
||||
unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index]
|
||||
reduce_dtype = unsharded_grads[0].meta["val"].dtype
|
||||
# Only float32 and bfloat16 are supported for now.
|
||||
# To support fp16, please see FSDP2 `_get_gradient_divide_factors`.
|
||||
assert reduce_dtype in (
|
||||
torch.float32, # type: ignore[attr-defined]
|
||||
torch.bfloat16, # type: ignore[attr-defined]
|
||||
), f"reduce_dtype {reduce_dtype} is not supported"
|
||||
assert all(
|
||||
grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads
|
||||
)
|
||||
device = unsharded_grads[0].meta["val"].device
|
||||
rank = device.index
|
||||
rank_idx = rank_idx_dict[rank]
|
||||
shard_dim = 0
|
||||
|
||||
padded_unsharded_sizes = tuple(
|
||||
_get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type]
|
||||
for grad in unsharded_grads
|
||||
)
|
||||
reduce_scatter_input_numel = sum(
|
||||
s.numel() for s in padded_unsharded_sizes
|
||||
)
|
||||
def _get_dim0_padded_size(
|
||||
tensor_size: torch.Size,
|
||||
dim0_factor: int, # type: ignore[name-defined]
|
||||
) -> torch.Size: # type: ignore[name-defined]
|
||||
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined]
|
||||
return torch.Size([padded_dim0]) + tensor_size[1:]
|
||||
|
||||
"""
|
||||
NOTE: the relationship between the next few nodes is tricky:
|
||||
- reduce_scatter_input_reshaped is a view of reduce_scatter_input
|
||||
(same storage, same # elems, different shape).
|
||||
- chunk_cat writes into reduce_scatter_input_reshaped,
|
||||
which indirectly writes into reduce_scatter_input
|
||||
(since they share the same storage).
|
||||
- reduce_scatter_tensor reads from reduce_scatter_input.
|
||||
"""
|
||||
reduce_scatter_input = new_graph_call_function(
|
||||
padded_unsharded_sizes = tuple(
|
||||
_get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type]
|
||||
for grad in unsharded_grads
|
||||
)
|
||||
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
|
||||
|
||||
"""
|
||||
NOTE: the relationship between the next few nodes is tricky:
|
||||
- reduce_scatter_input_reshaped is a view of reduce_scatter_input
|
||||
(same storage, same # elems, different shape).
|
||||
- chunk_cat writes into reduce_scatter_input_reshaped,
|
||||
which indirectly writes into reduce_scatter_input
|
||||
(since they share the same storage).
|
||||
- reduce_scatter_tensor reads from reduce_scatter_input.
|
||||
"""
|
||||
reduce_scatter_input = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.empty.memory_format,
|
||||
([reduce_scatter_input_numel],),
|
||||
{
|
||||
"dtype": reduce_dtype,
|
||||
"device": device,
|
||||
"pin_memory": False,
|
||||
},
|
||||
)
|
||||
reduce_scatter_input_reshaped = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.reshape.default,
|
||||
(reduce_scatter_input, [group_size, -1]),
|
||||
{},
|
||||
)
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.fsdp.chunk_cat.default,
|
||||
(unsharded_grads,),
|
||||
{
|
||||
"dim": 0,
|
||||
"num_chunks": group_size,
|
||||
"out": reduce_scatter_input_reshaped,
|
||||
},
|
||||
)
|
||||
reduce_scatter_tensor = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
||||
(reduce_scatter_input, reduce_op, group_size, group_name),
|
||||
{},
|
||||
)
|
||||
|
||||
wait_tensor = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
(reduce_scatter_tensor,),
|
||||
{},
|
||||
)
|
||||
|
||||
def _chunk_with_empty(
|
||||
tensor: torch.Tensor, num_chunks: int, dim: int
|
||||
) -> list[torch.Tensor]:
|
||||
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
|
||||
while len(chunks) < num_chunks:
|
||||
chunks.append(chunks[0].new_empty(0))
|
||||
return chunks
|
||||
|
||||
reduce_output = wait_tensor
|
||||
# View out and accumulate sharded gradients
|
||||
new_sharded_grads = []
|
||||
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
|
||||
for padded_unsharded_size, unsharded_grad in zip(
|
||||
padded_unsharded_sizes, unsharded_grads
|
||||
):
|
||||
# NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here
|
||||
chunks = _chunk_with_empty(
|
||||
torch.empty_like(unsharded_grad.meta["val"], device="meta"),
|
||||
group_size, # type: ignore[arg-type]
|
||||
dim=shard_dim,
|
||||
)
|
||||
sharded_param = chunks[rank_idx]
|
||||
sharded_size = sharded_param.size()
|
||||
contiguous_sharded_stride = (
|
||||
torch._prims_common.make_contiguous_strides_for(sharded_size)
|
||||
)
|
||||
# Assume even sharding for Shard(i), i > 0; otherwise would require
|
||||
# copy-out for contiguous strides
|
||||
new_sharded_grad = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.empty.memory_format,
|
||||
([reduce_scatter_input_numel],),
|
||||
torch.ops.aten.as_strided.default,
|
||||
(reduce_output,),
|
||||
{
|
||||
"dtype": reduce_dtype,
|
||||
"device": device,
|
||||
"pin_memory": False,
|
||||
"size": sharded_size,
|
||||
"stride": contiguous_sharded_stride,
|
||||
"storage_offset": flat_grad_offset,
|
||||
},
|
||||
)
|
||||
reduce_scatter_input_reshaped = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.reshape.default,
|
||||
(reduce_scatter_input, [group_size, -1]),
|
||||
{},
|
||||
)
|
||||
new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.fsdp.chunk_cat.default,
|
||||
(unsharded_grads,),
|
||||
{
|
||||
"dim": 0,
|
||||
"num_chunks": group_size,
|
||||
"out": reduce_scatter_input_reshaped,
|
||||
},
|
||||
)
|
||||
reduce_scatter_tensor = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
||||
(reduce_scatter_input, reduce_op, group_size, group_name),
|
||||
{},
|
||||
)
|
||||
|
||||
wait_tensor = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops._c10d_functional.wait_tensor.default,
|
||||
(reduce_scatter_tensor,),
|
||||
{},
|
||||
)
|
||||
|
||||
def _chunk_with_empty(
|
||||
tensor: torch.Tensor, num_chunks: int, dim: int
|
||||
) -> list[torch.Tensor]:
|
||||
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
|
||||
while len(chunks) < num_chunks:
|
||||
chunks.append(chunks[0].new_empty(0))
|
||||
return chunks
|
||||
|
||||
reduce_output = wait_tensor
|
||||
# View out and accumulate sharded gradients
|
||||
new_sharded_grads = []
|
||||
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
|
||||
for padded_unsharded_size, unsharded_grad in zip(
|
||||
padded_unsharded_sizes, unsharded_grads
|
||||
):
|
||||
# NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here
|
||||
chunks = _chunk_with_empty(
|
||||
torch.empty_like(unsharded_grad.meta["val"], device="meta"),
|
||||
group_size, # type: ignore[arg-type]
|
||||
dim=shard_dim,
|
||||
)
|
||||
sharded_param = chunks[rank]
|
||||
sharded_size = sharded_param.size()
|
||||
contiguous_sharded_stride = (
|
||||
torch._prims_common.make_contiguous_strides_for(sharded_size)
|
||||
)
|
||||
# Assume even sharding for Shard(i), i > 0; otherwise would require
|
||||
# copy-out for contiguous strides
|
||||
new_sharded_grad = new_graph_call_function(
|
||||
new_graph,
|
||||
torch.ops.aten.as_strided.default,
|
||||
(reduce_output,),
|
||||
{
|
||||
"size": sharded_size,
|
||||
"stride": contiguous_sharded_stride,
|
||||
"storage_offset": flat_grad_offset,
|
||||
},
|
||||
)
|
||||
new_sharded_grads.append(new_sharded_grad)
|
||||
padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator]
|
||||
flat_grad_offset += padded_sharded_numel # type: ignore[assignment]
|
||||
assert len(orig_wait_nodes) == len(new_sharded_grads)
|
||||
assert len(orig_wait_nodes) > 0
|
||||
for new_sharded_grad, orig_wait_node in zip(
|
||||
new_sharded_grads, orig_wait_nodes
|
||||
):
|
||||
env[orig_wait_node] = new_sharded_grad # noqa: PERF403
|
||||
for user in sorted(
|
||||
orig_wait_node_recursive_users, key=lambda x: order[x]
|
||||
):
|
||||
# We skip output node here, because output node will be inserted (later)
|
||||
# as the last node in the new graph.
|
||||
if user.op != "output":
|
||||
node_copy(
|
||||
env, new_graph, user, lambda x: env_lookup(env, x, user)
|
||||
)
|
||||
bucket_id_is_scheduled[bucket_id] = True
|
||||
new_sharded_grads.append(new_sharded_grad)
|
||||
padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator]
|
||||
flat_grad_offset += padded_sharded_numel # type: ignore[assignment]
|
||||
assert len(orig_wait_nodes) == len(new_sharded_grads)
|
||||
assert len(orig_wait_nodes) > 0
|
||||
for new_sharded_grad, orig_wait_node in zip(
|
||||
new_sharded_grads, orig_wait_nodes
|
||||
):
|
||||
env[orig_wait_node] = new_sharded_grad # noqa: PERF403
|
||||
for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]):
|
||||
# We skip output node here, because output node will be inserted (later)
|
||||
# as the last node in the new graph.
|
||||
if user.op != "output":
|
||||
node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user))
|
||||
bucket_id_is_scheduled[bucket_id] = True
|
||||
else:
|
||||
continue
|
||||
assert node_list[-1].op == "output"
|
||||
|
Reference in New Issue
Block a user