[bucketing] Rewrite all_gather, reduce_scatter passes via tracing merge_fn (#158663)

Rewriting bucketing of all_gather and reduce_scatter with defining of "merge graph" via torch function.
`all_gather_merge_fn_to_trace`
`reduce_scatter_merge_fn_to_trace`

(Instead of creating nodes and doing FakeTensor prop manually)
This allows to experiment with merge function.

Used foreach_copy_ in merging function for all_gather - added lowering for inductor for `foreach_copy_`

Adding topological sort after bucketing passes (comment in post_grad.py):
```
        # Fx collectives bucketing passes require topological sort for the cases:
        # when bucketed collectives have users before the last collective in the bucket
        # AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
        #
        # In this case we can not manually pick the place for bucketed collective insertion.
        # But we are guaranteed by the bucketing (independent collectives in the bucket),
        # that it is possible to reorder nodes to satisfy all ordering requirements.
        #
        # --- before bucketing ---
        # in0 = ...
        # wait_ag0 = ag(in0)
        # user0(wait_ag0)
        # ...
        # pre_in1 = ...
        # in1 = transform(pre_in1)
        # wait_ag1 = ag(in1)
        # user1(wait_ag1)
        #
        # --- after bucketing ---
        #
        # in0 = ...
        # user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
        #
        # pre_in1 = ...
        # in1 = transform(pre_in1)
        # ag_bucket(in0+in1)
        # wait_bucket
        # wait_ag0 = wait_bucket[0]
        # wait_ag1 = wait_bucket[1]
        # user1(wait_ag1)
````

Correctness of the passes verified by loss curve for llama3 8b for simple_fsdp and for autoparallel:

<img width="1364" height="495" alt="Screenshot 2025-07-22 at 14 27 28" src="https://github.com/user-attachments/assets/67b2cabb-3206-450b-b529-e23c24292fc6" />
<img width="1355" height="509" alt="Screenshot 2025-07-22 at 14 27 56" src="https://github.com/user-attachments/assets/4d0e6b25-2eb1-47b2-8d68-dcec185239c4" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158663
Approved by: https://github.com/wconstab
This commit is contained in:
IvanKobzarev
2025-07-25 10:40:47 -07:00
committed by PyTorch MergeBot
parent bc5dbbbb78
commit 8aebf01287
5 changed files with 573 additions and 739 deletions

View File

@ -2,7 +2,7 @@
import datetime
import functools
import unittest
from collections import defaultdict
from collections import Counter
from typing import Optional
from unittest.mock import patch
@ -666,7 +666,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
class TrackingMode(TorchDispatchMode):
def __init__(self):
super().__init__()
self.ops_counter = defaultdict(int)
self.ops_counter = Counter()
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
@ -1539,12 +1539,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0_cast, group_size, group_name
)
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_0_out = ag_0_out * 2
ag_1_cast = ag_1_cast * 2
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_1_cast, group_size, group_name
)
# wait op
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
return y, ag_0_out, ag_1_out
@ -1557,7 +1560,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "fsdp",
"bucket_all_gathers_fx": "all",
"reorder_for_compute_comm_overlap": False,
}
):
@ -1601,7 +1604,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
rs_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
rs_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
rs_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
inputs = [x, w, rs_0, rs_1]
func(*inputs, **self.get_world_trs())
@ -1786,10 +1789,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(len(node_stats), 4)
it = iter(node_stats.values())
node_stat0 = next(it)
self.assertTrue(node_stat0.moves > 0)
self.assertTrue(node_stat0.limiting_factor == "None")
node_stat1 = next(it)
self.assertTrue(node_stat1.moves > 0)
self.assertTrue("collective ordering" in node_stat1.limiting_factor)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,12 @@
import logging
from typing import Callable
from typing import Callable, Optional
import torch
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather_by_mb,
bucket_reduce_scatter_by_mb,
merge_all_gather,
merge_reduce_scatter,
)
@ -31,38 +33,68 @@ def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool:
)
def is_graph_output(node: torch.fx.Node) -> bool:
return all(user.op == "output" for user in node.users)
def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
return is_graph_output(wait)
def bucket_fsdp_all_gather(
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
) -> None:
"""
Bucketing pass for SimpleFSDP all_gather ops.
Attributes:
gm (torch.fx.GraphModule): Graph module of the graph.
all_gather_bucket_cap_mb_callback (Callable[[int], float]): callback function that
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
takes in bucket id and returns size of a bucket in megabytes.
Usage:
```
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather,
bucket_size_determinator,
)
def _bucket_all_gather(graph):
return bucket_all_gather(graph.owning_module, bucket_size_determinator)
torch._inductor.config.post_grad_custom_post_pass = _bucket_all_gather
```
"""
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
bucket_cap_mb_by_bucket_idx_default,
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
assert bucket_cap_mb_by_bucket_idx is not None
ag_buckets = bucket_all_gather_by_mb(
gm,
all_gather_bucket_cap_mb_callback,
bucket_cap_mb_by_bucket_idx,
filter_wait_node=is_fsdp_all_gather_wait,
)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)
def bucket_fsdp_reduce_scatter(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
) -> None:
"""
Bucketing pass for SimpleFSDP reduce_scatter ops.
Attributes:
gm (torch.fx.GraphModule): Graph module of the graph.
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
takes in bucket idx and returns size of a bucket in megabytes. By default
torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used.
"""
if bucket_cap_mb_by_bucket_idx is None:
from torch._inductor.fx_passes.bucketing import (
bucket_cap_mb_by_bucket_idx_default,
)
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
rs_buckets = bucket_reduce_scatter_by_mb(
gm,
bucket_cap_mb_by_bucket_idx,
filter_wait_node=is_fsdp_reduce_scatter_wait,
)
if len(rs_buckets) == 0:
return
merge_reduce_scatter(gm, rs_buckets)

View File

@ -197,7 +197,77 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
pass_name = "custom_backend_passes_" + device
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
# Keep these last, since they introduces mutation. Look at
collectives_bucketing: bool = False
if config.bucket_reduce_scatters_fx != "none":
from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter
from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter
p = (
bucket_fsdp_reduce_scatter
if config.bucket_reduce_scatters_fx == "fsdp"
else bucket_reduce_scatter
)
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
lambda graph: p(
graph.owning_module,
config.bucket_reduce_scatters_fx_bucket_size_determinator,
)
)
collectives_bucketing = True
# Fx all_gather bucketing introduces mutation op
# Keeping it in the end to keep invariant of functional graph for previous passes.
if config.bucket_all_gathers_fx != "none":
from torch._inductor.fx_passes.bucketing import bucket_all_gather
from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
p = (
bucket_fsdp_all_gather # type: ignore[assignment]
if config.bucket_all_gathers_fx == "fsdp"
else bucket_all_gather
)
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
lambda graph: p(
graph.owning_module,
config.bucket_all_gathers_fx_bucket_size_determinator,
)
)
collectives_bucketing = True
if collectives_bucketing:
# Fx collectives bucketing passes require topological sort for the cases:
# when bucketed collectives have users before the last collective in the bucket
# AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
#
# In this case we can not manually pick the place for bucketed collective insertion.
# But we are guaranteed by the bucketing (independent collectives in the bucket),
# that it is possible to reorder nodes to satisfy all ordering requirements.
#
# --- before bucketing ---
# in0 = ...
# wait_ag0 = ag(in0)
# user0(wait_ag0)
# ...
# pre_in1 = ...
# in1 = transform(pre_in1)
# wait_ag1 = ag(in1)
# user1(wait_ag1)
#
# --- after bucketing ---
#
# in0 = ...
# user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
#
# pre_in1 = ...
# in1 = transform(pre_in1)
# ag_bucket(in0+in1)
# wait_bucket
# wait_ag0 = wait_bucket[0]
# wait_ag1 = wait_bucket[1]
# user1(wait_ag1)
stable_topological_sort(gm.graph)
# Keep these last, since they introduce mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants.
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
functools.partial(reinplace_inplaceable_ops, fake_tensor_updater),
@ -219,42 +289,6 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
decompose_map_to_while_loop
)
if config.bucket_reduce_scatters_fx != "none":
from torch._inductor.fx_passes.bucketing import (
bucket_reduce_scatter,
bucket_size_determinator,
)
d = (
config.bucket_reduce_scatters_fx_bucket_size_determinator
or bucket_size_determinator
)
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
lambda graph: bucket_reduce_scatter(graph.owning_module, d)
)
# Fx all_gather bucketing introduces mutation op
# Keeping it in the end to keep invariant of functional graph for previous passes.
if config.bucket_all_gathers_fx != "none":
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather,
bucket_size_determinator,
)
from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
p = (
bucket_fsdp_all_gather
if config.bucket_all_gathers_fx == "fsdp"
else bucket_all_gather
)
d = (
config.bucket_all_gathers_fx_bucket_size_determinator
or bucket_size_determinator
)
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
lambda graph: p(graph.owning_module, d)
)
gm.recompile()
gm.graph.lint()

View File

@ -6731,7 +6731,7 @@ register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
register_foreach_pointwise(aten._foreach_sign, sign)
register_foreach_pointwise(aten._foreach_copy, copy)
foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)
# these are only encountered as outputs of the graph
@ -6770,6 +6770,9 @@ register_foreach_inplace(
register_foreach_inplace(
aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
)
register_foreach_inplace(
aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
)
def register_inplace(aten_op, outplace_op):