From 0b2fdc30a26ce838de7fe7fcbb7b3f50bacf0440 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 29 Sep 2025 15:36:43 -0700 Subject: [PATCH] refactor bucketing (#163754) Preparatory refactory Pull Request resolved: https://github.com/pytorch/pytorch/pull/163754 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #163215 --- torch/_inductor/fx_passes/bucketing.py | 384 ++++++++++++++----------- 1 file changed, 219 insertions(+), 165 deletions(-) diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index bf16454157b3..190fad35af7b 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -18,6 +18,22 @@ logger: logging.Logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +# Helper functions moved to top for better organization +def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: + _, group_size, group_name = node.args + dtype = node.meta["val"].dtype + assert isinstance(group_name, str) + return (group_name, dtype) + + +def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: + _, reduce_op, group_size, group_name = node.args + dtype = node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + return (group_name, reduce_op, dtype) + + def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: """ Determine the size of a bucket based on its ID. @@ -229,12 +245,6 @@ def bucket_all_gather_by_mb( list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes. """ - def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]: - _, group_size, group_name = node.args - dtype = node.meta["val"].dtype - assert isinstance(group_name, str) - return (group_name, dtype) - return greedy_bucket_collective_by_mb( gm, bucket_cap_mb_by_bucket_idx, @@ -265,13 +275,6 @@ def bucket_reduce_scatter_by_mb( list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes. """ - def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: - _, reduce_op, group_size, group_name = node.args - dtype = node.meta["val"].dtype - assert isinstance(group_name, str) - assert isinstance(reduce_op, str) - return (group_name, reduce_op, dtype) - return greedy_bucket_collective_by_mb( gm, bucket_cap_mb_by_bucket_idx, @@ -515,13 +518,16 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] insert_before_node: torch.fx.Node, g_fn_inps: list[torch.fx.Node], g_fn_outs: list[torch.fx.Node], -) -> dict[torch.fx.Node, torch.fx.Node]: # type: ignore[no-untyped-def] +) -> tuple[dict[torch.fx.Node, torch.fx.Node], list[torch.fx.Node]]: # type: ignore[no-untyped-def] """ Helper function that traces :attr:`fn_to_trace` with inputs :attr:`inps`. The result function graph will be inserted before :attr:`insert_before_node`, - using :attr:`g_fn_inps` nodes of original graphas inputs of function graph, + using :attr:`g_fn_inps` nodes of original graph as inputs of function graph, function graph outputs will replace :attr:`g_fn_outs` in original graph. + + Returns: + (replacements, new_nodes): Dictionary mapping old to new nodes, and list of all newly inserted nodes """ with dynamo_timed( "fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True @@ -534,6 +540,8 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] fn_g_ins = fn_g.find_nodes(op="placeholder") env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))} g_fn_new_outs: list[torch.fx.Node] = [] + new_nodes: list[torch.fx.Node] = [] # Track all newly inserted nodes + with g.inserting_before(insert_before_node): for _n in fn_g.nodes: if _n.op == "placeholder": @@ -543,12 +551,201 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] if _n.op == "output": g_fn_new_outs = _new_n.args[0] # type: ignore[assignment] g.erase_node(_new_n) + else: + new_nodes.append(_new_n) # Track non-output nodes + replacements = { # noqa: C416 orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs) } for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs): orig_out.replace_all_uses_with(new_out) - return replacements + + return replacements, new_nodes + + +def process_collective_bucket( + g: torch.fx.Graph, + bucket_nodes: list[torch.fx.Node], + fn_to_trace: Callable[..., list[torch.Tensor]], + trace_args_fn: Callable[[list[torch.fx.Node]], tuple[Any, ...]], + insert_before: Optional[torch.fx.Node] = None, + wait_insertion_point: Optional[torch.fx.Node] = None, +) -> dict[torch.fx.Node, torch.fx.Node]: + """ + Process a single bucket of collective operation nodes with flexible insertion control. + + Args: + g: The graph to modify + bucket_nodes: Nodes in the current bucket to process + fn_to_trace: Function to trace and insert + trace_args_fn: Function to create trace arguments from inputs + insert_before: Where to insert the traced function (default: after last bucket node) + wait_insertion_point: If provided, move all nodes from wait() onwards to before this node + + Returns: + replacements: Dictionary mapping old wait nodes to new output nodes + """ + # Collect inputs and waits from current bucket + bucket_ins: list[torch.fx.Node] = [] + bucket_waits: list[torch.fx.Node] = [] + ag_node_to_pre_nodes: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(list) + + for n in bucket_nodes: + assert len(n.users) == 1, f"Expected single user for {n}, got {n.users}" + wait_n = next(iter(n.users)) + + # Handle convert_element_type operations (for all_gather) + node_in = n.args[0] + if ( + is_all_gather_into_tensor(n) + and isinstance(node_in, torch.fx.Node) # Add type check + and node_in.op == "call_function" + and node_in.target == torch.ops.prims.convert_element_type.default + and len(node_in.users) == 1 + ): + ag_node_to_pre_nodes[n].append(node_in) + node_in = node_in.args[0] + + assert isinstance(node_in, torch.fx.Node) # Ensure node_in is a Node + bucket_ins.append(node_in) + bucket_waits.append(wait_n) + + # Create trace arguments + trace_args = trace_args_fn(bucket_ins) + + # Determine insertion point + if insert_before is None: + insert_before = bucket_nodes[-1].next + + # Insert traced function and get replacements + new nodes + replacements, new_nodes = _insert_fn_trace_before_node( + g, + fn_to_trace, + trace_args, + insert_before, + bucket_ins, + bucket_waits, + ) + + # If requested, move wait nodes and everything after to specified location + if wait_insertion_point is not None: + # Find the first wait node in new_nodes + wait_start_idx = None + for i, node in enumerate(new_nodes): + if is_wait_tensor(node): + wait_start_idx = i + break + + # Move all nodes from wait onwards (including the wait) + if wait_start_idx is not None: + nodes_to_move = new_nodes[wait_start_idx:] + for node in nodes_to_move: + wait_insertion_point.prepend(node) + + # Erase old nodes + for node, wait_n in zip(bucket_nodes, bucket_waits): + g.erase_node(wait_n) + g.erase_node(node) + # Erase any convert_element_type nodes we tracked + for pre_node in reversed(ag_node_to_pre_nodes[node]): + g.erase_node(pre_node) + + return replacements + + +def merge_reduce_scatter_bucket( + g: torch.fx.Graph, + rs_nodes: list[torch.fx.Node], + mode: Optional[str] = None, + wait_insertion_point: Optional[torch.fx.Node] = None, +) -> None: + # Validate bucket consistency + rs0 = rs_nodes[0] + rs0_val = rs0.meta["val"] + _, reduce_op, group_size, group_name = rs0.args + reduce_dtype = rs0_val.dtype + device = rs0_val.device + + for n in rs_nodes: + rs_val = n.meta["val"] + assert ( + n.args[1] == reduce_op + and n.args[2] == group_size + and n.args[3] == group_name + and rs_val.device == device + and rs_val.dtype == reduce_dtype + ) + + # Choose merge function based on mode + rs_merge_fn = reduce_scatter_merge_fn_to_trace + if mode and "custom_ops" in mode: + rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops + + # Process bucket with lazy input collection + def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: + return ( + pytree.tree_map(lambda node: node.meta["val"], bucket_ins), + group_size, + group_name, + reduce_op, + reduce_dtype, + device, + ) + + process_collective_bucket( + g, + rs_nodes, + rs_merge_fn, + create_trace_args, + wait_insertion_point=wait_insertion_point, + ) + + +def merge_all_gather_bucket( + g: torch.fx.Graph, + ag_nodes: list[torch.fx.Node], + mode: Optional[str] = None, + wait_insertion_point: Optional[torch.fx.Node] = None, +) -> None: + from torch.distributed.distributed_c10d import _resolve_process_group + + ag0 = ag_nodes[0] + ag0_val = ag0.meta["val"] + _, group_size, group_name = ag0.args + dtype = ag0_val.dtype + assert isinstance(group_name, str) + + for n in ag_nodes: + assert ( + n.args[1] == group_size + and n.args[2] == group_name + and n.meta["val"].dtype == dtype + ) + + # Choose merge function based on mode + ag_merge_fn = all_gather_merge_fn_to_trace + if mode and "custom_ops" in mode: + ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops + + # Process bucket with lazy input collection + rank: int = dist.get_rank(_resolve_process_group(group_name)) + + def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: + return ( + pytree.tree_map(lambda node: node.meta["val"], bucket_ins), + group_size, + group_name, + dtype, + rank, + ) + + process_collective_bucket( + g, + ag_nodes, + ag_merge_fn, + create_trace_args, + wait_insertion_point=wait_insertion_point, + ) def merge_reduce_scatter( @@ -560,9 +757,6 @@ def merge_reduce_scatter( Merges specified buckets of reduce_scatter to joint reduce_scatter. """ with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True): - rs_merge_fn = reduce_scatter_merge_fn_to_trace - if mode and "custom_ops" in mode: - rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops trace_structured( "artifact", metadata_fn=lambda: { @@ -571,90 +765,22 @@ def merge_reduce_scatter( }, payload_fn=lambda: str(rs_buckets), ) - n_buckets = len(rs_buckets) + g = gm.graph - rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] - rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] - for bucket_idx, rs_nodes in enumerate(rs_buckets): - rs0 = rs_nodes[0] - rs0_val = rs0.meta["val"] - _, reduce_op, group_size, group_name = rs0.args - reduce_dtype = rs0_val.dtype - device = rs0_val.device - for n in rs_nodes: - rs_val = n.meta["val"] - assert ( - n.args[1] == reduce_op - and n.args[2] == group_size - and n.args[3] == group_name - and rs_val.device == device - and rs_val.dtype == reduce_dtype - ) - assert len(n.users) == 1 - wait_n = next(iter(n.users)) - rs_ins[bucket_idx].append(n.args[0]) # type: ignore[arg-type] - rs_waits[bucket_idx].append(wait_n) - - for bucket_idx in range(n_buckets): - _rs_ins = rs_ins[bucket_idx] - _rs_waits = rs_waits[bucket_idx] - _rs_ns = rs_buckets[bucket_idx] - - rs0 = _rs_ns[0] - rs0_val = rs0.meta["val"] - _, reduce_op, group_size, group_name = rs0.args - reduce_dtype = rs0_val.dtype - device = rs0_val.device - - replacements = _insert_fn_trace_before_node( - g, - rs_merge_fn, - ( - pytree.tree_map(lambda node: node.meta["val"], _rs_ins), - group_size, - group_name, - reduce_op, - reduce_dtype, - device, - ), - _rs_ns[-1].next, - _rs_ins, - _rs_waits, - ) - # [Note: Replacement in bucketing passes] - # After bucketing _rs_waits will be replaced with output nodes of - # fn_to_trace graph that will be inserted in the graph g. - # By this time we already prepared rs_ins, rs_waits. - # rs_ins for following buckets can be replaced _rs_waits with new nodes. - # We apply replacements to rs_ins. - - def _replace(x: torch.fx.Node) -> torch.fx.Node: - return replacements.get(x, x) - - for j in range(bucket_idx + 1, n_buckets): - rs_ins[j] = pytree.tree_map(_replace, rs_ins[j]) - - for rs_n, wait_n in zip(_rs_ns, _rs_waits): - g.erase_node(wait_n) - g.erase_node(rs_n) + for rs_nodes in rs_buckets: + merge_reduce_scatter_bucket(g, rs_nodes, mode) def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]], mode: Optional[str] = None, -) -> None: # type: ignore[union-attr] +) -> None: """ Merges specified buckets of all_gather to joint all_gather. """ with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True): - from torch.distributed.distributed_c10d import _resolve_process_group - - ag_merge_fn = all_gather_merge_fn_to_trace - if mode and "custom_ops" in mode: - ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops - trace_structured( "artifact", metadata_fn=lambda: { @@ -663,80 +789,8 @@ def merge_all_gather( }, payload_fn=lambda: str(ag_buckets), ) - n_buckets = len(ag_buckets) - - ag_node_to_pre_nodes = defaultdict(list) - - ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] - ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)] - for bucket_idx, ag_bucket in enumerate(ag_buckets): - _, group_size, group_name = ag_bucket[0].args - assert isinstance(group_name, str) - dtype = ag_bucket[0].meta["val"].dtype - - for ag_node in ag_bucket: - assert len(ag_node.users) == 1, ( - f"Expect only one user for {ag_node}, but got {ag_node.users}" - ) - wait_node = next(iter(ag_node.users)) - assert ( - ag_node.args[1] == group_size - and ag_node.args[2] == group_name - and ag_node.meta["val"].dtype == dtype - ) - ag_node_in = ag_node.args[0] - if ( - ag_node_in.op == "call_function" # type: ignore[union-attr] - and ag_node_in.target # type: ignore[union-attr] - == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] - and len(ag_node_in.users) == 1 # type: ignore[union-attr] - ): - ag_node_to_pre_nodes[ag_node].append(ag_node_in) - ag_node_in = ag_node_in.args[0] # type: ignore[union-attr] - - ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type] - ag_waits[bucket_idx].append(wait_node) g = gm.graph - for bucket_idx in range(n_buckets): - _ag_ins = ag_ins[bucket_idx] - _ag_waits = ag_waits[bucket_idx] - _ag_ns = ag_buckets[bucket_idx] - - ag0 = _ag_ns[0] - ag0_val = ag0.meta["val"] - _, group_size, group_name = ag0.args - dtype = ag0_val.dtype - assert isinstance(group_name, str) - - rank: int = dist.get_rank(_resolve_process_group(group_name)) - - replacements = _insert_fn_trace_before_node( - g, - ag_merge_fn, - ( - pytree.tree_map(lambda node: node.meta["val"], _ag_ins), - group_size, - group_name, - dtype, - rank, - ), - ag0.next, - _ag_ins, - _ag_waits, - ) - - # See Note: [Replacement in bucketing passes] - def _replace(x: torch.fx.Node) -> torch.fx.Node: - return replacements.get(x, x) - - for j in range(bucket_idx + 1, n_buckets): - ag_ins[j] = pytree.tree_map(_replace, ag_ins[j]) - - # Erasing old nodes in reverse order - for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits): - g.erase_node(wait_n) - g.erase_node(ag_n) - for n in reversed(ag_node_to_pre_nodes[ag_n]): - g.erase_node(n) # type: ignore[arg-type] + for ag_nodes in ag_buckets: + merge_all_gather_bucket(g, ag_nodes, mode)