From 371ffaf415baf6251b9d98466c8ee970b3556282 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 22 Jul 2025 04:51:46 -0700 Subject: [PATCH] [bucketing] Support case of several pgs in graph (#158632) Main changes: - bucketing collectives only from the same process_group by group_name - Support of groups like [0,2,4,6], [0,1,3,5] using `rank_idx_dict` for in pass operations for slice idxs etc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158632 Approved by: https://github.com/wconstab --- torch/_inductor/fx_passes/bucketing.py | 491 +++++++++++++------------ 1 file changed, 254 insertions(+), 237 deletions(-) diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 8f5bb5ffd324..1794ce3a2a29 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,6 +1,7 @@ import logging import math import operator +from collections import defaultdict from typing import Any, Callable, Optional, Union import torch @@ -77,13 +78,9 @@ def bucket_all_gather_by_mb( ) -> list[list[torch.fx.Node]]: """ Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`. - - Returns a list of buckets, where each bucket is a list of all_gather nodes. """ - node_list = gm.graph.nodes - # Prerequisite: Check if there is any all_gather node found_all_gather = False for node in node_list: @@ -92,48 +89,53 @@ def bucket_all_gather_by_mb( break if not found_all_gather: return [] - - ag_nodes: list[torch.fx.Node] = [] - + group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined] + defaultdict(list) + ) # Step 1: Find all all_gather nodes for node in node_list: if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): if (filter_wait_node is None) or filter_wait_node(node): ag_node = node.args[0] - ag_nodes.append(ag_node) - + _, group_size, group_name = ag_node.args + dtype = ag_node.meta["val"].dtype + assert isinstance(group_name, str) + group_name_ag_nodes[(group_name, dtype)].append(ag_node) # Step 2: Put all_gather nodes into buckets ag_buckets: list[list[torch.fx.Node]] = [] - cur_bucket: list[torch.fx.Node] = [] - cur_bucket_size_bytes: int = 0 - cur_bucket_id: int = 0 - # Convert MiB to bytes - all_gather_bucket_size_bytes = int( - all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - for ag_node in ag_nodes: - assert is_all_gather_into_tensor(ag_node) - assert "val" in ag_node.meta - ag_output_size_bytes = ( - ag_node.meta["val"].numel() - * torch.finfo(ag_node.meta["val"].dtype).bits - // 8 + for (group_name, dtype), ag_nodes in group_name_ag_nodes.items(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + all_gather_bucket_size_bytes = int( + all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 ) - if ( - cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes - and cur_bucket - ): - # Current bucket is full, create new bucket + for ag_node in ag_nodes: + assert is_all_gather_into_tensor(ag_node) + if ag_node in cur_bucket_recursive_users: + # We can not bucket successors with the node + continue + assert "val" in ag_node.meta + ag_n_val = ag_node.meta["val"] + ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size() + if ( + cur_bucket_size_bytes + ag_output_size_bytes + > all_gather_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + if len(cur_bucket) > 1: + ag_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_size_bytes += ag_output_size_bytes + cur_bucket.append(ag_node) + find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users) + if len(cur_bucket) > 1: + # add remaining nodes in the last bucket ag_buckets.append(cur_bucket) - cur_bucket = [] - cur_bucket_size_bytes = 0 - cur_bucket_id += 1 - cur_bucket_size_bytes += ag_output_size_bytes - cur_bucket.append(ag_node) - if cur_bucket: - # add remaining nodes in the last bucket - ag_buckets.append(cur_bucket) - return ag_buckets @@ -143,13 +145,9 @@ def bucket_reduce_scatter_by_mb( ) -> list[list[torch.fx.Node]]: """ Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`. - - Returns a list of buckets, where each bucket is a list of reduce_scatter nodes. """ - node_list = list(gm.graph.nodes) - # Prerequisite: Check if there is any reduce_scatter node found_reduce_scatter = False for node in node_list: @@ -158,64 +156,71 @@ def bucket_reduce_scatter_by_mb( break if not found_reduce_scatter: return [] - - rs_nodes: list[torch.fx.Node] = [] - + group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined] + defaultdict(list) + ) # Step 1: Find all reduce_scatter nodes for node in node_list: if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]): rs_node = node.args[0] - rs_nodes.append(rs_node) - + _, reduce_op, group_size, group_name = rs_node.args + dtype = rs_node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node) # Step 2: Put reduce_scatter nodes into buckets rs_buckets: list[list[torch.fx.Node]] = [] - cur_bucket: list[torch.fx.Node] = [] - cur_bucket_size_bytes: int = 0 - cur_bucket_id: int = 0 - # Convert MiB to bytes - reduce_scatter_bucket_size_bytes = int( - reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - for rs_node in rs_nodes: - assert is_reduce_scatter_tensor(rs_node) - rs_input = rs_node.args[0] - assert "val" in rs_input.meta # type: ignore[union-attr] - rs_input_size_bytes = ( - rs_input.meta["val"].numel() # type: ignore[union-attr] - * torch.finfo(rs_input.meta["val"].dtype).bits # type: ignore[union-attr] - // 8 + for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + # Convert MiB to bytes + reduce_scatter_bucket_size_bytes = int( + reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 ) - if ( - cur_bucket_size_bytes + rs_input_size_bytes - > reduce_scatter_bucket_size_bytes - and cur_bucket - ): - # Current bucket is full, create new bucket - total_size = cur_bucket_size_bytes + rs_input_size_bytes + for rs_node in rs_nodes: + assert is_reduce_scatter_tensor(rs_node) + if rs_node in cur_bucket_recursive_users: + # We can not bucket successors with the node + continue + rs_input = rs_node.args[0] + assert "val" in rs_input.meta # type: ignore[union-attr] + rs_in_val = rs_input.meta["val"] # type: ignore[union-attr] + rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size() + if ( + cur_bucket_size_bytes + rs_input_size_bytes + > reduce_scatter_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + total_size = cur_bucket_size_bytes + rs_input_size_bytes + logger.info( + f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004 + f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = " + f"{cur_bucket_size_bytes} + {rs_input_size_bytes}," + f"bucket_cap = {reduce_scatter_bucket_size_bytes}" + ) + if len(cur_bucket) > 1: + rs_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + reduce_scatter_bucket_size_bytes = int( + reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 + ) + cur_bucket_size_bytes += rs_input_size_bytes + cur_bucket.append(rs_node) + find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users) + if cur_bucket: + # add remaining nodes in the last bucket logger.info( - f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004 - f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = " - f"{cur_bucket_size_bytes} + {rs_input_size_bytes}," + f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004 + f"total_size = {cur_bucket_size_bytes}, " f"bucket_cap = {reduce_scatter_bucket_size_bytes}" ) - rs_buckets.append(cur_bucket) - cur_bucket = [] - cur_bucket_size_bytes = 0 - cur_bucket_id += 1 - reduce_scatter_bucket_size_bytes = int( - reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - cur_bucket_size_bytes += rs_input_size_bytes - cur_bucket.append(rs_node) - if cur_bucket: - # add remaining nodes in the last bucket - logger.info( - f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004 - f"total_size = {cur_bucket_size_bytes}, " - f"bucket_cap = {reduce_scatter_bucket_size_bytes}" - ) - rs_buckets.append(cur_bucket) - + if len(cur_bucket) > 1: + rs_buckets.append(cur_bucket) return rs_buckets @@ -260,6 +265,18 @@ def env_lookup( # type: ignore[no-untyped-def] return env[x] +def _rank_idx_dict(group_name: str) -> dict[int, int]: + from torch.distributed.distributed_c10d import ( + _resolve_process_group, + get_process_group_ranks, + ) + + pg = _resolve_process_group(group_name) + ranks = get_process_group_ranks(pg) + rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)} + return rank_idx_dict + + def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]] ) -> None: @@ -297,15 +314,13 @@ def merge_all_gather( bucket_id_is_scheduled = {} cast_bucket_id_is_scheduled = {} _, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args + + group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {} + for bucket_id, ag_bucket in enumerate(ag_buckets): ag_input_nodes = [] wait_nodes = [] for ag_node in ag_bucket: - assert ( - ag_node in ag_node_to_wait_node - and ag_node.args[1] == group_size - and ag_node.args[2] == group_name - ) ag_input_nodes.append(ag_node.args[0]) wait_nodes.append(ag_node_to_wait_node[ag_node]) bucket_id_to_bucketed_op_info[bucket_id] = ( @@ -314,6 +329,8 @@ def merge_all_gather( group_name, wait_nodes, ) + if group_name not in group_name_to_rank_idx_dict: + group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index] ag_wait_nodes = list(ag_node_to_wait_node.values()) ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes) @@ -334,9 +351,6 @@ def merge_all_gather( ag_input_nodes, group_size, group_name, orig_wait_nodes = ( bucket_id_to_bucketed_op_info[bucket_id] ) - # device = ag_input_nodes[0].meta["val"].device - # rank = device.index - # dtype = ag_input_nodes[0].meta["val"].dtype if all( n.op == "call_function" # type: ignore[union-attr] and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] @@ -398,6 +412,7 @@ def merge_all_gather( ag_input_nodes, group_size, group_name, orig_wait_nodes = ( bucket_id_to_bucketed_op_info[bucket_id] ) + rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index] device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr] rank = device.index dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr] @@ -468,7 +483,7 @@ def merge_all_gather( all_gather_output, inp_split_sizes, all_gather_input_numel, - rank, + rank_idx_dict[rank], ), {}, ) @@ -585,6 +600,7 @@ def merge_reduce_scatter( # Prepare bucketed operation info bucket_id_to_bucketed_op_info = {} bucket_id_is_scheduled = {} + group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {} for bucket_id, rs_bucket in enumerate(rs_buckets): _, reduce_op, group_size, group_name = next( iter(rs_node_to_wait_node.keys()) @@ -612,6 +628,8 @@ def merge_reduce_scatter( wait_nodes, wait_node_recursive_users, ) + if group_name not in group_name_to_rank_idx_dict: + group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index] new_graph: torch.fx.Graph = torch.fx.Graph() env: dict[torch.fx.Node, torch.fx.Node] = {} @@ -624,155 +642,154 @@ def merge_reduce_scatter( elif node in rs_node_to_wait_node: assert node in rs_node_to_bucket_id bucket_id = rs_node_to_bucket_id[node] - if ( + if not ( bucket_id not in bucket_id_is_scheduled and rs_buckets[bucket_id][-1] == node ): - # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node - ( - rs_input_nodes, - reduce_op, - group_size, - group_name, - orig_wait_nodes, - orig_wait_node_recursive_users, - ) = bucket_id_to_bucketed_op_info[bucket_id] - # parents of rs have been scheduled, so we can directly use the env - unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index] - reduce_dtype = unsharded_grads[0].meta["val"].dtype - # Only float32 and bfloat16 are supported for now. - # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. - assert reduce_dtype in ( - torch.float32, - torch.bfloat16, - ), f"reduce_dtype {reduce_dtype} is not supported" - assert all( - grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads - ) - device = unsharded_grads[0].meta["val"].device - rank = device.index - shard_dim = 0 + continue - def _get_dim0_padded_size( - tensor_size: torch.Size, dim0_factor: int - ) -> torch.Size: - padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor - return torch.Size([padded_dim0]) + tensor_size[1:] + # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node + ( + rs_input_nodes, + reduce_op, + group_size, + group_name, + orig_wait_nodes, + orig_wait_node_recursive_users, + ) = bucket_id_to_bucketed_op_info[bucket_id] + rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index] + # parents of rs have been scheduled, so we can directly use the env + unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index] + reduce_dtype = unsharded_grads[0].meta["val"].dtype + # Only float32 and bfloat16 are supported for now. + # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. + assert reduce_dtype in ( + torch.float32, # type: ignore[attr-defined] + torch.bfloat16, # type: ignore[attr-defined] + ), f"reduce_dtype {reduce_dtype} is not supported" + assert all( + grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads + ) + device = unsharded_grads[0].meta["val"].device + rank = device.index + rank_idx = rank_idx_dict[rank] + shard_dim = 0 - padded_unsharded_sizes = tuple( - _get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type] - for grad in unsharded_grads - ) - reduce_scatter_input_numel = sum( - s.numel() for s in padded_unsharded_sizes - ) + def _get_dim0_padded_size( + tensor_size: torch.Size, + dim0_factor: int, # type: ignore[name-defined] + ) -> torch.Size: # type: ignore[name-defined] + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined] + return torch.Size([padded_dim0]) + tensor_size[1:] - """ - NOTE: the relationship between the next few nodes is tricky: - - reduce_scatter_input_reshaped is a view of reduce_scatter_input - (same storage, same # elems, different shape). - - chunk_cat writes into reduce_scatter_input_reshaped, - which indirectly writes into reduce_scatter_input - (since they share the same storage). - - reduce_scatter_tensor reads from reduce_scatter_input. - """ - reduce_scatter_input = new_graph_call_function( + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type] + for grad in unsharded_grads + ) + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + + """ + NOTE: the relationship between the next few nodes is tricky: + - reduce_scatter_input_reshaped is a view of reduce_scatter_input + (same storage, same # elems, different shape). + - chunk_cat writes into reduce_scatter_input_reshaped, + which indirectly writes into reduce_scatter_input + (since they share the same storage). + - reduce_scatter_tensor reads from reduce_scatter_input. + """ + reduce_scatter_input = new_graph_call_function( + new_graph, + torch.ops.aten.empty.memory_format, + ([reduce_scatter_input_numel],), + { + "dtype": reduce_dtype, + "device": device, + "pin_memory": False, + }, + ) + reduce_scatter_input_reshaped = new_graph_call_function( + new_graph, + torch.ops.aten.reshape.default, + (reduce_scatter_input, [group_size, -1]), + {}, + ) + new_graph_call_function( + new_graph, + torch.ops.fsdp.chunk_cat.default, + (unsharded_grads,), + { + "dim": 0, + "num_chunks": group_size, + "out": reduce_scatter_input_reshaped, + }, + ) + reduce_scatter_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + (reduce_scatter_input, reduce_op, group_size, group_name), + {}, + ) + + wait_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.wait_tensor.default, + (reduce_scatter_tensor,), + {}, + ) + + def _chunk_with_empty( + tensor: torch.Tensor, num_chunks: int, dim: int + ) -> list[torch.Tensor]: + chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + reduce_output = wait_tensor + # View out and accumulate sharded gradients + new_sharded_grads = [] + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, unsharded_grad in zip( + padded_unsharded_sizes, unsharded_grads + ): + # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here + chunks = _chunk_with_empty( + torch.empty_like(unsharded_grad.meta["val"], device="meta"), + group_size, # type: ignore[arg-type] + dim=shard_dim, + ) + sharded_param = chunks[rank_idx] + sharded_size = sharded_param.size() + contiguous_sharded_stride = ( + torch._prims_common.make_contiguous_strides_for(sharded_size) + ) + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = new_graph_call_function( new_graph, - torch.ops.aten.empty.memory_format, - ([reduce_scatter_input_numel],), + torch.ops.aten.as_strided.default, + (reduce_output,), { - "dtype": reduce_dtype, - "device": device, - "pin_memory": False, + "size": sharded_size, + "stride": contiguous_sharded_stride, + "storage_offset": flat_grad_offset, }, ) - reduce_scatter_input_reshaped = new_graph_call_function( - new_graph, - torch.ops.aten.reshape.default, - (reduce_scatter_input, [group_size, -1]), - {}, - ) - new_graph_call_function( - new_graph, - torch.ops.fsdp.chunk_cat.default, - (unsharded_grads,), - { - "dim": 0, - "num_chunks": group_size, - "out": reduce_scatter_input_reshaped, - }, - ) - reduce_scatter_tensor = new_graph_call_function( - new_graph, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - (reduce_scatter_input, reduce_op, group_size, group_name), - {}, - ) - - wait_tensor = new_graph_call_function( - new_graph, - torch.ops._c10d_functional.wait_tensor.default, - (reduce_scatter_tensor,), - {}, - ) - - def _chunk_with_empty( - tensor: torch.Tensor, num_chunks: int, dim: int - ) -> list[torch.Tensor]: - chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) - while len(chunks) < num_chunks: - chunks.append(chunks[0].new_empty(0)) - return chunks - - reduce_output = wait_tensor - # View out and accumulate sharded gradients - new_sharded_grads = [] - flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] - for padded_unsharded_size, unsharded_grad in zip( - padded_unsharded_sizes, unsharded_grads - ): - # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here - chunks = _chunk_with_empty( - torch.empty_like(unsharded_grad.meta["val"], device="meta"), - group_size, # type: ignore[arg-type] - dim=shard_dim, - ) - sharded_param = chunks[rank] - sharded_size = sharded_param.size() - contiguous_sharded_stride = ( - torch._prims_common.make_contiguous_strides_for(sharded_size) - ) - # Assume even sharding for Shard(i), i > 0; otherwise would require - # copy-out for contiguous strides - new_sharded_grad = new_graph_call_function( - new_graph, - torch.ops.aten.as_strided.default, - (reduce_output,), - { - "size": sharded_size, - "stride": contiguous_sharded_stride, - "storage_offset": flat_grad_offset, - }, - ) - new_sharded_grads.append(new_sharded_grad) - padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator] - flat_grad_offset += padded_sharded_numel # type: ignore[assignment] - assert len(orig_wait_nodes) == len(new_sharded_grads) - assert len(orig_wait_nodes) > 0 - for new_sharded_grad, orig_wait_node in zip( - new_sharded_grads, orig_wait_nodes - ): - env[orig_wait_node] = new_sharded_grad # noqa: PERF403 - for user in sorted( - orig_wait_node_recursive_users, key=lambda x: order[x] - ): - # We skip output node here, because output node will be inserted (later) - # as the last node in the new graph. - if user.op != "output": - node_copy( - env, new_graph, user, lambda x: env_lookup(env, x, user) - ) - bucket_id_is_scheduled[bucket_id] = True + new_sharded_grads.append(new_sharded_grad) + padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator] + flat_grad_offset += padded_sharded_numel # type: ignore[assignment] + assert len(orig_wait_nodes) == len(new_sharded_grads) + assert len(orig_wait_nodes) > 0 + for new_sharded_grad, orig_wait_node in zip( + new_sharded_grads, orig_wait_nodes + ): + env[orig_wait_node] = new_sharded_grad # noqa: PERF403 + for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]): + # We skip output node here, because output node will be inserted (later) + # as the last node in the new graph. + if user.op != "output": + node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user)) + bucket_id_is_scheduled[bucket_id] = True else: continue assert node_list[-1].op == "output"