all_gather_bucketing fx pass (#157396)

Porting passes to bucket all_gathers

The main logic of the pass is done via
1. Searching for all all_gathers from the buckets

Copying tests from @wconstab PR to test compatibility with reordering.
Test checks only compatibility, as because of (3) the joint all_gather will be scheduled already as early as possible and no space for reordering.

Pass changes:
Using mutation ops to match performance of fsdp, in future the perfect scenario will be to have only functional graph, that inductor does all memory optimizations on its own without mutable ops.

Inductor changes:
Adding foreach_copy_ lowering

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157396
Approved by: https://github.com/wconstab
This commit is contained in:
IvanKobzarev
2025-07-03 06:48:05 -07:00
committed by PyTorch MergeBot
parent 19ae5afdaa
commit 7b392bac13
5 changed files with 700 additions and 0 deletions

View File

@ -25,6 +25,7 @@ from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import run_and_get_triton_code
from torch.distributed.distributed_c10d import GroupMember
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
DynamoDistributedMultiProcTestCase,
@ -1503,6 +1504,179 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
self.assertEqual(stats.limiting_factor, "data dependency")
self.assertEqual(stats.moves, 0)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_all_gather_bucket(self):
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
# cast the inputs
ag_0_cast = ag_0.to(torch.bfloat16)
ag_1_cast = ag_1.to(torch.bfloat16)
# allgather
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0_cast, group_size, group_name
)
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_1_cast, group_size, group_name
)
# wait op
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
return y, ag_0_out, ag_1_out
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "fsdp",
"reorder_for_compute_comm_overlap": False,
}
):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unneccessary copy is made.
(FileCheck().check("all_gather_into_tensor_out").run(code))
out = compiled(*inputs, **self.get_world_trs())
correct = func(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_reorder_peak_memory_bucketed(self):
"""
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
comm from moving due to data dependency.
"""
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
# cast the inputs
ag_0_cast = ag_0.to(torch.bfloat16)
ag_1_cast = ag_1.to(torch.bfloat16)
# allgather
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0_cast, group_size, group_name
)
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_1_cast, group_size, group_name
)
# wait op
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
return y, ag_0_out, ag_1_out
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
# get stats directly from the internal helper without affecting the real pass's signature
node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None
def _reorder_communication_preserving_peak_memory(
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
nonlocal node_stats
(
reordered_snodes,
node_stats,
) = _reorder_communication_preserving_peak_memory_internal(snodes)
return reordered_snodes
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
"sink_waits",
# same as reorder_communication_preserving_peak_memory but returns debug info structures directly
_reorder_communication_preserving_peak_memory,
],
}
):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unneccessary copy is made.
(FileCheck().check("all_gather_into_tensor_out").run(code))
out = compiled(*inputs, **self.get_world_trs())
correct = func(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
assert node_stats is not None
self.assertTrue(isinstance(node_stats, dict))
self.assertEqual(len(node_stats), 1)
# TODO: Debug why reordering does not move collective after bucketing
# for stats in node_stats.values():
# self.assertEqual(stats.initial_exposed, 0)
def _reorder_communication_preserving_peak_memory(
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
nonlocal node_stats
(
reordered_snodes,
node_stats,
) = _reorder_communication_preserving_peak_memory_internal(snodes)
return reordered_snodes
with torch._inductor.config.patch(
{
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
"sink_waits",
# same as reorder_communication_preserving_peak_memory but returns debug info structures directly
_reorder_communication_preserving_peak_memory,
],
}
):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unneccessary copy is made.
(
FileCheck()
.check("all_gather")
.check("wait")
.check("all_gather")
.check("wait")
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
correct = func(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
# TODO make the test case more interesting and validate the actual desired behavior
assert node_stats is not None
self.assertTrue(isinstance(node_stats, dict))
self.assertEqual(len(node_stats), 2)
# for stats in node_stats.values():
# self.assertEqual(stats.moves, 0)
# self.assertEqual(stats.limiting_factor, "data dependency")
# self.assertEqual(stats.moves, 3)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_reorder_respects_wait_dep(self):
"""

View File

@ -384,6 +384,10 @@ reorder_prefetch_limit: Optional[int] = None
# enable operator reordering for peak memory optimization
reorder_for_peak_memory = True
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
# runtime estimation function for ops
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
estimate_op_runtime = "default"

View File

@ -0,0 +1,432 @@
import logging
import operator
from typing import Any, Callable, Optional, Union
import torch
from torch._dispatch.python import enable_python_dispatcher
from torch._inductor.virtualized import V
from torch.utils._ordered_set import OrderedSet
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def bucket_size_determinator(bucket_id: int) -> float:
"""
Determine the size of a bucket based on its ID.
Args:
bucket_id (int): The ID of the bucket.
Returns:
float: The size of the bucket.
"""
return 2000.0
def bucket_all_gather(
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
) -> None:
ag_buckets = bucket_all_gather_by_mb(gm, all_gather_bucket_cap_mb_callback)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
)
def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
def bucket_all_gather_by_mb(
gm: torch.fx.GraphModule,
all_gather_bucket_cap_mb_callback: Callable[[int], float],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
"""
Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`.
Returns a list of buckets, where each bucket is a list of all_gather nodes.
"""
node_list = gm.graph.nodes
# Prerequisite: Check if there is any all_gather node
found_all_gather = False
for node in node_list:
if is_all_gather_into_tensor(node):
found_all_gather = True
break
if not found_all_gather:
return []
ag_nodes: list[torch.fx.Node] = []
# Step 1: Find all all_gather nodes
for node in node_list:
if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]):
if (filter_wait_node is None) or filter_wait_node(node):
ag_node = node.args[0]
ag_nodes.append(ag_node)
# Step 2: Put all_gather nodes into buckets
ag_buckets: list[list[torch.fx.Node]] = []
cur_bucket: list[torch.fx.Node] = []
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
# Convert MiB to bytes
all_gather_bucket_size_bytes = int(
all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024
)
for ag_node in ag_nodes:
assert is_all_gather_into_tensor(ag_node)
assert "val" in ag_node.meta
ag_output_size_bytes = (
ag_node.meta["val"].numel()
* torch.finfo(ag_node.meta["val"].dtype).bits
// 8
)
if (
cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes
and cur_bucket
):
# Current bucket is full, create new bucket
ag_buckets.append(cur_bucket)
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
cur_bucket_size_bytes += ag_output_size_bytes
cur_bucket.append(ag_node)
if cur_bucket:
# add remaining nodes in the last bucket
ag_buckets.append(cur_bucket)
return ag_buckets
def node_copy( # type: ignore[no-untyped-def]
env,
new_graph,
node: torch.fx.Node,
arg_transform: Callable[[torch.fx.Node], torch.fx.node.Argument],
) -> torch.fx.Node:
if node not in env:
new_node = new_graph.node_copy(node, arg_transform=arg_transform)
env[node] = new_node
else:
new_node = env[node]
return new_node
def new_graph_call_function( # type: ignore[no-untyped-def]
new_graph,
target: Callable[..., Any],
args: Optional[tuple[torch.fx.node.Argument, ...]] = None,
kwargs: Optional[dict[str, torch.fx.node.Argument]] = None,
type_expr: Optional[Any] = None,
) -> torch.fx.Node:
from torch.utils._pytree import tree_map_only
new_node = new_graph.call_function(target, args, kwargs)
args_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], args)
kwargs_val = tree_map_only(torch.fx.Node, lambda x: x.meta["val"], kwargs)
with V.fake_mode, enable_python_dispatcher():
new_fake_tensor = target(*args_val, **kwargs_val)
new_node.meta["val"] = new_fake_tensor
return new_node
def env_lookup( # type: ignore[no-untyped-def]
env, x: torch.fx.Node, node_user: Union[torch.fx.Node, str]
) -> torch.fx.Node:
assert x in env, (
f"Dependent node {x} not in env when creating downstream node {node_user}"
)
return env[x]
def merge_all_gather(
gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]]
) -> None:
"""
Transforms the graph to use bucketed all_gather operations based on `ag_buckets`.
"""
assert len(ag_buckets) > 0
ag_nodes: list[torch.fx.Node] = []
cast_nodes: list[torch.fx.Node] = []
ag_node_to_wait_node: dict[torch.fx.Node, torch.fx.Node] = {}
ag_node_to_bucket_id = {}
cast_node_to_bucket_id = {}
# Map nodes to buckets and identify wait nodes
for bucket_id, bucket in enumerate(ag_buckets):
for ag_node in bucket:
assert len(ag_node.users) == 1, (
f"Expect only one user for {ag_node}, but got {ag_node.users}"
)
wait_node = next(iter(ag_node.users))
ag_node_to_wait_node[ag_node] = wait_node
ag_nodes.append(ag_node)
ag_node_to_bucket_id[ag_node] = bucket_id
if (
ag_node.args[0].op == "call_function" # type: ignore[union-attr]
and ag_node.args[0].target # type: ignore[union-attr]
== torch.ops.prims.convert_element_type.default
):
cast_nodes.append(ag_node.args[0]) # type: ignore[arg-type]
cast_node_to_bucket_id[ag_node.args[0]] = bucket_id # type: ignore[arg-type]
# Step 3: Create new (bucketed) all_gather nodes
bucket_id_to_bucketed_op_info = {}
bucket_id_is_scheduled = {}
cast_bucket_id_is_scheduled = {}
_, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args
for bucket_id, ag_bucket in enumerate(ag_buckets):
ag_input_nodes = []
wait_nodes = []
for ag_node in ag_bucket:
assert (
ag_node in ag_node_to_wait_node
and ag_node.args[1] == group_size
and ag_node.args[2] == group_name
)
ag_input_nodes.append(ag_node.args[0])
wait_nodes.append(ag_node_to_wait_node[ag_node])
bucket_id_to_bucketed_op_info[bucket_id] = (
ag_input_nodes,
group_size,
group_name,
wait_nodes,
)
ag_wait_nodes = list(ag_node_to_wait_node.values())
ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes)
cast_nodes = OrderedSet(cast_nodes)
new_graph: torch.fx.Graph = torch.fx.Graph()
env: dict[torch.fx.Node, torch.fx.Node] = {}
node_list = gm.graph.nodes
for node in node_list:
if node not in ag_and_wait_nodes and node not in cast_nodes:
# not cast-before-all_gather, all_gather or its wait_tensor - schedule it normally
node_copy(env, new_graph, node, lambda x: env_lookup(env, x, node))
elif node in cast_nodes:
# batch cast nodes together into one foreach_copy node
assert node in cast_node_to_bucket_id
bucket_id = cast_node_to_bucket_id[node]
if bucket_id not in cast_bucket_id_is_scheduled:
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
bucket_id_to_bucketed_op_info[bucket_id]
)
# device = ag_input_nodes[0].meta["val"].device
# rank = device.index
# dtype = ag_input_nodes[0].meta["val"].dtype
if all(
n.op == "call_function" # type: ignore[union-attr]
and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
for n in ag_input_nodes
):
param_all_gather_inputs = [
new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
(n.meta["val"].shape,), # type: ignore[union-attr]
{
"dtype": n.args[1], # type: ignore[union-attr]
"device": n.meta["val"].device, # type: ignore[union-attr]
"pin_memory": False,
},
)
for n in ag_input_nodes
]
for pp, n in zip(param_all_gather_inputs, ag_input_nodes):
pp.meta = n.meta.copy() # type: ignore[union-attr]
cast_input_nodes = [env[n.args[0]] for n in ag_input_nodes] # type: ignore[union-attr, index]
foreach_copy = new_graph_call_function(
new_graph,
torch.ops.aten._foreach_copy.default,
(param_all_gather_inputs, cast_input_nodes),
{},
)
foreach_copy.meta["val"] = [n.meta["val"] for n in ag_input_nodes] # type: ignore[union-attr]
getitems = [
new_graph_call_function(
new_graph,
operator.getitem,
(foreach_copy, i),
{},
)
for i in range(len(ag_input_nodes))
]
for new_n, old_n in zip(getitems, ag_input_nodes):
env[old_n] = new_n # type: ignore[index] # noqa: PERF403
else:
param_all_gather_inputs_orig = [
node_copy(
env,
new_graph,
ag_input_node, # type: ignore[arg-type]
lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type]
)
for ag_input_node in ag_input_nodes
]
cast_bucket_id_is_scheduled[bucket_id] = True
else:
continue
elif node in ag_node_to_wait_node:
assert node in ag_node_to_bucket_id
bucket_id = ag_node_to_bucket_id[node]
if bucket_id not in bucket_id_is_scheduled:
ag_input_nodes, group_size, group_name, orig_wait_nodes = (
bucket_id_to_bucketed_op_info[bucket_id]
)
device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr]
rank = device.index
dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr]
# TODO: if we want to support mixed dtype in the same bucket,
# we need to first view all all_gather inputs as uint8 (common denominator),
# then do the all_gather, then view the output back to the original dtype.
# Look at FSDP2 to see how to do this.
assert all(n.meta["val"].dtype == dtype for n in ag_input_nodes), ( # type: ignore[union-attr]
"All all_gather inputs in the same bucket must have the same dtype"
)
# must schedule all the all_gather input nodes first, before the bucketed all_gather node
param_all_gather_inputs_orig = [
node_copy(
env,
new_graph,
ag_input_node, # type: ignore[arg-type]
lambda x: env_lookup(env, x, ag_input_node), # type: ignore[arg-type]
)
for ag_input_node in ag_input_nodes
]
# schedule the bucketed all_gather node
param_all_gather_inputs_flattened = [
new_graph_call_function(
new_graph, torch.ops.aten.reshape.default, (n, [-1]), {}
)
for n in param_all_gather_inputs_orig
]
inp_split_sizes = [
n.meta["val"].numel() for n in param_all_gather_inputs_orig
]
param_all_gather_outputs = [
new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([n.meta["val"].numel() * group_size],),
{
"dtype": n.meta["val"].dtype,
"device": n.meta["val"].device,
"pin_memory": False,
},
)
for n in param_all_gather_inputs_orig
]
# TODO: This assumes dim-0 sharding.
# If we need to support sharding on another dim, we should look at how FSDP2 does it
# (e.g. search for `shard_dim` in FSDP2 codebase)
param_all_gather_outputs_shape_orig = [
(n.meta["val"].shape[0] * group_size,) + n.meta["val"].shape[1:]
for n in param_all_gather_inputs_orig
]
all_gather_input_numel = sum(inp_split_sizes)
all_gather_output = new_graph_call_function(
new_graph,
torch.ops.aten.empty.memory_format,
([all_gather_input_numel * group_size],),
{
"dtype": dtype,
"device": device,
"pin_memory": False,
},
)
all_gather_copy_in = new_graph_call_function(
new_graph,
torch.ops.fsdp.all_gather_copy_in.default,
(
param_all_gather_inputs_flattened,
all_gather_output,
inp_split_sizes,
all_gather_input_numel,
rank,
),
{},
)
all_gather_input = new_graph_call_function(
new_graph,
operator.getitem,
(all_gather_copy_in, 0),
{},
)
all_gather_into_tensor_out = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.all_gather_into_tensor_out.default,
(all_gather_input, group_size, group_name),
{"out": all_gather_output},
)
wait_tensor = new_graph_call_function(
new_graph,
torch.ops._c10d_functional.wait_tensor.default,
(all_gather_into_tensor_out,),
{},
)
all_gather_output_reshaped = new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(wait_tensor, [group_size, -1]),
{},
)
outs_flattened = [
new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(n, [group_size, -1]),
{},
)
for n in param_all_gather_outputs
]
split_with_sizes_copy = new_graph_call_function( # noqa: F841
new_graph,
torch.ops.fsdp.split_with_sizes_copy.default,
(all_gather_output_reshaped, inp_split_sizes),
{"dim": 1, "out": outs_flattened},
)
outs = [
new_graph_call_function(
new_graph,
torch.ops.aten.reshape.default,
(n, orig_shape),
{},
)
for n, orig_shape in zip(
outs_flattened, param_all_gather_outputs_shape_orig
)
]
assert len(orig_wait_nodes) == len(outs)
assert len(orig_wait_nodes) > 0
for out, orig_wait_node in zip(outs, orig_wait_nodes):
env[orig_wait_node] = out # noqa: PERF403
bucket_id_is_scheduled[bucket_id] = True
else:
continue
gm.graph = new_graph

View File

@ -0,0 +1,68 @@
import logging
from typing import Callable
import torch
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather_by_mb,
merge_all_gather,
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def is_graph_input(node: torch.fx.Node) -> bool:
return node.op == "placeholder"
def is_fsdp_all_gather_wait(wait: torch.fx.Node) -> bool:
# Assume all_gather_into_tensor input is either graph input
# or dtype conversion of graph input
ag_node = wait.args[0] # type: ignore[arg-type, union-attr]
return (
is_graph_input(ag_node.args[0]) # type: ignore[arg-type, union-attr]
or ( # type: ignore[arg-type, union-attr]
ag_node.args[0].op == "call_function" # type: ignore[arg-type, union-attr]
and ag_node.args[0].target # type: ignore[arg-type, union-attr]
== torch.ops.prims.convert_element_type.default # type: ignore[arg-type, union-attr]
and is_graph_input(ag_node.args[0].args[0]) # type: ignore[arg-type, union-attr]
)
)
def bucket_fsdp_all_gather(
gm: torch.fx.GraphModule, all_gather_bucket_cap_mb_callback: Callable[[int], float]
) -> None:
"""
Bucketing pass for SimpleFSDP all_gather ops.
Attributes:
gm (torch.fx.GraphModule): Graph module of the graph.
all_gather_bucket_cap_mb_callback (Callable[[int], float]): callback function that
takes in bucket id and returns size of a bucket in megabytes.
Usage:
```
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather,
bucket_size_determinator,
)
def _bucket_all_gather(graph):
return bucket_all_gather(graph.owning_module, bucket_size_determinator)
torch._inductor.config.post_grad_custom_post_pass = _bucket_all_gather
```
"""
ag_buckets = bucket_all_gather_by_mb(
gm,
all_gather_bucket_cap_mb_callback,
filter_wait_node=is_fsdp_all_gather_wait,
)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)

View File

@ -219,6 +219,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
decompose_map_to_while_loop
)
# Fx all_gather bucketing introduces mutation op
# Keeping it in the end to keep invariant of functional graph for previous passes.
if config.bucket_all_gathers_fx != "none":
from torch._inductor.fx_passes.bucketing import (
bucket_all_gather,
bucket_size_determinator,
)
from torch._inductor.fx_passes.fsdp import bucket_fsdp_all_gather
p = (
bucket_fsdp_all_gather
if config.bucket_all_gathers_fx == "fsdp"
else bucket_all_gather
)
d = (
config.bucket_all_gathers_fx_bucket_size_determinator
or bucket_size_determinator
)
GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass(
lambda graph: p(graph.owning_module, d)
)
gm.recompile()
gm.graph.lint()