diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index a69628354e84..ca729fd50b0a 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1528,7 +1528,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") - def test_all_gather_bucket(self): + @parametrize("bucket_mode", ["all", "all_custom_ops"]) + def test_all_gather_bucket(self, bucket_mode): def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1576,7 +1577,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): with ( torch._inductor.config.patch( { - "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx": bucket_mode, "reorder_for_compute_comm_overlap": False, "runtime_estimations_mms_benchmark": True, } @@ -1595,7 +1596,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): # We want to make sure no unnecessary copy is made. ( FileCheck() - .check_count(".all_gather_into_tensor_out.default(", 2, exactly=True) + .check("= torch.ops._c10d_functional.all_gather_into_tensor") + .check("torch.ops._c10d_functional.all_gather_into_tensor_out.default(") + .check("= torch.ops._c10d_functional.all_gather_into_tensor") .run(code) ) out = compiled(*inputs, **self.get_world_trs()) @@ -1656,7 +1659,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") - def test_reduce_scatter_bucket(self): + @parametrize("bucket_mode", ["all", "all_custom_ops"]) + def test_reduce_scatter_bucket(self, bucket_mode): def func(x, w, rs_0, rs_1, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1697,7 +1701,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): with torch._inductor.config.patch( { - "bucket_reduce_scatters_fx": "fsdp", + "bucket_reduce_scatters_fx": bucket_mode, "reorder_for_compute_comm_overlap": False, } ): @@ -1723,7 +1727,8 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not SM80OrLater, "bfloat16") - def test_reorder_peak_memory_bucketed(self): + @parametrize("bucket_mode", ["all", "all_custom_ops"]) + def test_reorder_peak_memory_bucketed(self, bucket_mode): """ Simulate the case where a bucketing pass ran and grouped several inputs into one bucketed allgather. Ensure the whole bucketed group including copy-ops get moved together rather than the copy ops preventing the @@ -1837,9 +1842,9 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): with ( torch._inductor.config.patch( { - "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx": bucket_mode, "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, - "bucket_reduce_scatters_fx": "all", + "bucket_reduce_scatters_fx": bucket_mode, "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 1c4c5f6c3f73..bf16454157b3 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,5 +1,6 @@ import collections import logging +from collections import defaultdict from typing import Any, Callable, Optional import torch @@ -33,6 +34,7 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: def bucket_all_gather( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, + mode: Optional[str] = None, ) -> None: if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( @@ -43,13 +45,13 @@ def bucket_all_gather( ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx) if len(ag_buckets) == 0: return - - merge_all_gather(gm, ag_buckets) + merge_all_gather(gm, ag_buckets, mode) def bucket_reduce_scatter( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, + mode: Optional[str] = None, ) -> None: if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( @@ -60,7 +62,7 @@ def bucket_reduce_scatter( rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx) if len(rs_buckets) == 0: return - merge_reduce_scatter(gm, rs_buckets) + merge_reduce_scatter(gm, rs_buckets, mode) def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type] @@ -131,28 +133,46 @@ def greedy_bucket_collective_by_mb( node_group_key: Callable[[torch.fx.Node], Any], filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None, ) -> list[list[torch.fx.Node]]: - if not gm.graph.find_nodes( - op="call_function", target=torch.ops._c10d_functional.wait_tensor.default - ): - return [] - + """ + Bucketing adjacent collectives with equal node_group_key. + We can not bucket non adjacent collectives, + as this will effectively change the order of collectives. + Reordering can lead to different order on different ranks. + """ g = gm.graph + found_candidates = False + for node in g.nodes: + if filter_node(node): + found_candidates = True + break + if not found_candidates: + return [] # TODO: pearce kelly algorithm for detecting cycles node_descendents = collect_node_descendants(gm.graph) - node_groups: dict[Any, list[torch.fx.Node]] = collections.defaultdict(list) + nodes_groups: list[list[torch.fx.Node]] = [] + cur_group: list[torch.fx.Node] = [] + cur_group_key = None for node in g.nodes: if is_wait_tensor(node) and filter_node(node.args[0]): if (filter_wait_node is None) or filter_wait_node(node): coll_node = node.args[0] group_key = node_group_key(coll_node) - node_groups[group_key].append(coll_node) + if group_key == cur_group_key: + cur_group.append(coll_node) + else: + if len(cur_group) > 1: + nodes_groups.append(cur_group) + cur_group = [coll_node] + cur_group_key = group_key + + if len(cur_group) > 1: + nodes_groups.append(cur_group) buckets: list[list[torch.fx.Node]] = [] - - for nodes in node_groups.values(): + for nodes in nodes_groups: cur_bucket: list[torch.fx.Node] = [] cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet() cur_bucket_size_bytes: int = 0 @@ -261,6 +281,52 @@ def bucket_reduce_scatter_by_mb( ) +@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={}) +def _pre_bucket_reduce_scatter( + rs_ins: list[torch.Tensor], + group_size: int, +) -> torch.Tensor: + rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins] + new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() + return new_rs_in + + +def _pre_bucket_reduce_scatter_fake( + rs_ins: list[torch.Tensor], + group_size: int, +) -> torch.Tensor: + out_numel = sum(rs_in.numel() for rs_in in rs_ins) + return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype) + + +_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake) + + +def reduce_scatter_merge_fn_to_trace_custom_ops( + rs_ins: list[torch.Tensor], + group_size: int, + group_name: str, + reduce_op: str, + reduce_dtype: torch.dtype, # type: ignore[name-defined] + device: torch.device, # type: ignore[name-defined] +) -> list[torch.Tensor]: # type: ignore[no-untyped-def] + new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins] + new_out_numels = [x.numel() // group_size for x in rs_ins] + + new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size) + + # TODO - either use torch.cat or make sure inductor foreach codegen + # fires more reliably + new_rs_out = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.reduce_scatter_tensor.default( + new_rs_in, reduce_op, group_size, group_name + ) + ) + new_out_flat = new_rs_out.split(new_out_numels, 0) + new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)] + return new_outs + + def reduce_scatter_merge_fn_to_trace( rs_ins: list[torch.Tensor], group_size: int, @@ -276,8 +342,6 @@ def reduce_scatter_merge_fn_to_trace( new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten() - # TODO - either use torch.cat or make sure inductor foreach codegen - # fires more reliably new_rs_out = torch.ops.c10d_functional.wait_tensor( torch.ops._c10d_functional.reduce_scatter_tensor.default( new_rs_in, reduce_op, group_size, group_name @@ -288,6 +352,74 @@ def reduce_scatter_merge_fn_to_trace( return new_outs +@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={}) +def _pre_bucket_all_gather( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + rank: int, +) -> torch.Tensor: + ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ag_input_numel = sum(ins_split_sizes) + device = ag_ins[0].device + new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) + new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) + foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) + return new_ag_out + + +def _pre_bucket_all_gather_fake( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + rank: int, +) -> torch.Tensor: + ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ag_input_numel = sum(ins_split_sizes) + device = ag_ins[0].device + new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) + return new_ag_out + + +_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake) + + +def all_gather_merge_fn_to_trace_custom_ops( + ag_ins: list[torch.Tensor], + group_size: int, + group_name: str, + dtype: torch.dtype, # type: ignore[name-defined] + rank: int, +) -> list[torch.Tensor]: + ins_sizes = [ag_in.shape for ag_in in ag_ins] + ins_split_sizes = [ag_in.numel() for ag_in in ag_ins] + ag_input_numel = sum(ins_split_sizes) + new_ag_out = torch.ops.bucketing._pre_bucket_all_gather( + ag_ins, group_size, group_name, dtype, rank + ) + new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) + wait_tensor = torch.ops.c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + new_ag_in, group_size, group_name, out=new_ag_out + ) + ) + new_ag_out_reshaped = wait_tensor.reshape(group_size, -1) + outs = torch.split_with_sizes( + new_ag_out_reshaped, + ins_split_sizes, + dim=1, + ) + outs_reshaped = [ + o.reshape((shape[0] * group_size,) + shape[1:]) + for o, shape in zip(outs, ins_sizes) + ] + return outs_reshaped + + def all_gather_merge_fn_to_trace( ag_ins: list[torch.Tensor], group_size: int, @@ -420,9 +552,17 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] def merge_reduce_scatter( - gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]] + gm: torch.fx.GraphModule, + rs_buckets: list[list[torch.fx.Node]], + mode: Optional[str] = None, ) -> None: + """ + Merges specified buckets of reduce_scatter to joint reduce_scatter. + """ with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True): + rs_merge_fn = reduce_scatter_merge_fn_to_trace + if mode and "custom_ops" in mode: + rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops trace_structured( "artifact", metadata_fn=lambda: { @@ -469,7 +609,7 @@ def merge_reduce_scatter( replacements = _insert_fn_trace_before_node( g, - reduce_scatter_merge_fn_to_trace, + rs_merge_fn, ( pytree.tree_map(lambda node: node.meta["val"], _rs_ins), group_size, @@ -501,7 +641,9 @@ def merge_reduce_scatter( def merge_all_gather( - gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]] + gm: torch.fx.GraphModule, + ag_buckets: list[list[torch.fx.Node]], + mode: Optional[str] = None, ) -> None: # type: ignore[union-attr] """ Merges specified buckets of all_gather to joint all_gather. @@ -509,6 +651,10 @@ def merge_all_gather( with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True): from torch.distributed.distributed_c10d import _resolve_process_group + ag_merge_fn = all_gather_merge_fn_to_trace + if mode and "custom_ops" in mode: + ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops + trace_structured( "artifact", metadata_fn=lambda: { @@ -519,6 +665,8 @@ def merge_all_gather( ) n_buckets = len(ag_buckets) + ag_node_to_pre_nodes = defaultdict(list) + ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] for bucket_idx, ag_bucket in enumerate(ag_buckets): @@ -537,6 +685,14 @@ def merge_all_gather( and ag_node.meta["val"].dtype == dtype ) ag_node_in = ag_node.args[0] + if ( + ag_node_in.op == "call_function" # type: ignore[union-attr] + and ag_node_in.target # type: ignore[union-attr] + == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] + and len(ag_node_in.users) == 1 # type: ignore[union-attr] + ): + ag_node_to_pre_nodes[ag_node].append(ag_node_in) + ag_node_in = ag_node_in.args[0] # type: ignore[union-attr] ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type] ag_waits[bucket_idx].append(wait_node) @@ -558,7 +714,7 @@ def merge_all_gather( replacements = _insert_fn_trace_before_node( g, - all_gather_merge_fn_to_trace, + ag_merge_fn, ( pytree.tree_map(lambda node: node.meta["val"], _ag_ins), group_size, @@ -582,3 +738,5 @@ def merge_all_gather( for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits): g.erase_node(wait_n) g.erase_node(ag_n) + for n in reversed(ag_node_to_pre_nodes[ag_n]): + g.erase_node(n) # type: ignore[arg-type] diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index e24ebe4037e7..e7e574ae4934 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -56,6 +56,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool: def bucket_fsdp_all_gather( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, + mode: Optional[str] = None, ) -> None: """ Bucketing pass for SimpleFSDP all_gather ops. @@ -79,12 +80,13 @@ def bucket_fsdp_all_gather( ) if len(ag_buckets) == 0: return - merge_all_gather(gm, ag_buckets) + merge_all_gather(gm, ag_buckets, mode) def bucket_fsdp_reduce_scatter( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None, + mode: Optional[str] = None, ) -> None: """ Bucketing pass for SimpleFSDP reduce_scatter ops. @@ -109,4 +111,4 @@ def bucket_fsdp_reduce_scatter( ) if len(rs_buckets) == 0: return - merge_reduce_scatter(gm, rs_buckets) + merge_reduce_scatter(gm, rs_buckets, mode) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index db273b06c8e6..ba6953c09118 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -204,13 +204,14 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): p = ( bucket_fsdp_reduce_scatter - if config.bucket_reduce_scatters_fx == "fsdp" + if "fsdp" in config.bucket_reduce_scatters_fx else bucket_reduce_scatter ) GraphTransformObserver(gm, "bucket_reduce_scatters").apply_graph_pass( lambda graph: p( graph.owning_module, config.bucket_reduce_scatters_fx_bucket_size_determinator, + config.bucket_reduce_scatters_fx, ) ) collectives_bucketing = True @@ -223,13 +224,14 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): p = ( bucket_fsdp_all_gather # type: ignore[assignment] - if config.bucket_all_gathers_fx == "fsdp" + if "fsdp" in config.bucket_all_gathers_fx else bucket_all_gather ) GraphTransformObserver(gm, "bucket_all_gathers").apply_graph_pass( lambda graph: p( graph.owning_module, config.bucket_all_gathers_fx_bucket_size_determinator, + config.bucket_all_gathers_fx, ) ) collectives_bucketing = True