diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index f7cf7764df56..8073b36f9ca3 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1580,14 +1580,65 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): # We want to make sure no unnecessary copy is made. ( FileCheck() - .check("= torch.ops._c10d_functional.all_gather_into_tensor") - .check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(") - .check("= torch.ops._c10d_functional.all_gather_into_tensor") + .check_count(".all_gather_into_tensor_out.default(", 2, exactly=True) .run(code) ) out = compiled(*inputs, **self.get_world_trs()) assert same(out, correct), f"{out} va {correct}" + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @unittest.skipIf(not SM80OrLater, "bfloat16") + def test_all_gather_bucket_path(self): + def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + # do some unrelated matmuls + y = torch.mm(x, w) + + # cast the inputs + ag_0_cast = ag_0.to(torch.bfloat16) + ag_1_cast = ag_1.to(torch.bfloat16) + + # first allgather + group_name = ( + torch.distributed.distributed_c10d._get_default_group().group_name + ) + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0_cast, group_size, group_name + ) + ag_0_out = torch.ops.c10d_functional.wait_tensor(ag_0_out) + ag_0_out = ag_0_out * 2 + + # Create dependency: second allgather input depends on first allgather output + # This prevents fusion of the two allgather operations + ag_1_modified = ( + ag_1_cast + ag_0_out[: ag_1_cast.shape[0]] + ) # Use part of ag_0_out + + # second allgather (now depends on the first one) + ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_1_modified, group_size, group_name + ) + ag_1_out = torch.ops.c10d_functional.wait_tensor(ag_1_out) + + return y, ag_0_out, ag_1_out + + x = torch.ones(4, 384, device="cuda", dtype=torch.float32) + w = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(384, 512, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1] + + with torch._inductor.config.patch( + { + "bucket_all_gathers_fx": "all", + "reorder_for_compute_comm_overlap": False, + } + ): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) + + # shouldnt have bucketed + FileCheck().check_count("wait_tensor.default(", 2, exactly=True).run(code) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") def test_reduce_scatter_bucket(self): diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 3bf1ff9dab86..1b35cf324f5f 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,5 +1,5 @@ +import collections import logging -from collections import defaultdict from typing import Any, Callable, Optional import torch @@ -42,6 +42,7 @@ def bucket_all_gather( ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx) if len(ag_buckets) == 0: return + merge_all_gather(gm, ag_buckets) @@ -86,6 +87,42 @@ def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool: return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type] +def collect_node_descendants( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: + """ + Collects the descendants of each node in the graph. + Args: + graph (torch.fx.Graph): The graph to collect descendants from. + Returns: + dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants. + """ + node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = ( + collections.defaultdict(OrderedSet) + ) + outdegree = collections.defaultdict(int) + queue = [] + + for node in graph.nodes: + n_outdegree = len(node.users) + if n_outdegree == 0: + queue.append(node) + else: + outdegree[node] = len(node.users) + + while queue: + node = queue.pop() + for input_node in node.all_input_nodes: + node_descendants[input_node] |= node_descendants[node] + node_descendants[input_node].add(node) + outdegree[input_node] -= 1 + + if outdegree[input_node] == 0: + queue.append(input_node) + + return node_descendants + + def greedy_bucket_collective_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], @@ -93,59 +130,38 @@ def greedy_bucket_collective_by_mb( node_group_key: Callable[[torch.fx.Node], Any], filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, ) -> list[list[torch.fx.Node]]: - """ - Bucketing adjacent collectives with equal node_group_key. - We can not bucket non adjacent collectives, - as this will effectively change the order of collectives. - Reordering can lead to different order on different ranks. - """ - g = gm.graph - found_candidates = False - for node in g.nodes: - if filter_node(node): - found_candidates = True - break - if not found_candidates: + if not gm.graph.find_nodes( + op="call_function", target=torch.ops._c10d_functional.wait_tensor.default + ): return [] - nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict( - OrderedSet - ) - nodes_groups: list[list[torch.fx.Node]] = [] - cur_group: list[torch.fx.Node] = [] - cur_group_key = None + g = gm.graph + + # TODO: pearce kelly algorithm for detecting cycles + node_descendents = collect_node_descendants(gm.graph) + + node_groups: dict[Any, list[torch.fx.Node]] = collections.defaultdict(list) for node in g.nodes: - for n, successors in nodes_successors.items(): - if any(arg in successors for arg in node.args): - successors.add(n) if is_wait_tensor(node) and filter_node(node.args[0]): if (filter_wait_node is None) or filter_wait_node(node): coll_node = node.args[0] group_key = node_group_key(coll_node) - if group_key == cur_group_key: - cur_group.append(coll_node) - else: - if len(cur_group) > 1: - nodes_groups.append(cur_group) - cur_group = [coll_node] - cur_group_key = group_key - - if len(cur_group) > 1: - nodes_groups.append(cur_group) + node_groups[group_key].append(coll_node) buckets: list[list[torch.fx.Node]] = [] - for nodes in nodes_groups: + + for nodes in node_groups.values(): cur_bucket: list[torch.fx.Node] = [] - cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet() cur_bucket_size_bytes: int = 0 cur_bucket_id: int = 0 bucket_size_bytes = int( bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024 ) for node in nodes: - if node in cur_bucket_successors: - # We cannot bucket successors with the node + if node in cur_bucket_descendents: + # if there is a path from node to the current bucket, we cannot horizontally fuse (bucket) continue assert "val" in node.meta n_val = node.meta["val"] @@ -160,10 +176,10 @@ def greedy_bucket_collective_by_mb( cur_bucket = [] cur_bucket_size_bytes = 0 cur_bucket_id += 1 - cur_bucket_successors = OrderedSet() + cur_bucket_descendents = OrderedSet() cur_bucket_size_bytes += size_bytes cur_bucket.append(node) - cur_bucket_successors |= nodes_successors[node] + cur_bucket_descendents |= node_descendents[node] if len(cur_bucket) > 1: buckets.append(cur_bucket) return buckets @@ -259,6 +275,8 @@ def reduce_scatter_merge_fn_to_trace( new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() + # TODO - either use torch.cat or make sure inductor foreach codegen + # fires more reliably new_rs_out = torch.ops.c10d_functional.wait_tensor( torch.ops._c10d_functional.reduce_scatter_tensor.default( new_rs_in, reduce_op, group_size, group_name @@ -347,7 +365,13 @@ def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def] fake_mode = detect_fake_mode(inps) assert fake_mode is not None with fake_mode, enable_python_dispatcher(): - return make_fx(fn)(*inps) + out = make_fx(fn)(*inps) + for node in out.graph.find_nodes( + op="call_function", target=torch.ops.aten.detach.default + ): + node.replace_all_uses_with(node.args[0]) + out.graph.erase_node(node) + return out def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] @@ -488,8 +512,6 @@ def merge_all_gather( ) n_buckets = len(ag_buckets) - ag_node_to_pre_nodes = defaultdict(list) - ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] for bucket_idx, ag_bucket in enumerate(ag_buckets): @@ -508,13 +530,6 @@ def merge_all_gather( and ag_node.meta["val"].dtype == dtype ) ag_node_in = ag_node.args[0] - if ( - ag_node_in.op == "call_function" # type: ignore[union-attr] - and ag_node_in.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] - and len(ag_node_in.users) == 1 # type: ignore[union-attr] - ): - ag_node_to_pre_nodes[ag_node].append(ag_node_in) - ag_node_in = ag_node_in.args[0] # type: ignore[union-attr] ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type] ag_waits[bucket_idx].append(wait_node) @@ -560,5 +575,3 @@ def merge_all_gather( for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits): g.erase_node(wait_n) g.erase_node(ag_n) - for n in reversed(ag_node_to_pre_nodes[ag_n]): - g.erase_node(n) # type: ignore[arg-type]