[bucketing] custom_ops mode to hide inductor copies overhead (#161499)

Adding "_custom_ops" bucketing to temporary fallback to eager execution of for_each,
to workaround too many generated kernels on inductor side.

This PR also reverts parts of bucketing changes for cycles detection that resulted in accuracy problems.

Differential Revision: [D81152293](https://our.internmc.facebook.com/intern/diff/D81152293)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161499
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-09-08 09:03:32 -07:00
committed by PyTorch MergeBot
parent 9c991b63ff
commit 8ec01f34e9
4 changed files with 197 additions and 30 deletions

View File

@ -1528,7 +1528,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_all_gather_bucket(self):
@parametrize("bucket_mode", ["all", "all_custom_ops"])
def test_all_gather_bucket(self, bucket_mode):
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
@ -1576,7 +1577,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with (
torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx": bucket_mode,
"reorder_for_compute_comm_overlap": False,
"runtime_estimations_mms_benchmark": True,
}
@ -1595,7 +1596,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
# We want to make sure no unnecessary copy is made.
(
FileCheck()
.check_count(".all_gather_into_tensor_out.default(", 2, exactly=True)
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
.check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(")
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
@ -1656,7 +1659,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_reduce_scatter_bucket(self):
@parametrize("bucket_mode", ["all", "all_custom_ops"])
def test_reduce_scatter_bucket(self, bucket_mode):
def func(x, w, rs_0, rs_1, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
@ -1697,7 +1701,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with torch._inductor.config.patch(
{
"bucket_reduce_scatters_fx": "fsdp",
"bucket_reduce_scatters_fx": bucket_mode,
"reorder_for_compute_comm_overlap": False,
}
):
@ -1723,7 +1727,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_reorder_peak_memory_bucketed(self):
@parametrize("bucket_mode", ["all", "all_custom_ops"])
def test_reorder_peak_memory_bucketed(self, bucket_mode):
"""
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
@ -1837,9 +1842,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with (
torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx": bucket_mode,
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx": "all",
"bucket_reduce_scatters_fx": bucket_mode,
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [

View File

@ -1,5 +1,6 @@
import collections
import logging
from collections import defaultdict
from typing import Any, Callable, Optional
import torch
@ -33,6 +34,7 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
def bucket_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -43,13 +45,13 @@ def bucket_all_gather(
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)
merge_all_gather(gm, ag_buckets, mode)
def bucket_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
) -> None:
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
@ -60,7 +62,7 @@ def bucket_reduce_scatter(
rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
if len(rs_buckets) == 0:
return
merge_reduce_scatter(gm, rs_buckets)
merge_reduce_scatter(gm, rs_buckets, mode)
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
@ -131,28 +133,46 @@ def greedy_bucket_collective_by_mb(
node_group_key: Callable[[torch.fx.Node], Any],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
if not gm.graph.find_nodes(
op="call_function", target=torch.ops._c10d_functional.wait_tensor.default
):
return []
"""
Bucketing adjacent collectives with equal node_group_key.
We can not bucket non adjacent collectives,
as this will effectively change the order of collectives.
Reordering can lead to different order on different ranks.
"""
g = gm.graph
found_candidates = False
for node in g.nodes:
if filter_node(node):
found_candidates = True
break
if not found_candidates:
return []
# TODO: pearce kelly algorithm for detecting cycles
node_descendents = collect_node_descendants(gm.graph)
node_groups: dict[Any, list[torch.fx.Node]] = collections.defaultdict(list)
nodes_groups: list[list[torch.fx.Node]] = []
cur_group: list[torch.fx.Node] = []
cur_group_key = None
for node in g.nodes:
if is_wait_tensor(node) and filter_node(node.args[0]):
if (filter_wait_node is None) or filter_wait_node(node):
coll_node = node.args[0]
group_key = node_group_key(coll_node)
node_groups[group_key].append(coll_node)
if group_key == cur_group_key:
cur_group.append(coll_node)
else:
if len(cur_group) > 1:
nodes_groups.append(cur_group)
cur_group = [coll_node]
cur_group_key = group_key
if len(cur_group) > 1:
nodes_groups.append(cur_group)
buckets: list[list[torch.fx.Node]] = []
for nodes in node_groups.values():
for nodes in nodes_groups:
cur_bucket: list[torch.fx.Node] = []
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
@ -261,6 +281,52 @@ def bucket_reduce_scatter_by_mb(
)
@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
def _pre_bucket_reduce_scatter(
rs_ins: list[torch.Tensor],
group_size: int,
) -> torch.Tensor:
rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
return new_rs_in
def _pre_bucket_reduce_scatter_fake(
rs_ins: list[torch.Tensor],
group_size: int,
) -> torch.Tensor:
out_numel = sum(rs_in.numel() for rs_in in rs_ins)
return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)
_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)
def reduce_scatter_merge_fn_to_trace_custom_ops(
rs_ins: list[torch.Tensor],
group_size: int,
group_name: str,
reduce_op: str,
reduce_dtype: torch.dtype, # type: ignore[name-defined]
device: torch.device, # type: ignore[name-defined]
) -> list[torch.Tensor]: # type: ignore[no-untyped-def]
new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
new_out_numels = [x.numel() // group_size for x in rs_ins]
new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)
# TODO - either use torch.cat or make sure inductor foreach codegen
# fires more reliably
new_rs_out = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.reduce_scatter_tensor.default(
new_rs_in, reduce_op, group_size, group_name
)
)
new_out_flat = new_rs_out.split(new_out_numels, 0)
new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
return new_outs
def reduce_scatter_merge_fn_to_trace(
rs_ins: list[torch.Tensor],
group_size: int,
@ -276,8 +342,6 @@ def reduce_scatter_merge_fn_to_trace(
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
# TODO - either use torch.cat or make sure inductor foreach codegen
# fires more reliably
new_rs_out = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.reduce_scatter_tensor.default(
new_rs_in, reduce_op, group_size, group_name
@ -288,6 +352,74 @@ def reduce_scatter_merge_fn_to_trace(
return new_outs
@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
def _pre_bucket_all_gather(
ag_ins: list[torch.Tensor],
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> torch.Tensor:
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ag_input_numel = sum(ins_split_sizes)
device = ag_ins[0].device
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
return new_ag_out
def _pre_bucket_all_gather_fake(
ag_ins: list[torch.Tensor],
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> torch.Tensor:
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ag_input_numel = sum(ins_split_sizes)
device = ag_ins[0].device
new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
return new_ag_out
_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)
def all_gather_merge_fn_to_trace_custom_ops(
ag_ins: list[torch.Tensor],
group_size: int,
group_name: str,
dtype: torch.dtype, # type: ignore[name-defined]
rank: int,
) -> list[torch.Tensor]:
ins_sizes = [ag_in.shape for ag_in in ag_ins]
ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
ag_input_numel = sum(ins_split_sizes)
new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
ag_ins, group_size, group_name, dtype, rank
)
new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
wait_tensor = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.all_gather_into_tensor_out.default(
new_ag_in, group_size, group_name, out=new_ag_out
)
)
new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
outs = torch.split_with_sizes(
new_ag_out_reshaped,
ins_split_sizes,
dim=1,
)
outs_reshaped = [
o.reshape((shape[0] * group_size,) + shape[1:])
for o, shape in zip(outs, ins_sizes)
]
return outs_reshaped
def all_gather_merge_fn_to_trace(
ag_ins: list[torch.Tensor],
group_size: int,
@ -420,9 +552,17 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
def merge_reduce_scatter(
gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]]
gm: torch.fx.GraphModule,
rs_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
) -> None:
"""
Merges specified buckets of reduce_scatter to joint reduce_scatter.
"""
with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
rs_merge_fn = reduce_scatter_merge_fn_to_trace
if mode and "custom_ops" in mode:
rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
trace_structured(
"artifact",
metadata_fn=lambda: {
@ -469,7 +609,7 @@ def merge_reduce_scatter(
replacements = _insert_fn_trace_before_node(
g,
reduce_scatter_merge_fn_to_trace,
rs_merge_fn,
(
pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
group_size,
@ -501,7 +641,9 @@ def merge_reduce_scatter(
def merge_all_gather(
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
gm: torch.fx.GraphModule,
ag_buckets: list[list[torch.fx.Node]],
mode: Optional[str] = None,
) -> None: # type: ignore[union-attr]
"""
Merges specified buckets of all_gather to joint all_gather.
@ -509,6 +651,10 @@ def merge_all_gather(
with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
from torch.distributed.distributed_c10d import _resolve_process_group
ag_merge_fn = all_gather_merge_fn_to_trace
if mode and "custom_ops" in mode:
ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops
trace_structured(
"artifact",
metadata_fn=lambda: {
@ -519,6 +665,8 @@ def merge_all_gather(
)
n_buckets = len(ag_buckets)
ag_node_to_pre_nodes = defaultdict(list)
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for bucket_idx, ag_bucket in enumerate(ag_buckets):
@ -537,6 +685,14 @@ def merge_all_gather(
and ag_node.meta["val"].dtype == dtype
)
ag_node_in = ag_node.args[0]
if (
ag_node_in.op == "call_function" # type: ignore[union-attr]
and ag_node_in.target # type: ignore[union-attr]
== torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
and len(ag_node_in.users) == 1 # type: ignore[union-attr]
):
ag_node_to_pre_nodes[ag_node].append(ag_node_in)
ag_node_in = ag_node_in.args[0] # type: ignore[union-attr]
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
ag_waits[bucket_idx].append(wait_node)
@ -558,7 +714,7 @@ def merge_all_gather(
replacements = _insert_fn_trace_before_node(
g,
all_gather_merge_fn_to_trace,
ag_merge_fn,
(
pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
group_size,
@ -582,3 +738,5 @@ def merge_all_gather(
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
g.erase_node(wait_n)
g.erase_node(ag_n)
for n in reversed(ag_node_to_pre_nodes[ag_n]):
g.erase_node(n) # type: ignore[arg-type]

View File

@ -56,6 +56,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
def bucket_fsdp_all_gather(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
) -> None:
"""
Bucketing pass for SimpleFSDP all_gather ops.
@ -79,12 +80,13 @@ def bucket_fsdp_all_gather(
)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)
merge_all_gather(gm, ag_buckets, mode)
def bucket_fsdp_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
mode: Optional[str] = None,
) -> None:
"""
Bucketing pass for SimpleFSDP reduce_scatter ops.
@ -109,4 +111,4 @@ def bucket_fsdp_reduce_scatter(
)
if len(rs_buckets) == 0:
return
merge_reduce_scatter(gm, rs_buckets)
merge_reduce_scatter(gm, rs_buckets, mode)

View File

@ -204,13 +204,14 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
p = (
bucket_fsdp_reduce_scatter
if config.bucket_reduce_scatters_fx == "fsdp"
if "fsdp" in config.bucket_reduce_scatters_fx
else bucket_reduce_scatter
)
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
lambda graph: p(
graph.owning_module,
config.bucket_reduce_scatters_fx_bucket_size_determinator,
config.bucket_reduce_scatters_fx,
)
)
collectives_bucketing = True
@ -223,13 +224,14 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
p = (
bucket_fsdp_all_gather # type: ignore[assignment]
if config.bucket_all_gathers_fx == "fsdp"
if "fsdp" in config.bucket_all_gathers_fx
else bucket_all_gather
)
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
lambda graph: p(
graph.owning_module,
config.bucket_all_gathers_fx_bucket_size_determinator,
config.bucket_all_gathers_fx,
)
)
collectives_bucketing = True