Fx collectives bucketing: add bucket all_reduce (#165351)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165351
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-10-16 03:51:46 -07:00
committed by PyTorch MergeBot
parent f06e669f6c
commit 9272437cde
2 changed files with 171 additions and 0 deletions

View File

@ -34,11 +34,21 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
_, reduce_op, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
assert isinstance(reduce_op, str)
return (group_name, reduce_op, dtype)
def bucket_key(node: torch.fx.Node) -> object | None:
if is_all_gather_into_tensor(node):
return _ag_group_key(node)
elif is_reduce_scatter_tensor(node):
return _rs_group_key(node)
elif is_all_reduce_tensor(node):
return _ar_group_key(node)
else:
return None
@ -111,6 +121,13 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
)
def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_reduce.default
)
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
@ -293,6 +310,38 @@ def bucket_reduce_scatter_by_mb(
)
def bucket_all_reduce_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
) -> list[list[torch.fx.Node]]:
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
is_all_reduce_tensor,
_ar_group_key,
filter_wait_node,
)
def bucket_all_reduce(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
bucket_cap_mb_by_bucket_idx_default,
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx)
if len(ar_buckets) == 0:
return
for bucket in ar_buckets:
merge_all_reduce_bucket(gm.graph, bucket, mode)
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
def _pre_bucket_reduce_scatter(
rs_ins: list[torch.Tensor],
@ -364,6 +413,24 @@ def reduce_scatter_merge_fn_to_trace(
return new_outs
def all_reduce_merge_fn_to_trace(
ar_ins: list[torch.Tensor],
group_name: str,
reduce_op: str,
reduce_dtype: torch.dtype, # type: ignore[name-defined]
device: torch.device, # type: ignore[name-defined]
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
ar_ins_flattened = [x.view(-1) for x in ar_ins]
new_ar_in = torch.cat(ar_ins_flattened)
new_ar_out = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name)
)
split_sizes = [x.numel() for x in ar_ins]
new_outs_flat = new_ar_out.split(split_sizes)
new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)]
return new_outs
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
def _pre_bucket_all_gather(
ag_ins: list[torch.Tensor],
@ -713,6 +780,49 @@ def merge_reduce_scatter_bucket(
)
def merge_all_reduce_bucket(
g: torch.fx.Graph,
ar_nodes: list[torch.fx.Node],
mode: str | None = None,
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
ar0 = ar_nodes[0]
ar0_val = ar0.meta["val"]
_, reduce_op, group_name = ar0.args
reduce_dtype = ar0_val.dtype
device = ar0_val.device
for n in ar_nodes:
ar_val = n.meta["val"]
assert (
n.args[1] == reduce_op
and n.args[2] == group_name
and ar_val.device == device
and ar_val.dtype == reduce_dtype
)
ar_merge_fn = all_reduce_merge_fn_to_trace
def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]:
return (
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
group_name,
reduce_op,
reduce_dtype,
device,
)
return process_collective_bucket(
g,
ar_nodes,
ar_merge_fn,
create_trace_args,
insert_before=insert_before,
wait_insertion_point=wait_insertion_point,
)
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],