deepcompile: Record graph order using OrderedDict (#7563)

On clear, GraphOrder does not clears ordered_frames. That may confuses
subsequent passes after the first iteration.

Use an OrderedDict to record the mapping from frame IDs to other
graph-related information.

Also fix the type annotation of graph_order which is a list of (int ,
bool) tuples instead of a list of int.

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
Junjie Mao
2025-09-16 13:25:32 +08:00
committed by GitHub
parent 660ee89529
commit e9d5d416cc
8 changed files with 50 additions and 45 deletions

View File

@ -3,9 +3,10 @@
# DeepSpeed Team
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Tuple
import time
import gc
from collections import OrderedDict
import torch
from torch.fx import Graph, GraphModule
@ -43,17 +44,14 @@ param_manager: Dict[int, DSGraphParamManager] = {}
class GraphOrder:
def __init__(self):
self.ordered_frames = []
self.frames = {}
self.frames = OrderedDict()
def add_graph(self, graph_id, frame_id, needs_backward):
if frame_id not in self.ordered_frames:
self.ordered_frames.append(frame_id)
def add_graph(self, graph_id: int, frame_id: int, needs_backward: bool):
if frame_id not in self.frames:
self.frames[frame_id] = (graph_id, needs_backward)
self.frames[frame_id] = (graph_id, needs_backward)
def get_graph_order(self):
return [self.frames[frame_id] for frame_id in self.ordered_frames]
def get_graph_order(self) -> List[Tuple[int, bool]]:
return list(self.frames.values())
def clear(self):
self.frames.clear()
@ -180,7 +178,7 @@ def set_example_values_to_symints(real_inputs, param_indices=None):
def run_opt_passes(opt_passes: List[Callable],
gm: GraphModule,
graph_id: int,
graph_order: List[int],
graph_order: List[Tuple[int, bool]],
profiling_results,
create_inputs_fn,
mem_budget: float,

View File

@ -41,7 +41,8 @@ def _should_offload(node: Node) -> bool:
def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]],
graph_order: List[int], mem_budget: float, param_manager: DSGraphParamManager) -> Graph:
graph_order: List[Tuple[int, bool]], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:
param_names = set(param_manager.param_names)
import copy
@ -77,7 +78,7 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_na
return graph
def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:
graph_value_to_id = value_to_id[graph_id]

View File

@ -4,7 +4,7 @@
# DeepSpeed Team
import copy
from typing import List
from typing import List, Tuple
import torch
from torch.fx import Graph, GraphModule
@ -250,8 +250,9 @@ reload_task_remaining = []
total_reload_mem = 0
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]],
profiling_results: ProfilingResult, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> Graph:
to_remove = []
for node in graph.nodes:
@ -475,8 +476,9 @@ def add_record_max_mem_nodes(graph: Graph):
graph.create_node('call_function', update_max_memory, (name, ), {}, name=name)
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]],
profiling_results: ProfilingResult, mem_budget: float,
param_manager: DSGraphParamManager, bwd: bool) -> Graph:
if bwd:
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
@ -512,23 +514,24 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int
return graph
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
bwd)
return gm
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
bwd)
return gm
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]],
profiling_results, create_inputs_fn, mem_budget: float,
param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
if not bwd and graph_id == graph_order[0][0]:
with unset_fake_temporarily():
offload_adam_states_sync()

View File

@ -3,7 +3,7 @@
# DeepSpeed Team
from typing import List
from typing import List, Tuple
import torch
from torch.fx import Node, GraphModule
@ -43,8 +43,9 @@ def get_ds_id(node: Node):
return node.args[2]
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
node_to_last_use, user_to_last_uses = get_last_uses(gm.graph)
for node in gm.graph.nodes:
if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default):

View File

@ -3,7 +3,7 @@
# DeepSpeed Team
from typing import List
from typing import List, Tuple
import torch
from torch.fx import Graph, Node, GraphModule
@ -34,8 +34,9 @@ def get_ds_id(node: Node):
return node.args[2]
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
max_mem = get_accelerator().total_memory() * (1 - MARGIN)
vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device()))

View File

@ -4,7 +4,7 @@
# DeepSpeed Team
from collections import defaultdict
from typing import List
from typing import List, Tuple
import torch
from torch.fx import GraphModule
@ -21,8 +21,9 @@ max_alloc_mem = 0
last_optimize_step = 0
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
if not bwd:
return gm
@ -138,7 +139,7 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], pro
# def make_selective_gather(z3_optimizer, nz3):
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results,
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
# mem_budget: float, param_manager, bwd: bool) -> Graph:
# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd,
# z3_optimizer, nz3)

View File

@ -3,7 +3,7 @@
# DeepSpeed Team
from typing import List
from typing import List, Tuple
import torch
from torch.fx import GraphModule
@ -52,15 +52,15 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu
return gm
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z1_reduce_bw(gm, graph_id, param_manager)
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=False)
def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z1_reduce_bw(gm, graph_id, param_manager)
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=True)

View File

@ -4,7 +4,7 @@
# DeepSpeed Team
import gc
from typing import List, Dict
from typing import List, Dict, Tuple
import torch
from torch.fx import Graph, Node, GraphModule
@ -92,7 +92,7 @@ def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_node
def add_z3_gather_release_fw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
graph_order: List[Tuple[int, bool]],
profiling_results,
create_inputs_fn,
param_manager,
@ -139,7 +139,7 @@ def add_z3_gather_release_fw(gm: GraphModule,
def add_z3_gather_release_bw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
graph_order: List[Tuple[int, bool]],
profiling_results,
create_inputs_fn,
param_manager,
@ -172,8 +172,8 @@ def add_z3_gather_release_bw(gm: GraphModule,
return gm
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z3_gather_release_bw(gm,
graph_id,