[bucketing] Bucket only adjacent collectives to prevent reordering (#159983)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159983
Approved by: https://github.com/wconstab, https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-08-07 03:15:48 -07:00
committed by PyTorch MergeBot
parent 4d5b3f2d5a
commit f33ce40bc0
2 changed files with 49 additions and 17 deletions

View File

@ -1524,39 +1524,49 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_all_gather_bucket(self):
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
# cast the inputs
ag_0_cast = ag_0.to(torch.bfloat16)
ag_1_cast = ag_1.to(torch.bfloat16)
# allgather
group_name = (
torch.distributed.distributed_c10d._get_default_group().group_name
)
ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_2, group_size, group_name
)
ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out)
ag_0 = ag_2_out + ag_0
ag_0_cast = ag_0.to(torch.bfloat16)
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0_cast, group_size, group_name
)
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_0_out = ag_0_out * 2
ag_1_cast = ag_1_cast * 2
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_1_cast, group_size, group_name
)
# wait op
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
return y, ag_0_out, ag_1_out
ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_3, group_size, group_name
)
ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out)
return y, ag_0_out, ag_1_out, ag_2_out, ag_3_out
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1]
ag_2 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
ag_3 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
correct = func(*inputs, **self.get_world_trs())
with torch._inductor.config.patch(
{
@ -1568,9 +1578,14 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
# NOTE: The first return value should be the output of the first wait_tensor.
# We want to make sure no unnecessary copy is made.
(FileCheck().check("all_gather_into_tensor_out").run(code))
(
FileCheck()
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
.check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(")
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
correct = func(*inputs, **self.get_world_trs())
assert same(out, correct), f"{out} va {correct}"
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")

View File

@ -93,6 +93,12 @@ def greedy_bucket_collective_by_mb(
node_group_key: Callable[[torch.fx.Node], Any],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
"""
Bucketing adjacent collectives with equal node_group_key.
We can not bucket non adjacent collectives,
as this will effectively change the order of collectives.
Reordering can lead to different order on different ranks.
"""
g = gm.graph
found_candidates = False
for node in g.nodes:
@ -102,10 +108,12 @@ def greedy_bucket_collective_by_mb(
if not found_candidates:
return []
nodes_groups: dict[Any, list[torch.fx.Node]] = defaultdict(list)
nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict(
OrderedSet
)
nodes_groups: list[list[torch.fx.Node]] = []
cur_group: list[torch.fx.Node] = []
cur_group_key = None
for node in g.nodes:
for n, successors in nodes_successors.items():
@ -115,10 +123,19 @@ def greedy_bucket_collective_by_mb(
if (filter_wait_node is None) or filter_wait_node(node):
coll_node = node.args[0]
group_key = node_group_key(coll_node)
nodes_groups[group_key].append(coll_node)
if group_key == cur_group_key:
cur_group.append(coll_node)
else:
if len(cur_group) > 1:
nodes_groups.append(cur_group)
cur_group = [coll_node]
cur_group_key = group_key
if len(cur_group) > 1:
nodes_groups.append(cur_group)
buckets: list[list[torch.fx.Node]] = []
for nodes in nodes_groups.values():
for nodes in nodes_groups:
cur_bucket: list[torch.fx.Node] = []
cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
@ -128,7 +145,7 @@ def greedy_bucket_collective_by_mb(
)
for node in nodes:
if node in cur_bucket_successors:
# We can not bucket successors with the node
# We cannot bucket successors with the node
continue
assert "val" in node.meta
n_val = node.meta["val"]
@ -163,7 +180,7 @@ def bucket_all_gather_by_mb(
Args:
gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets at the start,
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
@ -201,14 +218,14 @@ def bucket_reduce_scatter_by_mb(
Args:
gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
"""
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: