mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] Rewrite all_gather, reduce_scatter passes via tracing merge_fn (#158663)
Rewriting bucketing of all_gather and reduce_scatter with defining of "merge graph" via torch function. `all_gather_merge_fn_to_trace` `reduce_scatter_merge_fn_to_trace` (Instead of creating nodes and doing FakeTensor prop manually) This allows to experiment with merge function. Used foreach_copy_ in merging function for all_gather - added lowering for inductor for `foreach_copy_` Adding topological sort after bucketing passes (comment in post_grad.py): ``` # Fx collectives bucketing passes require topological sort for the cases: # when bucketed collectives have users before the last collective in the bucket # AND when inputs of bucketed collective have ancestors after the first collective in the bucket. # # In this case we can not manually pick the place for bucketed collective insertion. # But we are guaranteed by the bucketing (independent collectives in the bucket), # that it is possible to reorder nodes to satisfy all ordering requirements. # # --- before bucketing --- # in0 = ... # wait_ag0 = ag(in0) # user0(wait_ag0) # ... # pre_in1 = ... # in1 = transform(pre_in1) # wait_ag1 = ag(in1) # user1(wait_ag1) # # --- after bucketing --- # # in0 = ... # user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective. # # pre_in1 = ... # in1 = transform(pre_in1) # ag_bucket(in0+in1) # wait_bucket # wait_ag0 = wait_bucket[0] # wait_ag1 = wait_bucket[1] # user1(wait_ag1) ```` Correctness of the passes verified by loss curve for llama3 8b for simple_fsdp and for autoparallel: <img width="1364" height="495" alt="Screenshot 2025-07-22 at 14 27 28" src="https://github.com/user-attachments/assets/67b2cabb-3206-450b-b529-e23c24292fc6" /> <img width="1355" height="509" alt="Screenshot 2025-07-22 at 14 27 56" src="https://github.com/user-attachments/assets/4d0e6b25-2eb1-47b2-8d68-dcec185239c4" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158663 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc5dbbbb78
commit
8aebf01287
@ -2,7 +2,7 @@
|
||||
import datetime
|
||||
import functools
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections import Counter
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -666,7 +666,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
class TrackingMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ops_counter = defaultdict(int)
|
||||
self.ops_counter = Counter()
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
@ -1539,12 +1539,15 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0_cast, group_size, group_name
|
||||
)
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
|
||||
ag_0_out = ag_0_out * 2
|
||||
|
||||
ag_1_cast = ag_1_cast * 2
|
||||
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_1_cast, group_size, group_name
|
||||
)
|
||||
|
||||
# wait op
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
|
||||
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
@ -1557,7 +1560,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "fsdp",
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
@ -1601,7 +1604,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
rs_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
rs_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
rs_1 = torch.ones(384, 256, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, rs_0, rs_1]
|
||||
func(*inputs, **self.get_world_trs())
|
||||
|
||||
@ -1786,10 +1789,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
self.assertEqual(len(node_stats), 4)
|
||||
it = iter(node_stats.values())
|
||||
node_stat0 = next(it)
|
||||
self.assertTrue(node_stat0.moves > 0)
|
||||
self.assertTrue(node_stat0.limiting_factor == "None")
|
||||
node_stat1 = next(it)
|
||||
self.assertTrue(node_stat1.moves > 0)
|
||||
self.assertTrue("collective ordering" in node_stat1.limiting_factor)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather_by_mb,
|
||||
bucket_reduce_scatter_by_mb,
|
||||
merge_all_gather,
|
||||
merge_reduce_scatter,
|
||||
)
|
||||
|
||||
|
||||
@ -31,38 +33,68 @@ def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_graph_output(node: torch.fx.Node) -> bool:
|
||||
return all(user.op == "output" for user in node.users)
|
||||
|
||||
|
||||
def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool:
|
||||
return is_graph_output(wait)
|
||||
|
||||
|
||||
def bucket_fsdp_all_gather(
|
||||
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP all_gather ops.
|
||||
|
||||
Attributes:
|
||||
gm (torch.fx.GraphModule): Graph module of the graph.
|
||||
all_gather_bucket_cap_mb_callback (Callable[[int], float]): callback function that
|
||||
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
|
||||
takes in bucket id and returns size of a bucket in megabytes.
|
||||
|
||||
Usage:
|
||||
```
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather,
|
||||
bucket_size_determinator,
|
||||
)
|
||||
|
||||
|
||||
def _bucket_all_gather(graph):
|
||||
return bucket_all_gather(graph.owning_module, bucket_size_determinator)
|
||||
|
||||
|
||||
torch._inductor.config.post_grad_custom_post_pass = _bucket_all_gather
|
||||
```
|
||||
"""
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
assert bucket_cap_mb_by_bucket_idx is not None
|
||||
ag_buckets = bucket_all_gather_by_mb(
|
||||
gm,
|
||||
all_gather_bucket_cap_mb_callback,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
filter_wait_node=is_fsdp_all_gather_wait,
|
||||
)
|
||||
if len(ag_buckets) == 0:
|
||||
return
|
||||
merge_all_gather(gm, ag_buckets)
|
||||
|
||||
|
||||
def bucket_fsdp_reduce_scatter(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Bucketing pass for SimpleFSDP reduce_scatter ops.
|
||||
|
||||
Attributes:
|
||||
gm (torch.fx.GraphModule): Graph module of the graph.
|
||||
bucket_cap_mb_by_bucket_idx (Optional[Callable[[int], float]]): callback function that
|
||||
takes in bucket idx and returns size of a bucket in megabytes. By default
|
||||
torch._inductor.fx_passes.bucketing.bucket_cap_mb_by_bucket_idx_default is used.
|
||||
|
||||
"""
|
||||
if bucket_cap_mb_by_bucket_idx is None:
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_cap_mb_by_bucket_idx_default,
|
||||
)
|
||||
|
||||
bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
|
||||
rs_buckets = bucket_reduce_scatter_by_mb(
|
||||
gm,
|
||||
bucket_cap_mb_by_bucket_idx,
|
||||
filter_wait_node=is_fsdp_reduce_scatter_wait,
|
||||
)
|
||||
if len(rs_buckets) == 0:
|
||||
return
|
||||
merge_reduce_scatter(gm, rs_buckets)
|
||||
|
@ -197,7 +197,77 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
pass_name = "custom_backend_passes_" + device
|
||||
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
|
||||
|
||||
# Keep these last, since they introduces mutation. Look at
|
||||
collectives_bucketing: bool = False
|
||||
if config.bucket_reduce_scatters_fx != "none":
|
||||
from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter
|
||||
from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter
|
||||
|
||||
p = (
|
||||
bucket_fsdp_reduce_scatter
|
||||
if config.bucket_reduce_scatters_fx == "fsdp"
|
||||
else bucket_reduce_scatter
|
||||
)
|
||||
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
|
||||
lambda graph: p(
|
||||
graph.owning_module,
|
||||
config.bucket_reduce_scatters_fx_bucket_size_determinator,
|
||||
)
|
||||
)
|
||||
collectives_bucketing = True
|
||||
|
||||
# Fx all_gather bucketing introduces mutation op
|
||||
# Keeping it in the end to keep invariant of functional graph for previous passes.
|
||||
if config.bucket_all_gathers_fx != "none":
|
||||
from torch._inductor.fx_passes.bucketing import bucket_all_gather
|
||||
from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
|
||||
|
||||
p = (
|
||||
bucket_fsdp_all_gather # type: ignore[assignment]
|
||||
if config.bucket_all_gathers_fx == "fsdp"
|
||||
else bucket_all_gather
|
||||
)
|
||||
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
|
||||
lambda graph: p(
|
||||
graph.owning_module,
|
||||
config.bucket_all_gathers_fx_bucket_size_determinator,
|
||||
)
|
||||
)
|
||||
collectives_bucketing = True
|
||||
|
||||
if collectives_bucketing:
|
||||
# Fx collectives bucketing passes require topological sort for the cases:
|
||||
# when bucketed collectives have users before the last collective in the bucket
|
||||
# AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
|
||||
#
|
||||
# In this case we can not manually pick the place for bucketed collective insertion.
|
||||
# But we are guaranteed by the bucketing (independent collectives in the bucket),
|
||||
# that it is possible to reorder nodes to satisfy all ordering requirements.
|
||||
#
|
||||
# --- before bucketing ---
|
||||
# in0 = ...
|
||||
# wait_ag0 = ag(in0)
|
||||
# user0(wait_ag0)
|
||||
# ...
|
||||
# pre_in1 = ...
|
||||
# in1 = transform(pre_in1)
|
||||
# wait_ag1 = ag(in1)
|
||||
# user1(wait_ag1)
|
||||
#
|
||||
# --- after bucketing ---
|
||||
#
|
||||
# in0 = ...
|
||||
# user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
|
||||
#
|
||||
# pre_in1 = ...
|
||||
# in1 = transform(pre_in1)
|
||||
# ag_bucket(in0+in1)
|
||||
# wait_bucket
|
||||
# wait_ag0 = wait_bucket[0]
|
||||
# wait_ag1 = wait_bucket[1]
|
||||
# user1(wait_ag1)
|
||||
stable_topological_sort(gm.graph)
|
||||
|
||||
# Keep these last, since they introduce mutation. Look at
|
||||
# ./fx_passes/README.md for a discussion of mutation invariants.
|
||||
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
|
||||
functools.partial(reinplace_inplaceable_ops, fake_tensor_updater),
|
||||
@ -219,42 +289,6 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
decompose_map_to_while_loop
|
||||
)
|
||||
|
||||
if config.bucket_reduce_scatters_fx != "none":
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_reduce_scatter,
|
||||
bucket_size_determinator,
|
||||
)
|
||||
|
||||
d = (
|
||||
config.bucket_reduce_scatters_fx_bucket_size_determinator
|
||||
or bucket_size_determinator
|
||||
)
|
||||
GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass(
|
||||
lambda graph: bucket_reduce_scatter(graph.owning_module, d)
|
||||
)
|
||||
|
||||
# Fx all_gather bucketing introduces mutation op
|
||||
# Keeping it in the end to keep invariant of functional graph for previous passes.
|
||||
if config.bucket_all_gathers_fx != "none":
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
bucket_all_gather,
|
||||
bucket_size_determinator,
|
||||
)
|
||||
from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
|
||||
|
||||
p = (
|
||||
bucket_fsdp_all_gather
|
||||
if config.bucket_all_gathers_fx == "fsdp"
|
||||
else bucket_all_gather
|
||||
)
|
||||
d = (
|
||||
config.bucket_all_gathers_fx_bucket_size_determinator
|
||||
or bucket_size_determinator
|
||||
)
|
||||
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
|
||||
lambda graph: p(graph.owning_module, d)
|
||||
)
|
||||
|
||||
gm.recompile()
|
||||
gm.graph.lint()
|
||||
|
||||
|
@ -6731,7 +6731,7 @@ register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
|
||||
register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
|
||||
register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
|
||||
register_foreach_pointwise(aten._foreach_sign, sign)
|
||||
register_foreach_pointwise(aten._foreach_copy, copy)
|
||||
foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)
|
||||
|
||||
|
||||
# these are only encountered as outputs of the graph
|
||||
@ -6770,6 +6770,9 @@ register_foreach_inplace(
|
||||
register_foreach_inplace(
|
||||
aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
|
||||
)
|
||||
register_foreach_inplace(
|
||||
aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
|
||||
)
|
||||
|
||||
|
||||
def register_inplace(aten_op, outplace_op):
|
||||
|
Reference in New Issue
Block a user