From 23c7d2d7f9d6359c7668bb4833c91b8aaf5663bb Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 6 Apr 2025 08:56:53 +0000 Subject: [PATCH 1/2] fix offload Signed-off-by: Masahiro Tanaka --- .../compile/passes/offload_adam_states.py | 285 +++++++++++++----- 1 file changed, 212 insertions(+), 73 deletions(-) diff --git a/deepspeed/compile/passes/offload_adam_states.py b/deepspeed/compile/passes/offload_adam_states.py index 7a49800ca..ee5381566 100644 --- a/deepspeed/compile/passes/offload_adam_states.py +++ b/deepspeed/compile/passes/offload_adam_states.py @@ -24,6 +24,10 @@ from ..fx import move_primals_to_head import deepspeed.comm as dist + +from deepspeed.runtime.utils import see_memory_usage + + NAME = "offload_adam_states" @@ -66,6 +70,9 @@ def move_key(state, key, key_event=None): if offload_buf_key not in state: state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu")) + if key not in state: + return + with get_accelerator().stream(copy_stream): state[offload_buf_key].copy_(state[key], non_blocking=True) @@ -85,7 +92,34 @@ def move_back_key(state, key, key_event=None): key_event.record(stream=copy_stream) +def move_hp_param(src_tensor, dest_buf, key_event=None): + dest_buf.copy_(src_tensor, non_blocking=True) + src_tensor.data = dest_buf + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_back_hp_param(src_tensor, dest_buf, key_event=None): + # for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat): + dest_buf.data = src_tensor.to(device, non_blocking=True) + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + def offload_adam_states_sync(): + see_memory_usage("before offload_adam_states_sync", force=True) + + if not hasattr(optimizer, "hp_params_pin_buffers"): + optimizer.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device="cpu")) + for t in optimizer.fp32_partitioned_groups_flat + ] with unset_fake_temporarily(): # print_r0("Offloading Adam states") @@ -101,10 +135,19 @@ def offload_adam_states_sync(): if "exp_avg_sq" in state: del state["exp_avg_sq"] + for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers): + move_hp_param(src_tensor, dest_buf) + get_accelerator().synchronize() + see_memory_usage("after offload_adam_states_sync", force=True) + def reload_adam_states_sync(): + memory_stats = get_accelerator().memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + + see_memory_usage(f"before reload_adam_states_sync", force=True) with unset_fake_temporarily(): # print_r0("Reloading Adam states") @@ -115,8 +158,15 @@ def reload_adam_states_sync(): if _make_offload_state_key("exp_avg_sq") in state: move_back_key(state, "exp_avg_sq") + for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat): + move_back_hp_param(src, dest) + get_accelerator().synchronize() + memory_stats = get_accelerator().memory_stats() + alloc_retries = memory_stats.get("num_alloc_retries") + see_memory_usage(f"after reload_adam_states_sync", force=True) + def sync_offload_states(event=None): if nz3.is_profiling(): @@ -141,27 +191,43 @@ def sync_reload_states(event=None): def make_offload_task(task): def run_offload_task(): - if not nz3.is_profiling(): - # print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}") + # if not nz3.is_profiling(): + # print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}") + + if offload_key_events.get(task[1]) is None: + offload_key_events[task[1]] = get_accelerator().Event() + + if task[2] == "hp_param": + move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]]) + else: assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer" state = optimizer.state[task[1]] - if offload_key_events.get(task[1]) is None: - offload_key_events[task[1]] = get_accelerator().Event() + # if offload_key_events.get(task[1]) is None: + # offload_key_events[task[1]] = get_accelerator().Event() move_key(state, task[2], offload_key_events[task[1]]) + from deepspeed.runtime.utils import see_memory_usage + see_memory_usage(f"run_offload_task offload_opt_{task[0]}_{task[2]} alloc_mem={get_accelerator().memory_allocated()}", force=True) + return run_offload_task def make_offload_sync(task): + from deepspeed.runtime.utils import see_memory_usage def run_offload_sync(): - if not nz3.is_profiling(): - event = offload_key_events[task[1]] - event.synchronize() + # if not nz3.is_profiling(): + see_memory_usage(f"run_offload_sync start {task[0]} {task[2]}", force=True) + event = offload_key_events[task[1]] + event.synchronize() + + if task[2] != "hp_param": state = optimizer.state[task[1]] key = task[2] - del state[key] - # print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}") + if key in state: + del state[key] + # print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}") + see_memory_usage(f"run_offload_sync finish {task[0]} {task[2]}", force=True) return run_offload_sync @@ -170,26 +236,42 @@ def make_reload_task(task): def run_reload_task(): if not nz3.is_profiling(): - state = optimizer.state[task[1]] if reload_key_events.get(task[1]) is None: reload_key_events[task[1]] = get_accelerator().Event() - # print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}") - move_back_key(state, task[2], reload_key_events[task[1]]) - # alloc_mem = get_accelerator().memory_allocated() - # print_r0(f"run_reload_task reload_opt_{task[0]}_{task[2]} alloc_mem={alloc_mem}") + if task[2] == "hp_param": + move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]]) + else: + state = optimizer.state[task[1]] + # print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}") + move_back_key(state, task[2], reload_key_events[task[1]]) + + # alloc_mem = get_accelerator().memory_allocated() + # print_r0(f"run_reload_task reload_opt_{task[0]}_{task[2]} alloc_mem={alloc_mem}") + from deepspeed.runtime.utils import see_memory_usage + see_memory_usage(f"run_reload_task reload_opt_{task[0]}_{task[2]} profiling={nz3.is_profiling()}", force=True) return run_reload_task -def update_max_memory(): +def update_max_memory(name): + global max_memory mem = get_accelerator().max_memory_allocated() max_memory = max(max_memory, mem) + # see_memory_usage(f"update_max_memory {name}", force=True) + + +def empty_cache(): + see_memory_usage(f"empty_cache start", force=True) + get_accelerator().empty_cache() + see_memory_usage(f"empty_cache end", force=True) + offload_tasks = [] offload_tasks_remaining = [] +offload_tasks_scheduled = [] reload_task_remaining = [] total_reload_mem = 0 @@ -197,8 +279,6 @@ total_reload_mem = 0 def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult, mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph: - # print_r0(f"offload_opt_states_inc graph {graph_id} bwd={bwd} max_memory={max_memory}") - to_remove = [] for node in graph.nodes: if node.op == 'call_function' and \ @@ -210,6 +290,7 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], accelerator = get_accelerator() total_mem = accelerator.total_memory() * (1 - MARGIN) + print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}") mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem mem_dict = {name: peak for name, alloc_mem, delta, peak in mem} @@ -228,70 +309,87 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], # bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0 # peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem) - global offload_tasks_remaining, reload_tasks_remaining + global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled - # print(f"offload_opt_states_inc bwd={bwd}") if not bwd: is_first_graph = graph_id == graph_order[0][0] - # print_r0( - # f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}") + print_r0( + f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}") # At the beginning of the first graph, we schedule offload tasks to launch all offloading if is_first_graph: - # print_r0(f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") + print_r0(f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") with unset_fake_temporarily(): + offload_adam_states_sync() reload_adam_states_sync() sync_reload_states() reload_size = 0 - for i, (k, state) in enumerate(optimizer.state.items()): + + # for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers): + # move_hp_param(src_tensor, dest_buf) + + + for i, ((k, state), hp_param, hp_param_cpu) in enumerate(zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers)): + print_r0(f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}") + if _make_offload_state_key("exp_avg") in state: key = _make_offload_state_key("exp_avg") size = state[key].numel() * state[key].element_size() - if total_mem < max_memory + reload_size + size: - offload_tasks.append( - (i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype)) - # print_r0(f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") + # if total_mem < max_memory + reload_size + size: + offload_tasks.append( + (i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype)) + print_r0(f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") # else: # print_r0(f"Skipping offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") - reload_size += size + # reload_size += size if _make_offload_state_key("exp_avg_sq") in state: key = _make_offload_state_key("exp_avg_sq") size = state[key].numel() * state[key].element_size() - if total_mem < max_memory + reload_size + size: - offload_tasks.append( - (i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype)) - # print_r0(f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") + # if total_mem < max_memory + reload_size + size: + offload_tasks.append( + (i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype)) + print_r0(f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") # else: # print_r0(f"Skipping offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") - reload_size += size + # reload_size += size + + hp_param_size = hp_param.numel() * hp_param.element_size() + # if total_mem < max_memory + reload_size + hp_param_size: + offload_tasks.append( + (i, (hp_param, hp_param_cpu), "hp_param", hp_param.numel() * hp_param.element_size(), hp_param.dtype)) + print_r0(f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}") + # else: + # print_r0(f"Skipping offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}") + # reload_size += hp_param_size # for t in offload_tasks: # print_r0(f"Offloading task {t[0]} {t[2]} {t[3]}") - inserted_offload = False - for node in graph.nodes: - # print(f"Node: {node.name} mem: {mem_dict[node.name]}") - if node.op != 'placeholder' and not inserted_offload: - # print(f"Inserting offload_opt before {node.name}") - for task in offload_tasks: - name = f"offload_opt_{task[0]}_{task[2]}" - with graph.inserting_before(node): - offload_node = graph.create_node('call_function', - make_offload_task(task), (), {}, - name=name) - inserted_offload = True + # inserted_offload = False + # for node in graph.nodes: + # # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + # if node.op != 'placeholder' and not inserted_offload: + # print_r0(f"Inserting all offload tasks before {node.name}") + # for task in offload_tasks: + # name = f"offload_opt_{task[0]}_{task[2]}" + # with graph.inserting_before(node): + # offload_node = graph.create_node('call_function', + # make_offload_task(task), (), {}, + # name=name) + # inserted_offload = True - offload_tasks_remaining = copy.copy(offload_tasks) + # offload_tasks_remaining = copy.copy(offload_tasks) + + print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") - # print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") for node in graph.nodes: - # print_r0(f"checking sync node insert node: {node.name}") + print_r0(f"checking sync node insert node: {node.name}") if node.name not in peak_mem \ or node.op == 'placeholder' \ @@ -299,47 +397,64 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], continue to_offload = [] - optim_size = sum([task[3] for task in offload_tasks_remaining]) + optim_size = sum([task[3] for task in offload_tasks]) - # print_r0(f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks_remaining)}") + print_r0(f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}") while total_mem - peak_mem[node.name] - optim_size < 0: - if len(offload_tasks_remaining) == 0: + if len(offload_tasks) == 0: break - task = offload_tasks_remaining.pop(0) + task = offload_tasks.pop(0) to_offload.append(task) - optim_size = sum([task[3] for task in offload_tasks_remaining]) - # print_r0(f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks_remaining)}") + optim_size = sum([task[3] for task in offload_tasks]) + print_r0(f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}") for task in to_offload: with graph.inserting_before(node): graph.create_node('call_function', make_offload_sync(task), (), {}, name=f"offload_opt_sync_{task[0]}_{task[2]}") - # print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}") + print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}") + offload_tasks_scheduled.append(task) - # print_r0(f"offload_opt_states_inc graph {graph_id} fwd graph {graph}") + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op != 'placeholder': + print_r0(f"Inserting all offload tasks before {node.name}") + for task in offload_tasks_scheduled: + name = f"offload_opt_{task[0]}_{task[2]}" + with graph.inserting_before(node): + offload_node = graph.create_node('call_function', + make_offload_task(task), (), {}, + name=name) + break + # print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}") + print_r0(f"offload_opt_states_inc finish graph {graph_id}") else: graph_order_with_backward = [g[0] for g in graph_order if g[1]] is_first_graph = graph_id == graph_order_with_backward[-1] is_last_graph = graph_id == graph_order_with_backward[0] - # print_r0(f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}") + print_r0(f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}") if is_first_graph: inserted_sync = False for node in graph.nodes: if node.op != 'placeholder' and not inserted_sync: # print(f"Inserting offload_sync before {node.name}") - for task in offload_tasks_remaining: - name = f"offload_opt_sync_{task[0]}_{task[2]}" - with graph.inserting_before(node): - graph.create_node('call_function', make_offload_sync(task), (), {}, name=name) - # print_r0(f"Inserting bwd offload_opt_sync_{task[0]}_{task[2]}") + with graph.inserting_before(node): + graph.create_node('call_function', empty_cache, (), {}, name="empty_cache") + + # for task in offload_tasks_remaining: + # name = f"offload_opt_sync_{task[0]}_{task[2]}" + # with graph.inserting_before(node): + # graph.create_node('call_function', make_offload_sync(task), (), {}, name=name) + # print_r0(f"Inserting bwd offload_opt_sync_{task[0]}_{task[2]}") + inserted_sync = True - reload_tasks_remaining = copy.copy(offload_tasks) + reload_tasks_remaining = copy.copy(offload_tasks_scheduled) global total_reload_mem for node in graph.nodes: @@ -356,9 +471,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], insert_pos = node while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem: expected_mem = peak_mem[node.name] + total_reload_mem - # print_r0( - # f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}" - # ) + print_r0( + f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}" + ) with graph.inserting_after(insert_pos): insert_pos = graph.create_node('call_function', @@ -389,9 +504,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], with graph.inserting_before(node): graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream") - # print_r0( - # f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph} {graph}" - # ) + print_r0( + f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph} {graph}" + ) return graph @@ -405,18 +520,25 @@ def add_record_max_mem_nodes(graph: Graph): with graph.inserting_after(node): name = f"update_max_memory_{node.name}" - graph.create_node('call_function', update_max_memory, (), {}, name=name) + graph.create_node('call_function', update_max_memory, (name, ), {}, name=name) def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult, mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph: + from deepspeed.runtime.utils import see_memory_usage + if bwd: + graph_order_with_backward = [g[0] for g in graph_order if g[1]] is_last_graph = graph_id == graph_order_with_backward[0] - if not is_last_graph: - return graph + see_memory_usage( + f"insert_offload_opt_states bwd={bwd} graph_id={graph_id} graph_order={graph_order} is_last_graph={is_last_graph} starting", + force=True) + + # if not is_last_graph: + # return graph inserted_reload = False for node in graph.nodes: @@ -426,8 +548,14 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int with graph.inserting_before(node): graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt") inserted_reload = True + + # add_record_max_mem_nodes(graph) + else: is_first_graph = graph_id == graph_order[0][0] + see_memory_usage( + f"insert_offload_opt_states bwd={bwd} graph_id={graph_id} graph_order={graph_order} is_first_graph {is_first_graph} starting", + force=True) graph = move_primals_to_head(graph) @@ -435,13 +563,17 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int for node in graph.nodes: # print(f"Node: {node.name} mem: {mem_dict[node.name]}") if node.op != 'placeholder' and not inserted_offload and is_first_graph: - # print(f"Inserting offload_opt before {node.name}") + print(f"Inserting offload_opt before {node.name}") with graph.inserting_before(node): graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt") inserted_offload = True add_record_max_mem_nodes(graph) + # see_memory_usage( + # f"insert_offload_opt_states bwd={bwd} graph_id={graph_id} graph_order={graph_order} finished {graph}", + # force=True) + return graph @@ -459,6 +591,13 @@ def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], return gm +def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + if not bwd: + offload_adam_states_sync() + # returns None, and profiling will be skipped + + def init_offload_opt_states(adam_optimizer, _nz3): lazy_init() From 37134a3d0d98b288931397d66906ec21d295c954 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 11 Apr 2025 20:14:29 +0000 Subject: [PATCH 2/2] remove debugging code Signed-off-by: Masahiro Tanaka --- deepspeed/compile/backend.py | 31 ++- .../compile/passes/offload_adam_states.py | 181 ++++++------------ deepspeed/compile/profilers/graph_profile.py | 32 +++- 3 files changed, 109 insertions(+), 135 deletions(-) diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py index a38154afd..ee33447aa 100644 --- a/deepspeed/compile/backend.py +++ b/deepspeed/compile/backend.py @@ -5,6 +5,7 @@ from typing import Dict, List, Callable import time +import gc import torch from torch.fx import Graph, GraphModule @@ -15,6 +16,7 @@ try: import torch._inductor.scheduler from functorch.compile import make_boxed_func from torch._functorch.aot_autograd import aot_module_simplified + from torch._subclasses.fake_tensor import unset_fake_temporarily except ImportError: pass @@ -108,18 +110,31 @@ def run_opt_passes(opt_passes: List[Callable], bwd: bool, debug_log=False) -> None: + with unset_fake_temporarily(): + get_accelerator().synchronize() + gc.collect() + get_accelerator().empty_cache() + for i, opt_pass_fn in enumerate(opt_passes): log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log) - opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, bwd) - gm.graph.lint() - gm.recompile() + gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, + bwd) + if gm_new is not None: + gm = gm_new + gm.graph.lint() + gm.recompile() - mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log) - mem_prof.run(*create_inputs_fn()) - mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] + mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log) + mem_prof.run(*create_inputs_fn()) + mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] - set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results) + set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results) + + with unset_fake_temporarily(): + get_accelerator().synchronize() + gc.collect() + get_accelerator().empty_cache() def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False): @@ -142,7 +157,7 @@ def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=Fa if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id" param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs) if isinstance(input_val, torch.nn.Parameter)] - + global fwd_real_inputs fwd_real_inputs.append(real_inputs) diff --git a/deepspeed/compile/passes/offload_adam_states.py b/deepspeed/compile/passes/offload_adam_states.py index ee5381566..458d07f39 100644 --- a/deepspeed/compile/passes/offload_adam_states.py +++ b/deepspeed/compile/passes/offload_adam_states.py @@ -24,10 +24,6 @@ from ..fx import move_primals_to_head import deepspeed.comm as dist - -from deepspeed.runtime.utils import see_memory_usage - - NAME = "offload_adam_states" @@ -83,8 +79,10 @@ def move_key(state, key, key_event=None): def move_back_key(state, key, key_event=None): + with get_accelerator().stream(copy_stream): - state[key] = state[_make_offload_state_key(key)].to(device, non_blocking=True) + state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device) + state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True) if key_event is None: reload_event.record(stream=copy_stream) @@ -93,9 +91,10 @@ def move_back_key(state, key, key_event=None): def move_hp_param(src_tensor, dest_buf, key_event=None): - dest_buf.copy_(src_tensor, non_blocking=True) - src_tensor.data = dest_buf - + with get_accelerator().stream(copy_stream): + dest_buf.copy_(src_tensor, non_blocking=True) + src_tensor.data = dest_buf + if key_event is None: reload_event.record(stream=copy_stream) else: @@ -103,9 +102,10 @@ def move_hp_param(src_tensor, dest_buf, key_event=None): def move_back_hp_param(src_tensor, dest_buf, key_event=None): - # for src, dest in zip(self.hp_params_pin_buffers, self.fp32_partitioned_groups_flat): - dest_buf.data = src_tensor.to(device, non_blocking=True) - + with get_accelerator().stream(copy_stream): + dest_buf.data = torch.empty_like(src_tensor, device=device) + dest_buf.copy_(src_tensor, non_blocking=True) + if key_event is None: reload_event.record(stream=copy_stream) else: @@ -113,16 +113,15 @@ def move_back_hp_param(src_tensor, dest_buf, key_event=None): def offload_adam_states_sync(): - see_memory_usage("before offload_adam_states_sync", force=True) - - if not hasattr(optimizer, "hp_params_pin_buffers"): - optimizer.hp_params_pin_buffers = [ - get_accelerator().pin_memory(torch.empty_like(t, device="cpu")) - for t in optimizer.fp32_partitioned_groups_flat - ] with unset_fake_temporarily(): - # print_r0("Offloading Adam states") + + if not hasattr(optimizer, "hp_params_pin_buffers"): + optimizer.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device="cpu")) + for t in optimizer.fp32_partitioned_groups_flat + ] + for i, (k, state) in enumerate(optimizer.state.items()): if "exp_avg" in state: move_key(state, "exp_avg") @@ -140,14 +139,8 @@ def offload_adam_states_sync(): get_accelerator().synchronize() - see_memory_usage("after offload_adam_states_sync", force=True) - def reload_adam_states_sync(): - memory_stats = get_accelerator().memory_stats() - alloc_retries = memory_stats.get("num_alloc_retries") - - see_memory_usage(f"before reload_adam_states_sync", force=True) with unset_fake_temporarily(): # print_r0("Reloading Adam states") @@ -163,10 +156,6 @@ def reload_adam_states_sync(): get_accelerator().synchronize() - memory_stats = get_accelerator().memory_stats() - alloc_retries = memory_stats.get("num_alloc_retries") - see_memory_usage(f"after reload_adam_states_sync", force=True) - def sync_offload_states(event=None): if nz3.is_profiling(): @@ -206,18 +195,13 @@ def make_offload_task(task): # offload_key_events[task[1]] = get_accelerator().Event() move_key(state, task[2], offload_key_events[task[1]]) - from deepspeed.runtime.utils import see_memory_usage - see_memory_usage(f"run_offload_task offload_opt_{task[0]}_{task[2]} alloc_mem={get_accelerator().memory_allocated()}", force=True) - return run_offload_task def make_offload_sync(task): - from deepspeed.runtime.utils import see_memory_usage def run_offload_sync(): # if not nz3.is_profiling(): - see_memory_usage(f"run_offload_sync start {task[0]} {task[2]}", force=True) event = offload_key_events[task[1]] event.synchronize() @@ -227,7 +211,6 @@ def make_offload_sync(task): if key in state: del state[key] # print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}") - see_memory_usage(f"run_offload_sync finish {task[0]} {task[2]}", force=True) return run_offload_sync @@ -246,11 +229,6 @@ def make_reload_task(task): # print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}") move_back_key(state, task[2], reload_key_events[task[1]]) - # alloc_mem = get_accelerator().memory_allocated() - # print_r0(f"run_reload_task reload_opt_{task[0]}_{task[2]} alloc_mem={alloc_mem}") - from deepspeed.runtime.utils import see_memory_usage - see_memory_usage(f"run_reload_task reload_opt_{task[0]}_{task[2]} profiling={nz3.is_profiling()}", force=True) - return run_reload_task @@ -259,14 +237,10 @@ def update_max_memory(name): global max_memory mem = get_accelerator().max_memory_allocated() max_memory = max(max_memory, mem) - # see_memory_usage(f"update_max_memory {name}", force=True) def empty_cache(): - see_memory_usage(f"empty_cache start", force=True) get_accelerator().empty_cache() - see_memory_usage(f"empty_cache end", force=True) - offload_tasks = [] @@ -313,12 +287,15 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], if not bwd: is_first_graph = graph_id == graph_order[0][0] - print_r0( - f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}") + # print_r0( + # f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}" + # ) # At the beginning of the first graph, we schedule offload tasks to launch all offloading if is_first_graph: - print_r0(f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") + # print_r0( + # f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}" + # ) with unset_fake_temporarily(): offload_adam_states_sync() @@ -327,12 +304,11 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], reload_size = 0 - # for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers): - # move_hp_param(src_tensor, dest_buf) - - - for i, ((k, state), hp_param, hp_param_cpu) in enumerate(zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers)): - print_r0(f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}") + for i, ((k, state), hp_param, hp_param_cpu) in enumerate( + zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat, + optimizer.hp_params_pin_buffers)): + # print_r0( + # f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}") if _make_offload_state_key("exp_avg") in state: key = _make_offload_state_key("exp_avg") @@ -341,10 +317,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], # if total_mem < max_memory + reload_size + size: offload_tasks.append( (i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype)) - print_r0(f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") - # else: - # print_r0(f"Skipping offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") - # reload_size += size + # print_r0( + # f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}" + # ) if _make_offload_state_key("exp_avg_sq") in state: key = _make_offload_state_key("exp_avg_sq") @@ -353,43 +328,22 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], # if total_mem < max_memory + reload_size + size: offload_tasks.append( (i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype)) - print_r0(f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") - # else: - # print_r0(f"Skipping offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}") - # reload_size += size + # print_r0( + # f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}" + # ) hp_param_size = hp_param.numel() * hp_param.element_size() # if total_mem < max_memory + reload_size + hp_param_size: - offload_tasks.append( - (i, (hp_param, hp_param_cpu), "hp_param", hp_param.numel() * hp_param.element_size(), hp_param.dtype)) - print_r0(f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}") - # else: - # print_r0(f"Skipping offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}") - # reload_size += hp_param_size - - # for t in offload_tasks: - # print_r0(f"Offloading task {t[0]} {t[2]} {t[3]}") - - # inserted_offload = False - # for node in graph.nodes: - # # print(f"Node: {node.name} mem: {mem_dict[node.name]}") - # if node.op != 'placeholder' and not inserted_offload: - # print_r0(f"Inserting all offload tasks before {node.name}") - # for task in offload_tasks: - # name = f"offload_opt_{task[0]}_{task[2]}" - # with graph.inserting_before(node): - # offload_node = graph.create_node('call_function', - # make_offload_task(task), (), {}, - # name=name) - # inserted_offload = True - - # offload_tasks_remaining = copy.copy(offload_tasks) - - print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") + offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param", + hp_param.numel() * hp_param.element_size(), hp_param.dtype)) + # print_r0( + # f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}" + # ) + # print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") for node in graph.nodes: - print_r0(f"checking sync node insert node: {node.name}") + # print_r0(f"checking sync node insert node: {node.name}") if node.name not in peak_mem \ or node.op == 'placeholder' \ @@ -399,7 +353,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], to_offload = [] optim_size = sum([task[3] for task in offload_tasks]) - print_r0(f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}") + # print_r0( + # f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}" + # ) while total_mem - peak_mem[node.name] - optim_size < 0: if len(offload_tasks) == 0: break @@ -407,7 +363,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], task = offload_tasks.pop(0) to_offload.append(task) optim_size = sum([task[3] for task in offload_tasks]) - print_r0(f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}") + # print_r0( + # f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}" + # ) for task in to_offload: with graph.inserting_before(node): @@ -424,9 +382,7 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], for task in offload_tasks_scheduled: name = f"offload_opt_{task[0]}_{task[2]}" with graph.inserting_before(node): - offload_node = graph.create_node('call_function', - make_offload_task(task), (), {}, - name=name) + offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name) break # print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}") @@ -437,7 +393,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], is_first_graph = graph_id == graph_order_with_backward[-1] is_last_graph = graph_id == graph_order_with_backward[0] - print_r0(f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}") + # print_r0( + # f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}" + # ) if is_first_graph: inserted_sync = False @@ -447,12 +405,6 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], with graph.inserting_before(node): graph.create_node('call_function', empty_cache, (), {}, name="empty_cache") - # for task in offload_tasks_remaining: - # name = f"offload_opt_sync_{task[0]}_{task[2]}" - # with graph.inserting_before(node): - # graph.create_node('call_function', make_offload_sync(task), (), {}, name=name) - # print_r0(f"Inserting bwd offload_opt_sync_{task[0]}_{task[2]}") - inserted_sync = True reload_tasks_remaining = copy.copy(offload_tasks_scheduled) @@ -505,7 +457,7 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream") print_r0( - f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph} {graph}" + f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}" ) return graph @@ -526,20 +478,10 @@ def add_record_max_mem_nodes(graph: Graph): def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult, mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph: - from deepspeed.runtime.utils import see_memory_usage - if bwd: - graph_order_with_backward = [g[0] for g in graph_order if g[1]] is_last_graph = graph_id == graph_order_with_backward[0] - see_memory_usage( - f"insert_offload_opt_states bwd={bwd} graph_id={graph_id} graph_order={graph_order} is_last_graph={is_last_graph} starting", - force=True) - - # if not is_last_graph: - # return graph - inserted_reload = False for node in graph.nodes: # print(f"Node: {node.name} mem: {mem_dict[node.name]}") @@ -553,9 +495,6 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int else: is_first_graph = graph_id == graph_order[0][0] - see_memory_usage( - f"insert_offload_opt_states bwd={bwd} graph_id={graph_id} graph_order={graph_order} is_first_graph {is_first_graph} starting", - force=True) graph = move_primals_to_head(graph) @@ -570,10 +509,6 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int add_record_max_mem_nodes(graph) - # see_memory_usage( - # f"insert_offload_opt_states bwd={bwd} graph_id={graph_id} graph_order={graph_order} finished {graph}", - # force=True) - return graph @@ -591,10 +526,12 @@ def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], return gm -def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, - mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: - if not bwd: - offload_adam_states_sync() +def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> GraphModule: + if not bwd and graph_id == graph_order[0][0]: + with unset_fake_temporarily(): + offload_adam_states_sync() # returns None, and profiling will be skipped diff --git a/deepspeed/compile/profilers/graph_profile.py b/deepspeed/compile/profilers/graph_profile.py index 4ca552150..1a9b12c89 100644 --- a/deepspeed/compile/profilers/graph_profile.py +++ b/deepspeed/compile/profilers/graph_profile.py @@ -51,6 +51,26 @@ def _node_size(out): return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)]) +def _get_mem_usage_out_of_torch(): + + adjust = 0 + try: + import pynvml + pynvml.nvmlInit() + + current_dev_id = get_accelerator().current_device() + handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + torch_alloc = get_accelerator().memory_allocated() + adjust = info.used - torch_alloc + except: + # pynvml not available + pass + + return adjust + + # https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html class ProfilingInterpreter(Interpreter): @@ -68,6 +88,7 @@ class ProfilingInterpreter(Interpreter): self.distributed = dist.is_initialized() self.allgather_mem: Dict[int, int] = {} self.debug_log = debug_log + self.mem_usage_out_of_torch = 0 def run(self, *args) -> Any: """Run the graph with profiling enabled. @@ -81,6 +102,7 @@ class ProfilingInterpreter(Interpreter): with unset_fake_temporarily(): with get_accelerator().random().fork_rng(devices=[self.device]): + self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() return_val = super().run(*args) except Exception as e: msg = e.msg if "msg" in dir(e) else str(e) @@ -160,8 +182,8 @@ class ProfilingInterpreter(Interpreter): if is_comm_op(n): dist.barrier() - alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start - max_memory = get_accelerator().max_memory_allocated() - max_mem_start + alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch + max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch tensor_size = _node_size(out) def partition_param_if_necessary(v): @@ -223,7 +245,7 @@ class MemoryProfilingInterpreter(Interpreter): try: assert _all_real_if_tensor(args), "Inputs must be real tensors" self.nz3.enable_profiling(True) - self.mem_adjustment = 0 + self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() with unset_fake_temporarily(): with get_accelerator().random().fork_rng(devices=[self.device]): @@ -248,8 +270,8 @@ class MemoryProfilingInterpreter(Interpreter): del args, kwargs - current_alloc = get_accelerator().memory_allocated() - max_alloc = get_accelerator().max_memory_allocated() + current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch + max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device) dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX) current_alloc = vals_to_bcast[0].item()