[bucketing] custom_ops mode to hide inductor copies overhead (#161499)

Adding "_custom_ops" bucketing to temporary fallback to eager execution of for_each,
to workaround too many generated kernels on inductor side.

This PR also reverts parts of bucketing changes for cycles detection that resulted in accuracy problems.

Differential Revision: [D81152293](https://our.internmc.facebook.com/intern/diff/D81152293)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161499
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-09-08 09:03:32 -07:00
committed by PyTorch MergeBot
parent 9c991b63ff
commit 8ec01f34e9
4 changed files with 197 additions and 30 deletions

View File

@ -1528,7 +1528,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_all_gather_bucket(self):
@parametrize("bucket_mode", ["all", "all_custom_ops"])
def test_all_gather_bucket(self, bucket_mode):
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
@ -1576,7 +1577,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with (
torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx": bucket_mode,
"reorder_for_compute_comm_overlap": False,
"runtime_estimations_mms_benchmark": True,
}
@ -1595,7 +1596,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
# We want to make sure no unnecessary copy is made.
(
FileCheck()
.check_count(".all_gather_into_tensor_out.default(", 2, exactly=True)
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
.check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(")
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
.run(code)
)
out = compiled(*inputs, **self.get_world_trs())
@ -1656,7 +1659,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_reduce_scatter_bucket(self):
@parametrize("bucket_mode", ["all", "all_custom_ops"])
def test_reduce_scatter_bucket(self, bucket_mode):
def func(x, w, rs_0, rs_1, tag, ranks, group_size):
# do some unrelated matmuls
y = torch.mm(x, w)
@ -1697,7 +1701,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with torch._inductor.config.patch(
{
"bucket_reduce_scatters_fx": "fsdp",
"bucket_reduce_scatters_fx": bucket_mode,
"reorder_for_compute_comm_overlap": False,
}
):
@ -1723,7 +1727,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@unittest.skipIf(not SM80OrLater, "bfloat16")
def test_reorder_peak_memory_bucketed(self):
@parametrize("bucket_mode", ["all", "all_custom_ops"])
def test_reorder_peak_memory_bucketed(self, bucket_mode):
"""
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
@ -1837,9 +1842,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
with (
torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx": bucket_mode,
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx": "all",
"bucket_reduce_scatters_fx": bucket_mode,
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [