mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[bucketing] custom_ops mode to hide inductor copies overhead (#161499)
Adding "_custom_ops" bucketing to temporary fallback to eager execution of for_each, to workaround too many generated kernels on inductor side. This PR also reverts parts of bucketing changes for cycles detection that resulted in accuracy problems. Differential Revision: [D81152293](https://our.internmc.facebook.com/intern/diff/D81152293) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161499 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
9c991b63ff
commit
8ec01f34e9
@ -1528,7 +1528,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_all_gather_bucket(self):
|
||||
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||
def test_all_gather_bucket(self, bucket_mode):
|
||||
def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
@ -1576,7 +1577,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
with (
|
||||
torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"bucket_all_gathers_fx": bucket_mode,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
"runtime_estimations_mms_benchmark": True,
|
||||
}
|
||||
@ -1595,7 +1596,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
# We want to make sure no unnecessary copy is made.
|
||||
(
|
||||
FileCheck()
|
||||
.check_count(".all_gather_into_tensor_out.default(", 2, exactly=True)
|
||||
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
|
||||
.check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(")
|
||||
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
@ -1656,7 +1659,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_reduce_scatter_bucket(self):
|
||||
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||
def test_reduce_scatter_bucket(self, bucket_mode):
|
||||
def func(x, w, rs_0, rs_1, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
@ -1697,7 +1701,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_reduce_scatters_fx": "fsdp",
|
||||
"bucket_reduce_scatters_fx": bucket_mode,
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
@ -1723,7 +1727,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_reorder_peak_memory_bucketed(self):
|
||||
@parametrize("bucket_mode", ["all", "all_custom_ops"])
|
||||
def test_reorder_peak_memory_bucketed(self, bucket_mode):
|
||||
"""
|
||||
Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather.
|
||||
Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the
|
||||
@ -1837,9 +1842,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
with (
|
||||
torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"bucket_all_gathers_fx": bucket_mode,
|
||||
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||
"bucket_reduce_scatters_fx": "all",
|
||||
"bucket_reduce_scatters_fx": bucket_mode,
|
||||
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
|
Reference in New Issue
Block a user