From 9272437cde67fcbb7dde66373382f711fd189418 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 16 Oct 2025 03:51:46 -0700 Subject: [PATCH] Fx collectives bucketing: add bucket all_reduce (#165351) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165351 Approved by: https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 61 ++++++++++ torch/_inductor/fx_passes/bucketing.py | 110 ++++++++++++++++++ 2 files changed, 171 insertions(+) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 34a4879e5d73..c9e4cbaa7558 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1743,6 +1743,67 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): correct = f(*inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not SM80OrLater, "bfloat16") + @parametrize("bucket_mode", ["all"]) + def test_all_reduce_bucket(self, bucket_mode): + def func(x, w, ar_0, ar_1, tag, ranks, group_size): + y = torch.mm(x, w) + + group_name = ( + torch.distributed.distributed_c10d._get_default_group().group_name + ) + ar_0_out = torch.ops._c10d_functional.all_reduce.default( + ar_0, "sum", group_name + ) + ar_1_out = torch.ops._c10d_functional.all_reduce.default( + ar_1, "sum", group_name + ) + + ar_0_w = torch.ops.c10d_functional.wait_tensor(ar_0_out) + ar_1_w = torch.ops.c10d_functional.wait_tensor(ar_1_out) + + return y, ar_0_w, ar_1_w + + f = func + + x = torch.ones(4, 384, device="cuda", dtype=torch.float32) + w = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ar_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ar_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32) + inputs = [x, w, ar_0, ar_1] + f(*inputs, **self.get_world_trs()) + + def _pass(g): + from torch._inductor.fx_passes.bucketing import bucket_all_reduce + + bucket_all_reduce(g.owning_module, lambda _: 2000) + + torch._inductor.config.post_grad_custom_post_pass = _pass + + with torch._inductor.config.patch( + { + "reorder_for_compute_comm_overlap": False, + } + ): + compiled = torch.compile(f) + compiled(*inputs, **self.get_world_trs()) + code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + # NOTE: The first return value should be the output of the first wait_tensor. + # We want to make sure no unnecessary copy is made. + ( + FileCheck() + .check_count( + "torch.ops._c10d_functional.all_reduce_.default(", + count=1, + exactly=True, + ) + .run(code) + ) + out = compiled(*inputs, **self.get_world_trs()) + correct = f(*inputs, **self.get_world_trs()) + assert same(out, correct), f"{out} va {correct}" + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") @parametrize("bucket_mode", ["all", "all_custom_ops"]) diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 965e0654380c..7260a6dc203b 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -34,11 +34,21 @@ def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: return (group_name, reduce_op, dtype) +def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: + _, reduce_op, group_name = node.args + dtype = node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + return (group_name, reduce_op, dtype) + + def bucket_key(node: torch.fx.Node) -> object | None: if is_all_gather_into_tensor(node): return _ag_group_key(node) elif is_reduce_scatter_tensor(node): return _rs_group_key(node) + elif is_all_reduce_tensor(node): + return _ar_group_key(node) else: return None @@ -111,6 +121,13 @@ def is_wait_tensor(node: torch.fx.Node) -> bool: ) +def is_all_reduce_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_reduce.default + ) + + def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool: return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type] @@ -293,6 +310,38 @@ def bucket_reduce_scatter_by_mb( ) +def bucket_all_reduce_by_mb( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float], + filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, +) -> list[list[torch.fx.Node]]: + return greedy_bucket_collective_by_mb( + gm, + bucket_cap_mb_by_bucket_idx, + is_all_reduce_tensor, + _ar_group_key, + filter_wait_node, + ) + + +def bucket_all_reduce( + gm: torch.fx.GraphModule, + bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, + mode: str | None = None, +) -> None: + if bucket_cap_mb_by_bucket_idx is None: + from torch._inductor.fx_passes.bucketing import ( + bucket_cap_mb_by_bucket_idx_default, + ) + + bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default + ar_buckets = bucket_all_reduce_by_mb(gm, bucket_cap_mb_by_bucket_idx) + if len(ar_buckets) == 0: + return + for bucket in ar_buckets: + merge_all_reduce_bucket(gm.graph, bucket, mode) + + @torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={}) def _pre_bucket_reduce_scatter( rs_ins: list[torch.Tensor], @@ -364,6 +413,24 @@ def reduce_scatter_merge_fn_to_trace( return new_outs +def all_reduce_merge_fn_to_trace( + ar_ins: list[torch.Tensor], + group_name: str, + reduce_op: str, + reduce_dtype: torch.dtype, # type: ignore[name-defined] + device: torch.device, # type: ignore[name-defined] +) -> list[torch.Tensor]: # type: ignore[no-untyped-def] + ar_ins_flattened = [x.view(-1) for x in ar_ins] + new_ar_in = torch.cat(ar_ins_flattened) + new_ar_out = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_reduce.default(new_ar_in, reduce_op, group_name) + ) + split_sizes = [x.numel() for x in ar_ins] + new_outs_flat = new_ar_out.split(split_sizes) + new_outs = [x.view(ar_in.shape) for x, ar_in in zip(new_outs_flat, ar_ins)] + return new_outs + + @torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={}) def _pre_bucket_all_gather( ag_ins: list[torch.Tensor], @@ -713,6 +780,49 @@ def merge_reduce_scatter_bucket( ) +def merge_all_reduce_bucket( + g: torch.fx.Graph, + ar_nodes: list[torch.fx.Node], + mode: str | None = None, + insert_before: torch.fx.Node | None = None, + wait_insertion_point: torch.fx.Node | None = None, +) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + ar0 = ar_nodes[0] + ar0_val = ar0.meta["val"] + _, reduce_op, group_name = ar0.args + reduce_dtype = ar0_val.dtype + device = ar0_val.device + + for n in ar_nodes: + ar_val = n.meta["val"] + assert ( + n.args[1] == reduce_op + and n.args[2] == group_name + and ar_val.device == device + and ar_val.dtype == reduce_dtype + ) + + ar_merge_fn = all_reduce_merge_fn_to_trace + + def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: + return ( + pytree.tree_map(lambda node: node.meta["val"], bucket_ins), + group_name, + reduce_op, + reduce_dtype, + device, + ) + + return process_collective_bucket( + g, + ar_nodes, + ar_merge_fn, + create_trace_args, + insert_before=insert_before, + wait_insertion_point=wait_insertion_point, + ) + + def merge_all_gather_bucket( g: torch.fx.Graph, ag_nodes: list[torch.fx.Node],