Files
pytorch/torch/_inductor/fx_passes/bucketing.py
IvanKobzarev 371ffaf415 [bucketing] Support case of several pgs in graph (#158632)
Main changes:
- bucketing collectives only from the same process_group by group_name
- Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632
Approved by: https://github.com/wconstab
2025-07-22 14:50:39 +00:00

800 lines
33 KiB
Python

import logging
import math
import operator
from collections import defaultdict
from typing import Any, Callable, Optional, Union
import torch
from torch._dispatch.python import enable_python_dispatcher
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def bucket_size_determinator(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.
Args:
bucket_id (int): The ID of the bucket.
Returns:
float: The size of the bucket.
"""
return 2000.0
def bucket_all_gather(
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
) -> None:
ag_buckets = bucket_all_gather_by_mb(gm, all_gather_bucket_cap_mb_callback)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)
def bucket_reduce_scatter(
gm: torch.fx.GraphModule,
reduce_scatter_bucket_cap_mb_callback: Callable[[int], float],
) -> None:
rs_buckets = bucket_reduce_scatter_by_mb(gm, reduce_scatter_bucket_cap_mb_callback)
if len(rs_buckets) == 0:
return
merge_reduce_scatter(gm, rs_buckets)
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
)
def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default
)
def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
def bucket_all_gather_by_mb(
gm: torch.fx.GraphModule,
all_gather_bucket_cap_mb_callback: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`.
Returns a list of buckets, where each bucket is a list of all_gather nodes.
"""
node_list = gm.graph.nodes
# Prerequisite: Check if there is any all_gather node
found_all_gather = False
for node in node_list:
if is_all_gather_into_tensor(node):
found_all_gather = True
break
if not found_all_gather:
return []
group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
defaultdict(list)
)
# Step 1: Find all all_gather nodes
for node in node_list:
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
if (filter_wait_node is None) or filter_wait_node(node):
ag_node = node.args[0]
_, group_size, group_name = ag_node.args
dtype = ag_node.meta["val"].dtype
assert isinstance(group_name, str)
group_name_ag_nodes[(group_name, dtype)].append(ag_node)
# Step 2: Put all_gather nodes into buckets
ag_buckets: list[list[torch.fx.Node]] = []
for (group_name, dtype), ag_nodes in group_name_ag_nodes.items():
cur_bucket: list[torch.fx.Node] = []
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
all_gather_bucket_size_bytes = int(
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
for ag_node in ag_nodes:
assert is_all_gather_into_tensor(ag_node)
if ag_node in cur_bucket_recursive_users:
# We can not bucket successors with the node
continue
assert "val" in ag_node.meta
ag_n_val = ag_node.meta["val"]
ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size()
if (
cur_bucket_size_bytes + ag_output_size_bytes
> all_gather_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
if len(cur_bucket) > 1:
ag_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
cur_bucket_size_bytes += ag_output_size_bytes
cur_bucket.append(ag_node)
find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users)
if len(cur_bucket) > 1:
# add remaining nodes in the last bucket
ag_buckets.append(cur_bucket)
return ag_buckets
def bucket_reduce_scatter_by_mb(
gm: torch.fx.GraphModule,
reduce_scatter_bucket_cap_mb_callback: Callable[[int], float],
) -> list[list[torch.fx.Node]]:
"""
Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`.
Returns a list of buckets, where each bucket is a list of reduce_scatter nodes.
"""
node_list = list(gm.graph.nodes)
# Prerequisite: Check if there is any reduce_scatter node
found_reduce_scatter = False
for node in node_list:
if is_reduce_scatter_tensor(node):
found_reduce_scatter = True
break
if not found_reduce_scatter:
return []
group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined]
defaultdict(list)
)
# Step 1: Find all reduce_scatter nodes
for node in node_list:
if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]):
rs_node = node.args[0]
_, reduce_op, group_size, group_name = rs_node.args
dtype = rs_node.meta["val"].dtype
assert isinstance(group_name, str)
assert isinstance(reduce_op, str)
group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node)
# Step 2: Put reduce_scatter nodes into buckets
rs_buckets: list[list[torch.fx.Node]] = []
for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items():
cur_bucket: list[torch.fx.Node] = []
cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
# Convert MiB to bytes
reduce_scatter_bucket_size_bytes = int(
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
for rs_node in rs_nodes:
assert is_reduce_scatter_tensor(rs_node)
if rs_node in cur_bucket_recursive_users:
# We can not bucket successors with the node
continue
rs_input = rs_node.args[0]
assert "val" in rs_input.meta # type: ignore[union-attr]
rs_in_val = rs_input.meta["val"] # type: ignore[union-attr]
rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size()
if (
cur_bucket_size_bytes + rs_input_size_bytes
> reduce_scatter_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
total_size = cur_bucket_size_bytes + rs_input_size_bytes
logger.info(
f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004
f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = "
f"{cur_bucket_size_bytes} + {rs_input_size_bytes},"
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
)
if len(cur_bucket) > 1:
rs_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
reduce_scatter_bucket_size_bytes = int(
reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
cur_bucket_size_bytes += rs_input_size_bytes
cur_bucket.append(rs_node)
find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users)
if cur_bucket:
# add remaining nodes in the last bucket
logger.info(
f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004
f"total_size = {cur_bucket_size_bytes}, "
f"bucket_cap = {reduce_scatter_bucket_size_bytes}"
)
if len(cur_bucket) > 1:
rs_buckets.append(cur_bucket)
return rs_buckets
def node_copy( # type: ignore[no-untyped-def]
env,
new_graph,
node: torch.fx.Node,
arg_transform: Callable[[torch.fx.Node], torch.fx.node.Argument],
) -> torch.fx.Node:
if node not in env:
new_node = new_graph.node_copy(node, arg_transform=arg_transform)
env[node] = new_node
else:
new_node = env[node]
return new_node
def new_graph_call_function( # type: ignore[no-untyped-def]
new_graph,
target: Callable[..., Any],
args: Optional[tuple[torch.fx.node.Argument, ...]] = None,
kwargs: Optional[dict[str, torch.fx.node.Argument]] = None,
type_expr: Optional[Any] = None,
) -> torch.fx.Node:
from torch.utils._pytree import tree_map_only
new_node = new_graph.call_function(target, args, kwargs)
args_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], args)
kwargs_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], kwargs)
with V.fake_mode, enable_python_dispatcher():
new_fake_tensor = target(*args_val, **kwargs_val)
new_node.meta["val"] = new_fake_tensor
return new_node
def env_lookup( # type: ignore[no-untyped-def]
env, x: torch.fx.Node, node_user: Union[torch.fx.Node, str]
) -> torch.fx.Node:
assert x in env, (
f"Dependent node {x} not in env when creating downstream node {node_user}"
)
return env[x]
def _rank_idx_dict(group_name: str) -> dict[int, int]:
from torch.distributed.distributed_c10d import (
_resolve_process_group,
get_process_group_ranks,
)
pg = _resolve_process_group(group_name)
ranks = get_process_group_ranks(pg)
rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)}
return rank_idx_dict
def merge_all_gather(
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
) -> None:
"""
Transforms the graph to use bucketed all_gather operations based on `ag_buckets`.
"""
assert len(ag_buckets) > 0
ag_nodes: list[torch.fx.Node] = []
cast_nodes: list[torch.fx.Node] = []
ag_node_to_wait_node: dict[torch.fx.Node, torch.fx.Node] = {}
ag_node_to_bucket_id = {}
cast_node_to_bucket_id = {}
# Map nodes to buckets and identify wait nodes
for bucket_id, bucket in enumerate(ag_buckets):
for ag_node in bucket:
assert len(ag_node.users) == 1, (
f"Expect only one user for {ag_node}, but got {ag_node.users}"
)
wait_node = next(iter(ag_node.users))
ag_node_to_wait_node[ag_node] = wait_node
ag_nodes.append(ag_node)
ag_node_to_bucket_id[ag_node] = bucket_id
if (
ag_node.args[0].op == "call_function" # type: ignore[union-attr]
and ag_node.args[0].target # type: ignore[union-attr]
== torch.ops.prims.convert_element_type.default
):
cast_nodes.append(ag_node.args[0]) # type: ignore[arg-type]
cast_node_to_bucket_id[ag_node.args[0]] = bucket_id # type: ignore[arg-type]
# Step 3: Create new (bucketed) all_gather nodes
bucket_id_to_bucketed_op_info = {}
bucket_id_is_scheduled = {}
cast_bucket_id_is_scheduled = {}
_, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
for bucket_id, ag_bucket in enumerate(ag_buckets):
ag_input_nodes = []
wait_nodes = []
for ag_node in ag_bucket:
ag_input_nodes.append(ag_node.args[0])
wait_nodes.append(ag_node_to_wait_node[ag_node])
bucket_id_to_bucketed_op_info[bucket_id] = (
ag_input_nodes,
group_size,
group_name,
wait_nodes,
)
if group_name not in group_name_to_rank_idx_dict:
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
ag_wait_nodes = list(ag_node_to_wait_node.values())
ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes)
cast_nodes = OrderedSet(cast_nodes)
new_graph: torch.fx.Graph = torch.fx.Graph()
env: dict[torch.fx.Node, torch.fx.Node] = {}
node_list = gm.graph.nodes
for node in node_list:
if node not in ag_and_wait_nodes and node not in cast_nodes:
# not cast-before-all_gather, all_gather or its wait_tensor - schedule it normally
node_copy(env, new_graph, node, lambda x: env_lookup(env, x, node))
elif node in cast_nodes:
# batch cast nodes together into one foreach_copy node
assert node in cast_node_to_bucket_id
bucket_id = cast_node_to_bucket_id[node]
if bucket_id not in cast_bucket_id_is_scheduled:
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
bucket_id_to_bucketed_op_info[bucket_id]
)
if all(
n.op == "call_function" # type: ignore[union-attr]
and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
for n in ag_input_nodes
):
param_all_gather_inputs = [
new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
(n.meta["val"].shape,), # type: ignore[union-attr]
{
"dtype": n.args[1], # type: ignore[union-attr]
"device": n.meta["val"].device, # type: ignore[union-attr]
"pin_memory": False,
},
)
for n in ag_input_nodes
]
for pp, n in zip(param_all_gather_inputs, ag_input_nodes):
pp.meta = n.meta.copy() # type: ignore[union-attr]
cast_input_nodes = [env[n.args[0]] for n in ag_input_nodes] # type: ignore[union-attr, index]
foreach_copy = new_graph_call_function(
new_graph,
torch.ops.aten._foreach_copy.default,
(param_all_gather_inputs, cast_input_nodes),
{},
)
foreach_copy.meta["val"] = [n.meta["val"] for n in ag_input_nodes] # type: ignore[union-attr]
getitems = [
new_graph_call_function(
new_graph,
operator.getitem,
(foreach_copy, i),
{},
)
for i in range(len(ag_input_nodes))
]
for new_n, old_n in zip(getitems, ag_input_nodes):
env[old_n] = new_n # type: ignore[index] # noqa: PERF403
else:
param_all_gather_inputs_orig = [
node_copy(
env,
new_graph,
ag_input_node, # type: ignore[arg-type]
lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type]
)
for ag_input_node in ag_input_nodes
]
cast_bucket_id_is_scheduled[bucket_id] = True
else:
continue
elif node in ag_node_to_wait_node:
assert node in ag_node_to_bucket_id
bucket_id = ag_node_to_bucket_id[node]
if bucket_id not in bucket_id_is_scheduled:
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
bucket_id_to_bucketed_op_info[bucket_id]
)
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr]
rank = device.index
dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr]
# TODO: if we want to support mixed dtype in the same bucket,
# we need to first view all all_gather inputs as uint8 (common denominator),
# then do the all_gather, then view the output back to the original dtype.
# Look at FSDP2 to see how to do this.
assert all(n.meta["val"].dtype == dtype for n in ag_input_nodes), ( # type: ignore[union-attr]
"All all_gather inputs in the same bucket must have the same dtype"
)
# must schedule all the all_gather input nodes first, before the bucketed all_gather node
param_all_gather_inputs_orig = [
node_copy(
env,
new_graph,
ag_input_node, # type: ignore[arg-type]
lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type]
)
for ag_input_node in ag_input_nodes
]
# schedule the bucketed all_gather node
param_all_gather_inputs_flattened = [
new_graph_call_function(
new_graph, torch.ops.aten.reshape.default, (n, [-1]), {}
)
for n in param_all_gather_inputs_orig
]
inp_split_sizes = [
n.meta["val"].numel() for n in param_all_gather_inputs_orig
]
param_all_gather_outputs = [
new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([n.meta["val"].numel() * group_size],),
{
"dtype": n.meta["val"].dtype,
"device": n.meta["val"].device,
"pin_memory": False,
},
)
for n in param_all_gather_inputs_orig
]
# TODO: This assumes dim-0 sharding.
# If we need to support sharding on another dim, we should look at how FSDP2 does it
# (e.g. search for `shard_dim` in FSDP2 codebase)
param_all_gather_outputs_shape_orig = [
(n.meta["val"].shape[0] * group_size,) + n.meta["val"].shape[1:]
for n in param_all_gather_inputs_orig
]
all_gather_input_numel = sum(inp_split_sizes)
all_gather_output = new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([all_gather_input_numel * group_size],),
{
"dtype": dtype,
"device": device,
"pin_memory": False,
},
)
all_gather_copy_in = new_graph_call_function(
new_graph,
torch.ops.fsdp.all_gather_copy_in.default,
(
param_all_gather_inputs_flattened,
all_gather_output,
inp_split_sizes,
all_gather_input_numel,
rank_idx_dict[rank],
),
{},
)
all_gather_input = new_graph_call_function(
new_graph,
operator.getitem,
(all_gather_copy_in, 0),
{},
)
all_gather_into_tensor_out = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
(all_gather_input, group_size, group_name),
{"out": all_gather_output},
)
wait_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.wait_tensor.default,
(all_gather_into_tensor_out,),
{},
)
all_gather_output_reshaped = new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(wait_tensor, [group_size, -1]),
{},
)
outs_flattened = [
new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(n, [group_size, -1]),
{},
)
for n in param_all_gather_outputs
]
split_with_sizes_copy = new_graph_call_function( # noqa: F841
new_graph,
torch.ops.fsdp.split_with_sizes_copy.default,
(all_gather_output_reshaped, inp_split_sizes),
{"dim": 1, "out": outs_flattened},
)
outs = [
new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(n, orig_shape),
{},
)
for n, orig_shape in zip(
outs_flattened, param_all_gather_outputs_shape_orig
)
]
assert len(orig_wait_nodes) == len(outs)
assert len(orig_wait_nodes) > 0
for out, orig_wait_node in zip(outs, orig_wait_nodes):
env[orig_wait_node] = out # noqa: PERF403
bucket_id_is_scheduled[bucket_id] = True
else:
continue
gm.graph = new_graph
def find_recursive_users_of_fx_node(node, collected_node_set, criteria_cb=None) -> None: # type: ignore[no-untyped-def]
if criteria_cb and criteria_cb(node):
return
for user_node in node.users:
if user_node in collected_node_set:
continue
collected_node_set.add(user_node)
find_recursive_users_of_fx_node(
user_node,
collected_node_set,
criteria_cb=criteria_cb,
)
def merge_reduce_scatter(
gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]]
) -> None:
"""
Transforms the graph to use bucketed reduce_scatter operations based on `rs_buckets`.
"""
assert len(rs_buckets) > 0
rs_nodes: list[torch.fx.Node] = []
rs_node_to_wait_node: dict[torch.fx.Node, torch.fx.Node] = {}
rs_node_to_bucket_id = {}
# Map nodes to buckets and identify wait nodes
for bucket_id, bucket in enumerate(rs_buckets):
for rs_node in bucket:
assert is_reduce_scatter_tensor(rs_node), (
f"Expected reduce_scatter node, got {rs_node}"
)
# Find the wait_tensor node that uses this reduce_scatter node
wait_nodes = list(rs_node.users)
assert len(wait_nodes) == 1, (
f"Expected exactly one user for {rs_node}, got {wait_nodes}"
)
wait_node = wait_nodes[0]
assert is_wait_tensor(wait_node), (
f"Expected wait_tensor node, got {wait_node}"
)
rs_node_to_wait_node[rs_node] = wait_node
rs_nodes.append(rs_node)
rs_node_to_bucket_id[rs_node] = bucket_id
order = {x: i for i, x in enumerate(gm.graph.nodes)}
rs_wait_nodes = list(rs_node_to_wait_node.values())
rs_and_its_recursive_users = OrderedSet(rs_nodes + rs_wait_nodes)
# Prepare bucketed operation info
bucket_id_to_bucketed_op_info = {}
bucket_id_is_scheduled = {}
group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {}
for bucket_id, rs_bucket in enumerate(rs_buckets):
_, reduce_op, group_size, group_name = next(
iter(rs_node_to_wait_node.keys())
).args
rs_input_nodes = []
wait_nodes = []
wait_node_recursive_users = OrderedSet() # type: ignore[var-annotated]
for rs_node in rs_bucket:
assert (
rs_node in rs_node_to_wait_node
and rs_node.args[1] == reduce_op
and rs_node.args[2] == group_size
and rs_node.args[3] == group_name
)
rs_input_nodes.append(rs_node.args[0])
wait_node = rs_node_to_wait_node[rs_node]
wait_nodes.append(wait_node)
find_recursive_users_of_fx_node(wait_node, wait_node_recursive_users)
rs_and_its_recursive_users |= wait_node_recursive_users
bucket_id_to_bucketed_op_info[bucket_id] = (
rs_input_nodes,
reduce_op,
group_size,
group_name,
wait_nodes,
wait_node_recursive_users,
)
if group_name not in group_name_to_rank_idx_dict:
group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index]
new_graph: torch.fx.Graph = torch.fx.Graph()
env: dict[torch.fx.Node, torch.fx.Node] = {}
node_list = list(gm.graph.nodes)
for node in node_list:
if node not in rs_and_its_recursive_users:
# not reduce_scatter or its (recursive) users - schedule it normally
node_copy(env, new_graph, node, lambda x: env_lookup(env, x, node))
elif node in rs_node_to_wait_node:
assert node in rs_node_to_bucket_id
bucket_id = rs_node_to_bucket_id[node]
if not (
bucket_id not in bucket_id_is_scheduled
and rs_buckets[bucket_id][-1] == node
):
continue
# If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node
(
rs_input_nodes,
reduce_op,
group_size,
group_name,
orig_wait_nodes,
orig_wait_node_recursive_users,
) = bucket_id_to_bucketed_op_info[bucket_id]
rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index]
# parents of rs have been scheduled, so we can directly use the env
unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index]
reduce_dtype = unsharded_grads[0].meta["val"].dtype
# Only float32 and bfloat16 are supported for now.
# To support fp16, please see FSDP2 `_get_gradient_divide_factors`.
assert reduce_dtype in (
torch.float32, # type: ignore[attr-defined]
torch.bfloat16, # type: ignore[attr-defined]
), f"reduce_dtype {reduce_dtype} is not supported"
assert all(
grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads
)
device = unsharded_grads[0].meta["val"].device
rank = device.index
rank_idx = rank_idx_dict[rank]
shard_dim = 0
def _get_dim0_padded_size(
tensor_size: torch.Size,
dim0_factor: int, # type: ignore[name-defined]
) -> torch.Size: # type: ignore[name-defined]
padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined]
return torch.Size([padded_dim0]) + tensor_size[1:]
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type]
for grad in unsharded_grads
)
reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes)
"""
NOTE: the relationship between the next few nodes is tricky:
- reduce_scatter_input_reshaped is a view of reduce_scatter_input
(same storage, same # elems, different shape).
- chunk_cat writes into reduce_scatter_input_reshaped,
which indirectly writes into reduce_scatter_input
(since they share the same storage).
- reduce_scatter_tensor reads from reduce_scatter_input.
"""
reduce_scatter_input = new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([reduce_scatter_input_numel],),
{
"dtype": reduce_dtype,
"device": device,
"pin_memory": False,
},
)
reduce_scatter_input_reshaped = new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(reduce_scatter_input, [group_size, -1]),
{},
)
new_graph_call_function(
new_graph,
torch.ops.fsdp.chunk_cat.default,
(unsharded_grads,),
{
"dim": 0,
"num_chunks": group_size,
"out": reduce_scatter_input_reshaped,
},
)
reduce_scatter_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
(reduce_scatter_input, reduce_op, group_size, group_name),
{},
)
wait_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.wait_tensor.default,
(reduce_scatter_tensor,),
{},
)
def _chunk_with_empty(
tensor: torch.Tensor, num_chunks: int, dim: int
) -> list[torch.Tensor]:
chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
while len(chunks) < num_chunks:
chunks.append(chunks[0].new_empty(0))
return chunks
reduce_output = wait_tensor
# View out and accumulate sharded gradients
new_sharded_grads = []
flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1]
for padded_unsharded_size, unsharded_grad in zip(
padded_unsharded_sizes, unsharded_grads
):
# NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here
chunks = _chunk_with_empty(
torch.empty_like(unsharded_grad.meta["val"], device="meta"),
group_size, # type: ignore[arg-type]
dim=shard_dim,
)
sharded_param = chunks[rank_idx]
sharded_size = sharded_param.size()
contiguous_sharded_stride = (
torch._prims_common.make_contiguous_strides_for(sharded_size)
)
# Assume even sharding for Shard(i), i > 0; otherwise would require
# copy-out for contiguous strides
new_sharded_grad = new_graph_call_function(
new_graph,
torch.ops.aten.as_strided.default,
(reduce_output,),
{
"size": sharded_size,
"stride": contiguous_sharded_stride,
"storage_offset": flat_grad_offset,
},
)
new_sharded_grads.append(new_sharded_grad)
padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator]
flat_grad_offset += padded_sharded_numel # type: ignore[assignment]
assert len(orig_wait_nodes) == len(new_sharded_grads)
assert len(orig_wait_nodes) > 0
for new_sharded_grad, orig_wait_node in zip(
new_sharded_grads, orig_wait_nodes
):
env[orig_wait_node] = new_sharded_grad # noqa: PERF403
for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]):
# We skip output node here, because output node will be inserted (later)
# as the last node in the new graph.
if user.op != "output":
node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user))
bucket_id_is_scheduled[bucket_id] = True
else:
continue
assert node_list[-1].op == "output"
# Finally, insert the output node
output_node = node_list[-1]
node_copy(env, new_graph, output_node, lambda x: env_lookup(env, x, output_node))
gm.graph = new_graph