mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Comm-Compute Preserving Bucketer (#163960)
tl;dr performs bucketing while preserving comm-compute overlap. In comm-compute overlap we will have a graph with: ``` def foo(...): ag = all_gather(...) hiding_compute = mm(...) wait(ag) ``` There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in. TODO: - need to instrument fx graph so inductor respects these relationships. - the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through - more memory aware handling Pull Request resolved: https://github.com/pytorch/pytorch/pull/163960 Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev ghstack dependencies: #163215, #163754, #163959
This commit is contained in:
committed by
PyTorch MergeBot
parent
92108f4abd
commit
7d59e37434
@ -29,12 +29,12 @@ from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
|
||||
|
||||
def estimate_aten_runtime(fx_node):
|
||||
def estimate_aten_runtime(fx_node, compute_multiplier=1.0):
|
||||
# for tests, assume a matmul can hide a single collective
|
||||
if "c10" in str(fx_node.target):
|
||||
return 1.0
|
||||
elif fx_node.target == aten.mm.default:
|
||||
return 1.0
|
||||
return compute_multiplier
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -347,6 +347,410 @@ graph():
|
||||
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
|
||||
|
||||
|
||||
def get_bucket_patches(compute_multiplier=1.0):
|
||||
estimate_aten_runtime_part = functools.partial(
|
||||
estimate_aten_runtime, compute_multiplier=compute_multiplier
|
||||
)
|
||||
return {
|
||||
"test_configs.estimate_aten_runtime": estimate_aten_runtime_part,
|
||||
"test_configs.aten_fx_overlap_preserving_bucketing": True,
|
||||
"reorder_for_locality": False,
|
||||
"reorder_for_compute_comm_overlap_passes": [],
|
||||
"compile_threads": 1,
|
||||
"force_disable_caches": True,
|
||||
}
|
||||
|
||||
|
||||
class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_basic_all_gather_bucketing(self):
|
||||
"""Test that independent all_gather operations get bucketed together."""
|
||||
|
||||
def func(a, b, c, *, ranks):
|
||||
# Three independent all_gathers that should be bucketed
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + 3
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + 4
|
||||
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + 5
|
||||
return ag1 + ag2 + ag3
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = (
|
||||
torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
)
|
||||
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
|
||||
inputs_c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(
|
||||
compiled, inputs_a, inputs_b, inputs_c
|
||||
)
|
||||
|
||||
# Should see a single bucketed all_gather
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(inputs_a, inputs_b, inputs_c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_reduce_scatter_bucketing(self):
|
||||
"""Test bucketing of reduce_scatter operations."""
|
||||
|
||||
def func(a, b, c):
|
||||
rs1 = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
|
||||
rs2 = _functional_collectives.reduce_scatter_tensor(b, "sum", 0, "0")
|
||||
rs3 = _functional_collectives.reduce_scatter_tensor(c, "sum", 0, "0")
|
||||
return torch.cat([rs1, rs2, rs3])
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = torch.ones(8, 4, dtype=torch.float, device=device_type)
|
||||
inputs_b = torch.ones(8, 4, dtype=torch.float, device=device_type) * 2
|
||||
inputs_c = torch.ones(8, 4, dtype=torch.float, device=device_type) * 3
|
||||
|
||||
out, aten_graph_str = run_and_get_aten_graph(
|
||||
torch.compile(func), inputs_a, inputs_b, inputs_c
|
||||
)
|
||||
|
||||
# Should bucket reduce_scatter ops
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.reduce_scatter_tensor", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# TODO: debug - on ci this fails.
|
||||
# correct = func(inputs_a, inputs_b, inputs_c)
|
||||
# self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_no_bucketing_with_dependent_hiding_nodes(self):
|
||||
"""Test that collectives with dependent hiding nodes don't get bucketed."""
|
||||
|
||||
def func(a, b, *, ranks):
|
||||
# ag1 could be hidden by mm1
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
mm1 = torch.matmul(a, a)
|
||||
|
||||
# ag2 can be hidden by mm2, but mm2 depends on ag1's result
|
||||
# ag2 start
|
||||
mm2 = torch.matmul(ag1[:4], b)
|
||||
# ag2 end
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
|
||||
return ag1.sum() * ag2.sum() * mm1 * mm2
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs_a, inputs_b)
|
||||
|
||||
# mm2 depends on ag1, so if mm2 is to hide ag2, we can't bucket ag1 and ag2
|
||||
# because that would create a dependency issue, even though we could bucket them
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.all_gather_into_tensor", 2, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(inputs_a, inputs_b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_no_bucketing_when_collective_depends_on_hiding_node(self):
|
||||
"""Test that collectives don't get bucketed when one depends on another's hiding node."""
|
||||
|
||||
def func(a, *, ranks):
|
||||
# ag1 hidden by mm1
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
mm1 = torch.matmul(a, a)
|
||||
|
||||
# ag2 depends on mm1 (which hides ag1)
|
||||
b = mm1 * 2
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
|
||||
return ag1.sum() * ag2.sum() * mm1
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||
|
||||
# ag2 depends on mm1 (ag1's hiding node), so they can't be bucketed
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor", 2, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(inputs, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_wait_sink(self):
|
||||
"""Test that 4 independent all-gathers split bucketed."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Second compute - can hide ag3 and ag4
|
||||
f = b * 6
|
||||
mm2 = torch.matmul(f, f.T)
|
||||
|
||||
# Use all collective results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# The 4 all gathers can be bucketed, and their waits should be sunk below the mms
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
|
||||
).check_count("ops.aten.mm", 2, exactly=True).check(
|
||||
"_c10d_functional.wait_tensor"
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap_blocking(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
# Use first 8x8 elements to match mm1's shape
|
||||
intermediate = ag1[:8, :8] + ag2[:8, :8]
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
mm2 = torch.matmul(mm1 + intermediate, c[:8])
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# The 4 all gathers can be bucketed, and the wait should be sunk below the mms
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor", 1, exactly=True
|
||||
).check_count("ops.aten.mm", 2, exactly=True).check_count(
|
||||
"_c10d_functional.wait_tensor", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches(2.0))
|
||||
def test_bucketing_split_for_overlap(self):
|
||||
"""Test that 4 independent all-gathers split into 2+2 buckets for better overlap with compute."""
|
||||
|
||||
def func(a, b, c, d, *, ranks):
|
||||
# All 4 all-gathers are independent - COULD be bucketed together
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
|
||||
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
|
||||
|
||||
# First compute - can hide ag1 and ag2
|
||||
e = a * 5 # Use a to avoid fusion
|
||||
mm1 = torch.matmul(e, e.T)
|
||||
|
||||
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
|
||||
intermediate = ag1[:2, :2] + ag2[:2, :2] # Small slice to minimize compute
|
||||
|
||||
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
|
||||
f = b * 6
|
||||
# Expand intermediate to match mm1's shape for broadcasting
|
||||
intermediate_expanded = torch.nn.functional.pad(intermediate, (0, 6, 0, 6))
|
||||
mm2 = torch.matmul(mm1 + intermediate_expanded, f.T)
|
||||
|
||||
# Use all results
|
||||
result = (
|
||||
ag1.sum() * 1.1
|
||||
+ ag2.sum() * 1.2
|
||||
+ ag3.sum() * 1.3
|
||||
+ ag4.sum() * 1.4
|
||||
+ mm1.sum()
|
||||
+ mm2.sum()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d)
|
||||
|
||||
# Should have 2 bucketed all-gathers (one for ag1+ag2, one for ag3+ag4)
|
||||
FileCheck().check_count(
|
||||
"_c10d_functional.all_gather_into_tensor_out", 2, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify the ordering - first bucket, then mm1, then second bucket, then mm2
|
||||
FileCheck().check("_c10d_functional.all_gather_into_tensor_out").check(
|
||||
"ops.aten.mm"
|
||||
).check("_c10d_functional.all_gather_into_tensor_out").check(
|
||||
"ops.aten.mm"
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_bucket_exposed_with_hidden_single_overlap(self):
|
||||
"""Test that exposed and hidden collectives bucket together when overlap is preserved."""
|
||||
|
||||
def func(a, b, c, *, ranks):
|
||||
# ag1 will be hidden by mm1
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
|
||||
# ag2 and ag3 are exposed (no compute to hide them)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks)
|
||||
|
||||
# can only hide one collective
|
||||
mm1 = torch.matmul(a[:2], a[:2].T) # 2x2 matmul, hides only ag1
|
||||
|
||||
# All three can bucket together because:
|
||||
# bucketing ag1, ag2, ag3 together does not prevent ag1 being hidden by mm1.
|
||||
|
||||
return ag1.sum() + ag2.sum() + ag3.sum() + mm1.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
|
||||
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
|
||||
|
||||
# Should have 1 bucketed operation containing all 3 all-gathers
|
||||
FileCheck().check_count("wait_tensor.default", 1, exactly=True).run(
|
||||
aten_graph_str
|
||||
)
|
||||
|
||||
# Verify bucketed collective overlaps with mm1
|
||||
FileCheck().check("functional.all_gather_into_tensor").check(
|
||||
"aten.mm"
|
||||
).check("wait_tensor").run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -2010,6 +2010,9 @@ class test_configs:
|
||||
# to be migrated when ready for use
|
||||
aten_fx_overlap_scheduling = False
|
||||
|
||||
# to be migrated when ready for use
|
||||
aten_fx_overlap_preserving_bucketing = False
|
||||
|
||||
# to be migrated when ready for use
|
||||
# runtime estimation function for ops
|
||||
# for user-defined estimation function, pass in the function handle
|
||||
|
@ -570,7 +570,7 @@ def process_collective_bucket(
|
||||
trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]],
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> dict[torch.fx.Node, torch.fx.Node]:
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
"""
|
||||
Process a single bucket of collective operation nodes with flexible insertion control.
|
||||
|
||||
@ -583,6 +583,7 @@ def process_collective_bucket(
|
||||
wait_insertion_point: If provided, move all nodes from wait() onwards to before this node
|
||||
|
||||
Returns:
|
||||
new_nodes: List of all newly inserted nodes
|
||||
replacements: Dictionary mapping old wait nodes to new output nodes
|
||||
"""
|
||||
# Collect inputs and waits from current bucket
|
||||
@ -650,15 +651,16 @@ def process_collective_bucket(
|
||||
for pre_node in reversed(ag_node_to_pre_nodes[node]):
|
||||
g.erase_node(pre_node)
|
||||
|
||||
return replacements
|
||||
return new_nodes, replacements
|
||||
|
||||
|
||||
def merge_reduce_scatter_bucket(
|
||||
g: torch.fx.Graph,
|
||||
rs_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> None:
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
# Validate bucket consistency
|
||||
rs0 = rs_nodes[0]
|
||||
rs0_val = rs0.meta["val"]
|
||||
@ -692,11 +694,12 @@ def merge_reduce_scatter_bucket(
|
||||
device,
|
||||
)
|
||||
|
||||
process_collective_bucket(
|
||||
return process_collective_bucket(
|
||||
g,
|
||||
rs_nodes,
|
||||
rs_merge_fn,
|
||||
create_trace_args,
|
||||
insert_before=insert_before,
|
||||
wait_insertion_point=wait_insertion_point,
|
||||
)
|
||||
|
||||
@ -705,8 +708,9 @@ def merge_all_gather_bucket(
|
||||
g: torch.fx.Graph,
|
||||
ag_nodes: list[torch.fx.Node],
|
||||
mode: Optional[str] = None,
|
||||
insert_before: Optional[torch.fx.Node] = None,
|
||||
wait_insertion_point: Optional[torch.fx.Node] = None,
|
||||
) -> None:
|
||||
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
ag0 = ag_nodes[0]
|
||||
@ -739,7 +743,7 @@ def merge_all_gather_bucket(
|
||||
rank,
|
||||
)
|
||||
|
||||
process_collective_bucket(
|
||||
return process_collective_bucket(
|
||||
g,
|
||||
ag_nodes,
|
||||
ag_merge_fn,
|
||||
|
256
torch/_inductor/fx_passes/overlap_preserving_bucketer.py
Normal file
256
torch/_inductor/fx_passes/overlap_preserving_bucketer.py
Normal file
@ -0,0 +1,256 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._inductor.augmented_graph_helper import AugmentedGraphHelper
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
_ag_group_key,
|
||||
_rs_group_key,
|
||||
is_all_gather_into_tensor as is_all_gather,
|
||||
is_reduce_scatter_tensor as is_reduce_scatter,
|
||||
is_wait_tensor,
|
||||
)
|
||||
from torch._inductor.fx_passes.overlap_scheduling import CollBucket, CollectiveInfo
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def bucket_key(node: torch.fx.Node) -> Optional[object]:
|
||||
if is_all_gather(node):
|
||||
return _ag_group_key(node)
|
||||
elif is_reduce_scatter(node):
|
||||
return _rs_group_key(node)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class OverlapPreservingBucketer:
|
||||
"""
|
||||
Buckets collective operations while preserving compute-collective overlap relationships.
|
||||
Uses an augmented graph to track dependencies between compute and collective operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.Graph,
|
||||
collective_info: dict[fx.Node, CollectiveInfo],
|
||||
node_ancestors: dict[fx.Node, OrderedSet[fx.Node]],
|
||||
scheduled: OrderedSet[fx.Node],
|
||||
max_bucket_memory_gb: float = 1.0,
|
||||
):
|
||||
self.graph = graph
|
||||
self.collective_info = collective_info
|
||||
self.node_ancestors = node_ancestors
|
||||
self.scheduled = scheduled
|
||||
self.max_bucket_memory_gb = max_bucket_memory_gb
|
||||
self.node_idx = {n: i for i, n in enumerate(scheduled)}
|
||||
|
||||
def bucket_collectives(self) -> None:
|
||||
"""Main entry point for bucketing collectives."""
|
||||
|
||||
aug_graph = AugmentedGraphHelper(self.graph)
|
||||
|
||||
# Add extra dependencies for hidden collectives
|
||||
# For each hidden collective, add: compute -> start and wait -> compute
|
||||
for start_node, info in self.collective_info.items():
|
||||
if info.hiding_node and not info.is_exposed:
|
||||
# Add edge: hiding_compute depends on start (start must come before compute)
|
||||
aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node)
|
||||
# Add edge: wait depends on hiding_compute (compute must come before wait)
|
||||
aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node)
|
||||
|
||||
# Group collectives by bucket key (type, group, etc.)
|
||||
grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
for start in self.collective_info:
|
||||
key = bucket_key(start)
|
||||
if key is not None:
|
||||
grouped_collectives[key].add(start)
|
||||
|
||||
all_buckets: list[CollBucket] = []
|
||||
for collective_group in grouped_collectives.values():
|
||||
buckets = self._find_buckets(collective_group, aug_graph)
|
||||
all_buckets.extend(buckets)
|
||||
|
||||
# Collect all extra dependencies to preserve after bucketing
|
||||
additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
|
||||
# Apply bucketing transformations
|
||||
for coll_bucket in all_buckets:
|
||||
if len(coll_bucket.collectives) <= 1:
|
||||
continue
|
||||
|
||||
bucket_deps = self._apply_bucket(coll_bucket)
|
||||
additional_deps.update(bucket_deps)
|
||||
|
||||
# Apply topological sort with all the collected dependencies
|
||||
from torch._dynamo.graph_deduplication import _stable_topological_sort
|
||||
|
||||
_stable_topological_sort(self.graph, additional_deps)
|
||||
self.graph.lint()
|
||||
|
||||
def _find_buckets(
|
||||
self,
|
||||
collective_group: OrderedSet[fx.Node],
|
||||
aug_graph: AugmentedGraphHelper,
|
||||
) -> list[CollBucket]:
|
||||
"""Find valid buckets within a group of similar collectives."""
|
||||
|
||||
max_bucket_bytes = int(self.max_bucket_memory_gb * 1024 * 1024 * 1024)
|
||||
buckets = []
|
||||
processed: OrderedSet[fx.Node] = OrderedSet()
|
||||
|
||||
for start_node in collective_group:
|
||||
if start_node in processed:
|
||||
continue
|
||||
|
||||
# Initialize bucket with first collective
|
||||
bucket_info = CollBucket(
|
||||
collectives=[start_node],
|
||||
total_bytes=self.collective_info[start_node].size_bytes,
|
||||
)
|
||||
processed.add(start_node)
|
||||
|
||||
# TODO - limit within range
|
||||
for candidate in collective_group:
|
||||
if candidate in processed:
|
||||
continue
|
||||
|
||||
candidate_bytes = self.collective_info[candidate].size_bytes
|
||||
if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes:
|
||||
continue
|
||||
|
||||
if self._can_add_to_bucket(bucket_info, candidate, aug_graph):
|
||||
bucket_info.collectives.append(candidate)
|
||||
bucket_info.total_bytes += candidate_bytes
|
||||
processed.add(candidate)
|
||||
|
||||
if len(bucket_info.collectives) > 1:
|
||||
buckets.append(bucket_info)
|
||||
|
||||
return buckets
|
||||
|
||||
def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool:
|
||||
"""Check if there's an ancestor relationship between two nodes."""
|
||||
return n1 in self.node_ancestors[n2] or n2 in self.node_ancestors[n1]
|
||||
|
||||
def _can_add_to_bucket(
|
||||
self,
|
||||
bucket_info: CollBucket,
|
||||
candidate: fx.Node,
|
||||
aug_graph: AugmentedGraphHelper,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if candidate can be added to bucket without interfering
|
||||
with comm/compute overlap.
|
||||
"""
|
||||
|
||||
candidate_info = self.collective_info[candidate]
|
||||
candidate_wait = candidate_info.wait_node
|
||||
|
||||
# Step 1: Quick check using precomputed ancestors
|
||||
# This will not be fully up to date because bucketing changes ancestors,
|
||||
# however any ancestor at the start of bucketing will remain an ancestor.
|
||||
for coll in bucket_info.collectives:
|
||||
if self._ancestor_dep(coll, candidate):
|
||||
return False
|
||||
|
||||
coll_wait = self.collective_info[coll].wait_node
|
||||
if self._ancestor_dep(candidate_wait, coll_wait):
|
||||
return False
|
||||
|
||||
if hiding_node := self.collective_info[coll].hiding_node:
|
||||
if self._ancestor_dep(hiding_node, candidate_wait):
|
||||
return False
|
||||
|
||||
if new_hiding_node := candidate_info.hiding_node:
|
||||
if self._ancestor_dep(new_hiding_node, coll_wait):
|
||||
return False
|
||||
|
||||
# Step 2: Check and merge starts
|
||||
# Check if there's a path between any existing start and candidate start.
|
||||
# Because the collectives have already been merged, we can just start from one
|
||||
# of them.
|
||||
# TODO: we have a range of possible idxs of the merged node, and idx of new node.
|
||||
# we should not do path search beyond that range
|
||||
existing_coll = bucket_info.collectives[0]
|
||||
if aug_graph.has_path(existing_coll, candidate):
|
||||
return False
|
||||
if aug_graph.has_path(candidate, existing_coll):
|
||||
return False
|
||||
|
||||
# Safe to merge starts - do the merge
|
||||
aug_graph.merge_to_set(existing_coll, candidate)
|
||||
|
||||
# Step 3: Check and merge waits
|
||||
existing_wait = self.collective_info[existing_coll].wait_node
|
||||
candidate_wait = candidate_info.wait_node
|
||||
# TODO - as above, limit search by idx
|
||||
if aug_graph.has_path(existing_wait, candidate_wait) or aug_graph.has_path(
|
||||
candidate_wait, existing_wait
|
||||
):
|
||||
# Unmerge the start we just merged
|
||||
aug_graph.unmerge_node(candidate)
|
||||
return False
|
||||
|
||||
aug_graph.merge_to_set(existing_wait, candidate_wait)
|
||||
return True
|
||||
|
||||
def _apply_bucket(
|
||||
self, bucket_info: CollBucket
|
||||
) -> dict[fx.Node, OrderedSet[fx.Node]]:
|
||||
"""Apply bucketing transformation and return dependencies to preserve."""
|
||||
|
||||
from torch._inductor.fx_passes.bucketing import (
|
||||
merge_all_gather_bucket,
|
||||
merge_reduce_scatter_bucket,
|
||||
)
|
||||
|
||||
bucket = bucket_info.collectives
|
||||
|
||||
# Find where to place the bucketed operations
|
||||
next_node = bucket[0]
|
||||
while next_node in bucket:
|
||||
next_node = next_node.next
|
||||
waits = [self.collective_info[n].wait_node for n in bucket]
|
||||
first_wait = min(waits, key=lambda w: self.node_idx[w])
|
||||
|
||||
# Create bucketed collective
|
||||
if is_all_gather(bucket[0]):
|
||||
new_nodes, replacements = merge_all_gather_bucket(
|
||||
self.graph,
|
||||
bucket,
|
||||
wait_insertion_point=first_wait,
|
||||
insert_before=next_node,
|
||||
mode="custom_ops",
|
||||
)
|
||||
else:
|
||||
assert is_reduce_scatter(bucket[0])
|
||||
new_nodes, replacements = merge_reduce_scatter_bucket(
|
||||
self.graph,
|
||||
bucket,
|
||||
wait_insertion_point=first_wait,
|
||||
insert_before=next_node,
|
||||
mode="custom_ops",
|
||||
)
|
||||
|
||||
# Build dependencies to preserve overlap
|
||||
# replacements maps old_start -> new_start, old_wait -> new_wait
|
||||
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
|
||||
assert len(new_waits) == 1
|
||||
|
||||
new_wait = new_waits[0]
|
||||
new_start = new_wait.args[0]
|
||||
assert isinstance(new_start, fx.Node)
|
||||
|
||||
overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
|
||||
|
||||
# Create dependencies to preserve overlap
|
||||
for coll in bucket:
|
||||
info = self.collective_info[coll]
|
||||
if info.hiding_node and not info.is_exposed:
|
||||
# Compute depends on collective start
|
||||
overlap_deps[info.hiding_node].add(new_start)
|
||||
# Wait depends on compute
|
||||
overlap_deps[new_wait].add(info.hiding_node)
|
||||
|
||||
return overlap_deps
|
@ -11,6 +11,7 @@ from typing import Any, Callable, Optional, Union
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.fx_passes.bucketing import is_wait_tensor
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
@ -20,13 +21,6 @@ log = logging.getLogger(__name__)
|
||||
from ..pattern_matcher import stable_topological_sort
|
||||
|
||||
|
||||
def is_wait_tensor(node: torch.fx.Node) -> bool:
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
||||
)
|
||||
|
||||
|
||||
def get_custom_estimation(n: fx.Node) -> Optional[float]:
|
||||
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
|
||||
if runtime_estimation == "default":
|
||||
@ -156,6 +150,16 @@ class CollectiveInfo:
|
||||
return self.exposed_time_ms != 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollBucket:
|
||||
"""Track information about a bucket of collectives."""
|
||||
|
||||
collectives: list[fx.Node] # Original collective starts
|
||||
bucketed_start: Optional[fx.Node] = None # After bucketing
|
||||
bucketed_wait: Optional[fx.Node] = None # After bucketing
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class OverlapScheduler:
|
||||
"""
|
||||
Scheduler that reorders operations to maximize compute-collective overlap.
|
||||
@ -184,7 +188,7 @@ class OverlapScheduler:
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
max_in_flight_gb: float = 2.0,
|
||||
compute_overlap_multipler: float = 1.0,
|
||||
compute_overlap_multipler: float = 2.0,
|
||||
max_coll_distance: int = 1000,
|
||||
):
|
||||
self.gm = gm
|
||||
@ -325,6 +329,8 @@ class OverlapScheduler:
|
||||
self._handle_other(node)
|
||||
|
||||
self._reorder_graph()
|
||||
if torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing:
|
||||
self._bucket_collectives()
|
||||
return self.gm
|
||||
|
||||
def _handle_other(self, node: fx.Node) -> None:
|
||||
@ -583,6 +589,20 @@ class OverlapScheduler:
|
||||
|
||||
self.reorder_graph()
|
||||
|
||||
def _bucket_collectives(self) -> None:
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
graph=self.graph,
|
||||
collective_info=self.collective_info,
|
||||
node_ancestors=self.node_ancestors,
|
||||
scheduled=self.scheduled,
|
||||
max_bucket_memory_gb=1.0, # Could make this configurable
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
def compute_potential_hidden_nodes(
|
||||
self, nodes_to_check: Iterable[fx.Node], limit_coll_per_compute: bool = False
|
||||
) -> dict[fx.Node, fx.Node]:
|
||||
|
Reference in New Issue
Block a user