mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
deepcompile: Record graph order using OrderedDict (#7563)
On clear, GraphOrder does not clears ordered_frames. That may confuses subsequent passes after the first iteration. Use an OrderedDict to record the mapping from frame IDs to other graph-related information. Also fix the type annotation of graph_order which is a list of (int , bool) tuples instead of a list of int. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
@ -3,9 +3,10 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Dict, List, Callable
|
||||
from typing import Dict, List, Callable, Tuple
|
||||
import time
|
||||
import gc
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule
|
||||
@ -43,17 +44,14 @@ param_manager: Dict[int, DSGraphParamManager] = {}
|
||||
class GraphOrder:
|
||||
|
||||
def __init__(self):
|
||||
self.ordered_frames = []
|
||||
self.frames = {}
|
||||
self.frames = OrderedDict()
|
||||
|
||||
def add_graph(self, graph_id, frame_id, needs_backward):
|
||||
if frame_id not in self.ordered_frames:
|
||||
self.ordered_frames.append(frame_id)
|
||||
def add_graph(self, graph_id: int, frame_id: int, needs_backward: bool):
|
||||
if frame_id not in self.frames:
|
||||
self.frames[frame_id] = (graph_id, needs_backward)
|
||||
|
||||
self.frames[frame_id] = (graph_id, needs_backward)
|
||||
|
||||
def get_graph_order(self):
|
||||
return [self.frames[frame_id] for frame_id in self.ordered_frames]
|
||||
def get_graph_order(self) -> List[Tuple[int, bool]]:
|
||||
return list(self.frames.values())
|
||||
|
||||
def clear(self):
|
||||
self.frames.clear()
|
||||
@ -180,7 +178,7 @@ def set_example_values_to_symints(real_inputs, param_indices=None):
|
||||
def run_opt_passes(opt_passes: List[Callable],
|
||||
gm: GraphModule,
|
||||
graph_id: int,
|
||||
graph_order: List[int],
|
||||
graph_order: List[Tuple[int, bool]],
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
mem_budget: float,
|
||||
|
@ -41,7 +41,8 @@ def _should_offload(node: Node) -> bool:
|
||||
|
||||
|
||||
def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]],
|
||||
graph_order: List[int], mem_budget: float, param_manager: DSGraphParamManager) -> Graph:
|
||||
graph_order: List[Tuple[int, bool]], mem_budget: float,
|
||||
param_manager: DSGraphParamManager) -> Graph:
|
||||
param_names = set(param_manager.param_names)
|
||||
|
||||
import copy
|
||||
@ -77,7 +78,7 @@ def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_na
|
||||
return graph
|
||||
|
||||
|
||||
def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
|
||||
def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], mem_budget: float,
|
||||
param_manager: DSGraphParamManager) -> Graph:
|
||||
|
||||
graph_value_to_id = value_to_id[graph_id]
|
||||
|
@ -4,7 +4,7 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import copy
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule
|
||||
@ -250,8 +250,9 @@ reload_task_remaining = []
|
||||
total_reload_mem = 0
|
||||
|
||||
|
||||
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
|
||||
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]],
|
||||
profiling_results: ProfilingResult, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> Graph:
|
||||
|
||||
to_remove = []
|
||||
for node in graph.nodes:
|
||||
@ -475,8 +476,9 @@ def add_record_max_mem_nodes(graph: Graph):
|
||||
graph.create_node('call_function', update_max_memory, (name, ), {}, name=name)
|
||||
|
||||
|
||||
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
|
||||
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]],
|
||||
profiling_results: ProfilingResult, mem_budget: float,
|
||||
param_manager: DSGraphParamManager, bwd: bool) -> Graph:
|
||||
|
||||
if bwd:
|
||||
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
|
||||
@ -512,23 +514,24 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int
|
||||
return graph
|
||||
|
||||
|
||||
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
|
||||
bwd)
|
||||
return gm
|
||||
|
||||
|
||||
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
|
||||
bwd)
|
||||
return gm
|
||||
|
||||
|
||||
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]],
|
||||
profiling_results, create_inputs_fn, mem_budget: float,
|
||||
param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
if not bwd and graph_id == graph_order[0][0]:
|
||||
with unset_fake_temporarily():
|
||||
offload_adam_states_sync()
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, GraphModule
|
||||
@ -43,8 +43,9 @@ def get_ds_id(node: Node):
|
||||
return node.args[2]
|
||||
|
||||
|
||||
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
node_to_last_use, user_to_last_uses = get_last_uses(gm.graph)
|
||||
for node in gm.graph.nodes:
|
||||
if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default):
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node, GraphModule
|
||||
@ -34,8 +34,9 @@ def get_ds_id(node: Node):
|
||||
return node.args[2]
|
||||
|
||||
|
||||
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
|
||||
max_mem = get_accelerator().total_memory() * (1 - MARGIN)
|
||||
vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device()))
|
||||
|
@ -4,7 +4,7 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
@ -21,8 +21,9 @@ max_alloc_mem = 0
|
||||
last_optimize_step = 0
|
||||
|
||||
|
||||
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
|
||||
if not bwd:
|
||||
return gm
|
||||
@ -138,7 +139,7 @@ def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], pro
|
||||
|
||||
# def make_selective_gather(z3_optimizer, nz3):
|
||||
|
||||
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results,
|
||||
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
# mem_budget: float, param_manager, bwd: bool) -> Graph:
|
||||
# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd,
|
||||
# z3_optimizer, nz3)
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
@ -52,15 +52,15 @@ def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModu
|
||||
return gm
|
||||
|
||||
|
||||
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
if bwd:
|
||||
return add_z1_reduce_bw(gm, graph_id, param_manager)
|
||||
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=False)
|
||||
|
||||
|
||||
def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
def add_z2_reduce(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
if bwd:
|
||||
return add_z1_reduce_bw(gm, graph_id, param_manager)
|
||||
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager, use_z2=True)
|
||||
|
@ -4,7 +4,7 @@
|
||||
# DeepSpeed Team
|
||||
|
||||
import gc
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node, GraphModule
|
||||
@ -92,7 +92,7 @@ def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_node
|
||||
|
||||
def add_z3_gather_release_fw(gm: GraphModule,
|
||||
graph_id: int,
|
||||
graph_order: List[int],
|
||||
graph_order: List[Tuple[int, bool]],
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
param_manager,
|
||||
@ -139,7 +139,7 @@ def add_z3_gather_release_fw(gm: GraphModule,
|
||||
|
||||
def add_z3_gather_release_bw(gm: GraphModule,
|
||||
graph_id: int,
|
||||
graph_order: List[int],
|
||||
graph_order: List[Tuple[int, bool]],
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
param_manager,
|
||||
@ -172,8 +172,8 @@ def add_z3_gather_release_bw(gm: GraphModule,
|
||||
return gm
|
||||
|
||||
|
||||
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
if bwd:
|
||||
return add_z3_gather_release_bw(gm,
|
||||
graph_id,
|
||||
|
Reference in New Issue
Block a user