[inductor][bucketing] Fx collectives bucketing of multiple dtypes (#162470)

Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-10-16 06:41:59 -07:00
committed by PyTorch MergeBot
parent 1a34ff4e04
commit 7d87d7052e
4 changed files with 141 additions and 34 deletions

View File

@ -1804,6 +1804,63 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
correct = f(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all_custom_ops_multidtype"])
def test_all_gather_bucket_multidtype(self, bucket_mode):
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ag_0_w = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0, group_size, group_name
)
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_w)
ag_0_out = ag_0_out * 2
ag_1_w = torch.ops._c10d_functional.all_gather_into_tensor(
ag_1, group_size, group_name
)
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_w)
return y, ag_0_out, ag_1_out
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.bfloat16)
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
correct = func(*inputs, **self.get_world_trs())
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": bucket_mode,
"reorder_for_compute_comm_overlap": False,
}
):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
(
FileCheck()
.check_count(
"torch.ops._c10d_functional.all_gather_into_tensor_out.default(",
count=1,
exactly=True,
)
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
_, y_ag0, y_ag1 = out
assert y_ag0.dtype == ag_0.dtype
assert y_ag1.dtype == ag_1.dtype
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
@parametrize("bucket_mode", ["all", "all_custom_ops"])

View File

@ -1,7 +1,8 @@
import collections
import logging
import operator
from collections import defaultdict
from typing import Any, Callable
from typing import Any, Callable, Literal, TypeAlias
import torch
import torch.distributed as dist
@ -17,16 +18,24 @@ from torch.utils._ordered_set import OrderedSet
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"]
# Helper functions moved to top for better organization
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: # type: ignore[name-defined]
_, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
return (group_name, dtype)
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
def _ag_group_key_multidtype(node: torch.fx.Node) -> tuple[str]:
_, group_size, group_name = node.args
assert isinstance(group_name, str)
return (group_name,)
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: # type: ignore[name-defined]
_, reduce_op, group_size, group_name = node.args
dtype = node.meta["val"].dtype
assert isinstance(group_name, str)
@ -53,6 +62,11 @@ def bucket_key(node: torch.fx.Node) -> object | None:
return None
def pick_bucket_dtype(dtypes: list[torch.dtype]) -> torch.dtype: # type: ignore[name-defined]
assert len(dtypes) > 0
return min(dtypes, key=operator.attrgetter("itemsize"))
def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.
@ -69,7 +83,7 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
def bucket_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -77,7 +91,7 @@ def bucket_all_gather(
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx, None, mode)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets, mode)
@ -86,7 +100,7 @@ def bucket_all_gather(
def bucket_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -94,7 +108,9 @@ def bucket_reduce_scatter(
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
rs_buckets = bucket_reduce_scatter_by_mb(
gm, bucket_cap_mb_by_bucket_idx, None, mode
)
if len(rs_buckets) == 0:
return
merge_reduce_scatter(gm, rs_buckets, mode)
@ -252,6 +268,7 @@ def bucket_all_gather_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
mode: BucketMode = "default",
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets,
@ -271,11 +288,15 @@ def bucket_all_gather_by_mb(
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
"""
group_key_fn = (
_ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key
)
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
is_all_gather_into_tensor,
_ag_group_key,
group_key_fn,
filter_wait_node,
)
@ -284,6 +305,7 @@ def bucket_reduce_scatter_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
filter_wait_node: Callable[[torch.fx.Node], bool] | None = None,
mode: BucketMode = "default",
) -> list[list[torch.fx.Node]]:
"""
Identifies all reduce_scatter nodes and groups them into buckets,
@ -301,6 +323,10 @@ def bucket_reduce_scatter_by_mb(
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
"""
assert "multidtype" not in mode, (
"reduce scatter bucketing does not support multidtype"
)
return greedy_bucket_collective_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
@ -439,13 +465,17 @@ def _pre_bucket_all_gather(
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> torch.Tensor:
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
device = ag_ins[0].device
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins]
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
return new_ag_out
@ -457,7 +487,11 @@ def _pre_bucket_all_gather_fake(
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> torch.Tensor:
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
device = ag_ins[0].device
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
@ -468,14 +502,28 @@ _pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
def all_gather_merge_fn_to_trace_custom_ops(
ag_ins: list[torch.Tensor],
_ag_ins: list[torch.Tensor],
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
) -> list[torch.Tensor]:
ag_ins = [
torch._prims.convert_element_type(_ag_in, out_dtype)
if _ag_in.dtype != out_dtype
else _ag_in
for _ag_in, out_dtype in zip(_ag_ins, out_dtypes)
]
ins_sizes = [ag_in.shape for ag_in in ag_ins]
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ins_split_sizes_bytes = [
ag_in.numel() * out_dtype.itemsize
for ag_in, out_dtype in zip(ag_ins, out_dtypes)
]
bucket_dtype_size_bytes = dtype.itemsize
ins_split_sizes = [
_bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes
]
ag_input_numel = sum(ins_split_sizes)
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
ag_ins, group_size, group_name, dtype, rank
@ -487,14 +535,14 @@ def all_gather_merge_fn_to_trace_custom_ops(
)
)
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
outs = torch.split_with_sizes(
outs_bucket_dtype = torch.split_with_sizes(
new_ag_out_reshaped,
ins_split_sizes,
dim=1,
)
outs_reshaped = [
o.reshape((shape[0] * group_size,) + shape[1:])
for o, shape in zip(outs, ins_sizes)
o.view(out_dtype).reshape((shape[0] * group_size,) + shape[1:])
for o, shape, out_dtype in zip(outs_bucket_dtype, ins_sizes, out_dtypes)
]
return outs_reshaped
@ -504,6 +552,7 @@ def all_gather_merge_fn_to_trace(
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
) -> list[torch.Tensor]:
ins_sizes = [ag_in.shape for ag_in in ag_ins]
@ -538,6 +587,7 @@ def all_gather_merge_fn_to_trace_functional(
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
out_dtypes: list[torch.dtype], # type: ignore[name-defined]
rank: int,
use_fsdp_ag_copy_in: bool = False,
) -> list[torch.Tensor]:
@ -733,7 +783,7 @@ def process_collective_bucket(
def merge_reduce_scatter_bucket(
g: torch.fx.Graph,
rs_nodes: list[torch.fx.Node],
mode: str | None = None,
mode: BucketMode = "default",
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
@ -826,29 +876,27 @@ def merge_all_reduce_bucket(
def merge_all_gather_bucket(
g: torch.fx.Graph,
ag_nodes: list[torch.fx.Node],
mode: str | None = None,
mode: BucketMode = "default",
insert_before: torch.fx.Node | None = None,
wait_insertion_point: torch.fx.Node | None = None,
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
from torch.distributed.distributed_c10d import _resolve_process_group
ag0 = ag_nodes[0]
ag0_val = ag0.meta["val"]
_, group_size, group_name = ag0.args
dtype = ag0_val.dtype
assert isinstance(group_name, str)
_ag_dtypes: list[torch.dtype] = [] # type: ignore[name-defined]
for n in ag_nodes:
assert (
n.args[1] == group_size
and n.args[2] == group_name
and n.meta["val"].dtype == dtype
)
assert n.args[1] == group_size and n.args[2] == group_name
_ag_dtypes.append(n.meta["val"].dtype)
bucket_dtype = pick_bucket_dtype(_ag_dtypes)
# Choose merge function based on mode
ag_merge_fn = all_gather_merge_fn_to_trace
if mode and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
if mode is not None and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment]
# Process bucket with lazy input collection
rank: int = dist.get_rank(_resolve_process_group(group_name))
@ -858,7 +906,8 @@ def merge_all_gather_bucket(
pytree.tree_map(lambda node: node.meta["val"], bucket_ins),
group_size,
group_name,
dtype,
bucket_dtype,
_ag_dtypes,
rank,
)
@ -874,7 +923,7 @@ def merge_all_gather_bucket(
def merge_reduce_scatter(
gm: torch.fx.GraphModule,
rs_buckets: list[list[torch.fx.Node]],
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Merges specified buckets of reduce_scatter to joint reduce_scatter.
@ -898,7 +947,7 @@ def merge_reduce_scatter(
def merge_all_gather(
gm: torch.fx.GraphModule,
ag_buckets: list[list[torch.fx.Node]],
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Merges specified buckets of all_gather to joint all_gather.

View File

@ -5,6 +5,7 @@ import torch
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather_by_mb,
bucket_reduce_scatter_by_mb,
BucketMode,
merge_all_gather,
merge_reduce_scatter,
)
@ -56,7 +57,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
def bucket_fsdp_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Bucketing pass for SimpleFSDP all_gather ops.
@ -86,7 +87,7 @@ def bucket_fsdp_all_gather(
def bucket_fsdp_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None,
mode: str | None = None,
mode: BucketMode = "default",
) -> None:
"""
Bucketing pass for SimpleFSDP reduce_scatter ops.

View File

@ -216,7 +216,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
lambda graph: p(
graph.owning_module,
config.bucket_reduce_scatters_fx_bucket_size_determinator,
config.bucket_reduce_scatters_fx,
config.bucket_reduce_scatters_fx, # type: ignore[arg-type]
)
)
collectives_bucketing = True
@ -236,7 +236,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
lambda graph: p(
graph.owning_module,
config.bucket_all_gathers_fx_bucket_size_determinator,
config.bucket_all_gathers_fx,
config.bucket_all_gathers_fx, # type: ignore[arg-type]
)
)
collectives_bucketing = True