mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix bucketing introducing cycles (#160967)
We were just looking at direct arguments, but not transitive dependencies. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160967 Approved by: https://github.com/IvanKobzarev
This commit is contained in:
committed by
PyTorch MergeBot
parent
dbef606631
commit
b708966201
@ -1580,14 +1580,65 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
# We want to make sure no unnecessary copy is made.
|
||||
(
|
||||
FileCheck()
|
||||
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
|
||||
.check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(")
|
||||
.check("= torch.ops._c10d_functional.all_gather_into_tensor")
|
||||
.check_count(".all_gather_into_tensor_out.default(", 2, exactly=True)
|
||||
.run(code)
|
||||
)
|
||||
out = compiled(*inputs, **self.get_world_trs())
|
||||
assert same(out, correct), f"{out} va {correct}"
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_all_gather_bucket_path(self):
|
||||
def func(x, w, ag_0, ag_1, *, tag, ranks, group_size):
|
||||
# do some unrelated matmuls
|
||||
y = torch.mm(x, w)
|
||||
|
||||
# cast the inputs
|
||||
ag_0_cast = ag_0.to(torch.bfloat16)
|
||||
ag_1_cast = ag_1.to(torch.bfloat16)
|
||||
|
||||
# first allgather
|
||||
group_name = (
|
||||
torch.distributed.distributed_c10d._get_default_group().group_name
|
||||
)
|
||||
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_0_cast, group_size, group_name
|
||||
)
|
||||
ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out)
|
||||
ag_0_out = ag_0_out * 2
|
||||
|
||||
# Create dependency: second allgather input depends on first allgather output
|
||||
# This prevents fusion of the two allgather operations
|
||||
ag_1_modified = (
|
||||
ag_1_cast + ag_0_out[: ag_1_cast.shape[0]]
|
||||
) # Use part of ag_0_out
|
||||
|
||||
# second allgather (now depends on the first one)
|
||||
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
ag_1_modified, group_size, group_name
|
||||
)
|
||||
ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out)
|
||||
|
||||
return y, ag_0_out, ag_1_out
|
||||
|
||||
x = torch.ones(4, 384, device="cuda", dtype=torch.float32)
|
||||
w = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32)
|
||||
inputs = [x, w, ag_0, ag_1]
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
|
||||
# shouldnt have bucketed
|
||||
FileCheck().check_count("wait_tensor.default(", 2, exactly=True).run(code)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@unittest.skipIf(not SM80OrLater, "bfloat16")
|
||||
def test_reduce_scatter_bucket(self):
|
||||
|
@ -1,5 +1,5 @@
|
||||
import collections
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -42,6 +42,7 @@ def bucket_all_gather(
|
||||
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
|
||||
if len(ag_buckets) == 0:
|
||||
return
|
||||
|
||||
merge_all_gather(gm, ag_buckets)
|
||||
|
||||
|
||||
@ -86,6 +87,42 @@ def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
|
||||
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def collect_node_descendants(
|
||||
graph: torch.fx.Graph,
|
||||
) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
|
||||
"""
|
||||
Collects the descendants of each node in the graph.
|
||||
Args:
|
||||
graph (torch.fx.Graph): The graph to collect descendants from.
|
||||
Returns:
|
||||
dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
|
||||
"""
|
||||
node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
|
||||
collections.defaultdict(OrderedSet)
|
||||
)
|
||||
outdegree = collections.defaultdict(int)
|
||||
queue = []
|
||||
|
||||
for node in graph.nodes:
|
||||
n_outdegree = len(node.users)
|
||||
if n_outdegree == 0:
|
||||
queue.append(node)
|
||||
else:
|
||||
outdegree[node] = len(node.users)
|
||||
|
||||
while queue:
|
||||
node = queue.pop()
|
||||
for input_node in node.all_input_nodes:
|
||||
node_descendants[input_node] |= node_descendants[node]
|
||||
node_descendants[input_node].add(node)
|
||||
outdegree[input_node] -= 1
|
||||
|
||||
if outdegree[input_node] == 0:
|
||||
queue.append(input_node)
|
||||
|
||||
return node_descendants
|
||||
|
||||
|
||||
def greedy_bucket_collective_by_mb(
|
||||
gm: torch.fx.GraphModule,
|
||||
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
|
||||
@ -93,59 +130,38 @@ def greedy_bucket_collective_by_mb(
|
||||
node_group_key: Callable[[torch.fx.Node], Any],
|
||||
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
Bucketing adjacent collectives with equal node_group_key.
|
||||
We can not bucket non adjacent collectives,
|
||||
as this will effectively change the order of collectives.
|
||||
Reordering can lead to different order on different ranks.
|
||||
"""
|
||||
g = gm.graph
|
||||
found_candidates = False
|
||||
for node in g.nodes:
|
||||
if filter_node(node):
|
||||
found_candidates = True
|
||||
break
|
||||
if not found_candidates:
|
||||
if not gm.graph.find_nodes(
|
||||
op="call_function", target=torch.ops._c10d_functional.wait_tensor.default
|
||||
):
|
||||
return []
|
||||
|
||||
nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict(
|
||||
OrderedSet
|
||||
)
|
||||
nodes_groups: list[list[torch.fx.Node]] = []
|
||||
cur_group: list[torch.fx.Node] = []
|
||||
cur_group_key = None
|
||||
g = gm.graph
|
||||
|
||||
# TODO: pearce kelly algorithm for detecting cycles
|
||||
node_descendents = collect_node_descendants(gm.graph)
|
||||
|
||||
node_groups: dict[Any, list[torch.fx.Node]] = collections.defaultdict(list)
|
||||
|
||||
for node in g.nodes:
|
||||
for n, successors in nodes_successors.items():
|
||||
if any(arg in successors for arg in node.args):
|
||||
successors.add(n)
|
||||
if is_wait_tensor(node) and filter_node(node.args[0]):
|
||||
if (filter_wait_node is None) or filter_wait_node(node):
|
||||
coll_node = node.args[0]
|
||||
group_key = node_group_key(coll_node)
|
||||
if group_key == cur_group_key:
|
||||
cur_group.append(coll_node)
|
||||
else:
|
||||
if len(cur_group) > 1:
|
||||
nodes_groups.append(cur_group)
|
||||
cur_group = [coll_node]
|
||||
cur_group_key = group_key
|
||||
|
||||
if len(cur_group) > 1:
|
||||
nodes_groups.append(cur_group)
|
||||
node_groups[group_key].append(coll_node)
|
||||
|
||||
buckets: list[list[torch.fx.Node]] = []
|
||||
for nodes in nodes_groups:
|
||||
|
||||
for nodes in node_groups.values():
|
||||
cur_bucket: list[torch.fx.Node] = []
|
||||
cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
|
||||
cur_bucket_size_bytes: int = 0
|
||||
cur_bucket_id: int = 0
|
||||
bucket_size_bytes = int(
|
||||
bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
|
||||
)
|
||||
for node in nodes:
|
||||
if node in cur_bucket_successors:
|
||||
# We cannot bucket successors with the node
|
||||
if node in cur_bucket_descendents:
|
||||
# if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
|
||||
continue
|
||||
assert "val" in node.meta
|
||||
n_val = node.meta["val"]
|
||||
@ -160,10 +176,10 @@ def greedy_bucket_collective_by_mb(
|
||||
cur_bucket = []
|
||||
cur_bucket_size_bytes = 0
|
||||
cur_bucket_id += 1
|
||||
cur_bucket_successors = OrderedSet()
|
||||
cur_bucket_descendents = OrderedSet()
|
||||
cur_bucket_size_bytes += size_bytes
|
||||
cur_bucket.append(node)
|
||||
cur_bucket_successors |= nodes_successors[node]
|
||||
cur_bucket_descendents |= node_descendents[node]
|
||||
if len(cur_bucket) > 1:
|
||||
buckets.append(cur_bucket)
|
||||
return buckets
|
||||
@ -259,6 +275,8 @@ def reduce_scatter_merge_fn_to_trace(
|
||||
|
||||
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
|
||||
|
||||
# TODO - either use torch.cat or make sure inductor foreach codegen
|
||||
# fires more reliably
|
||||
new_rs_out = torch.ops.c10d_functional.wait_tensor(
|
||||
torch.ops._c10d_functional.reduce_scatter_tensor.default(
|
||||
new_rs_in, reduce_op, group_size, group_name
|
||||
@ -347,7 +365,13 @@ def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
|
||||
fake_mode = detect_fake_mode(inps)
|
||||
assert fake_mode is not None
|
||||
with fake_mode, enable_python_dispatcher():
|
||||
return make_fx(fn)(*inps)
|
||||
out = make_fx(fn)(*inps)
|
||||
for node in out.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.detach.default
|
||||
):
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
out.graph.erase_node(node)
|
||||
return out
|
||||
|
||||
|
||||
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
|
||||
@ -488,8 +512,6 @@ def merge_all_gather(
|
||||
)
|
||||
n_buckets = len(ag_buckets)
|
||||
|
||||
ag_node_to_pre_nodes = defaultdict(list)
|
||||
|
||||
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
|
||||
for bucket_idx, ag_bucket in enumerate(ag_buckets):
|
||||
@ -508,13 +530,6 @@ def merge_all_gather(
|
||||
and ag_node.meta["val"].dtype == dtype
|
||||
)
|
||||
ag_node_in = ag_node.args[0]
|
||||
if (
|
||||
ag_node_in.op == "call_function" # type: ignore[union-attr]
|
||||
and ag_node_in.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
|
||||
and len(ag_node_in.users) == 1 # type: ignore[union-attr]
|
||||
):
|
||||
ag_node_to_pre_nodes[ag_node].append(ag_node_in)
|
||||
ag_node_in = ag_node_in.args[0] # type: ignore[union-attr]
|
||||
|
||||
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
|
||||
ag_waits[bucket_idx].append(wait_node)
|
||||
@ -560,5 +575,3 @@ def merge_all_gather(
|
||||
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
|
||||
g.erase_node(wait_n)
|
||||
g.erase_node(ag_n)
|
||||
for n in reversed(ag_node_to_pre_nodes[ag_n]):
|
||||
g.erase_node(n) # type: ignore[arg-type]
|
||||
|
Reference in New Issue
Block a user