Files
DeepSpeed/deepspeed/compile/passes/offload_adam_states.py
Masahiro Tanaka 37134a3d0d remove debugging code
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
2025-04-11 20:14:29 +00:00

547 lines
20 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import copy
from typing import List
import torch
from torch.fx import Graph, GraphModule
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.offload_states import _make_offload_state_key
try:
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
# Unsupported torch version
pass
from ..profilers import ProfilingResult
from ..graph_param import DSGraphParamManager
from ..fx import move_primals_to_head
import deepspeed.comm as dist
NAME = "offload_adam_states"
def print_r0(msg):
if dist.get_rank() == 0:
print(msg)
MARGIN = 0.2
copy_stream = None
offload_event = None
reload_event = None
offload_key_events = {}
reload_key_events = {}
max_memory = 0
def lazy_init():
global copy_stream
global offload_event
global reload_event
if copy_stream is None:
copy_stream = get_accelerator().Stream()
offload_event = get_accelerator().Event()
reload_event = get_accelerator().Event()
optimizer = None
device = None
nz3 = None
def move_key(state, key, key_event=None):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu"))
if key not in state:
return
with get_accelerator().stream(copy_stream):
state[offload_buf_key].copy_(state[key], non_blocking=True)
if key_event is None:
offload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def move_back_key(state, key, key_event=None):
with get_accelerator().stream(copy_stream):
state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device)
state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True)
if key_event is None:
reload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def move_hp_param(src_tensor, dest_buf, key_event=None):
with get_accelerator().stream(copy_stream):
dest_buf.copy_(src_tensor, non_blocking=True)
src_tensor.data = dest_buf
if key_event is None:
reload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def move_back_hp_param(src_tensor, dest_buf, key_event=None):
with get_accelerator().stream(copy_stream):
dest_buf.data = torch.empty_like(src_tensor, device=device)
dest_buf.copy_(src_tensor, non_blocking=True)
if key_event is None:
reload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def offload_adam_states_sync():
with unset_fake_temporarily():
if not hasattr(optimizer, "hp_params_pin_buffers"):
optimizer.hp_params_pin_buffers = [
get_accelerator().pin_memory(torch.empty_like(t, device="cpu"))
for t in optimizer.fp32_partitioned_groups_flat
]
for i, (k, state) in enumerate(optimizer.state.items()):
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")
for _, state in optimizer.state.items():
if "exp_avg" in state:
del state["exp_avg"]
if "exp_avg_sq" in state:
del state["exp_avg_sq"]
for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers):
move_hp_param(src_tensor, dest_buf)
get_accelerator().synchronize()
def reload_adam_states_sync():
with unset_fake_temporarily():
# print_r0("Reloading Adam states")
for _, state in optimizer.state.items():
if _make_offload_state_key("exp_avg") in state:
move_back_key(state, "exp_avg")
if _make_offload_state_key("exp_avg_sq") in state:
move_back_key(state, "exp_avg_sq")
for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat):
move_back_hp_param(src, dest)
get_accelerator().synchronize()
def sync_offload_states(event=None):
if nz3.is_profiling():
offload_adam_states_sync()
else:
if event is None:
offload_event.wait(copy_stream)
else:
event.wait(copy_stream)
def sync_reload_states(event=None):
if nz3.is_profiling():
reload_adam_states_sync()
else:
if event is None:
reload_event.wait(copy_stream)
else:
event.wait(copy_stream)
def make_offload_task(task):
def run_offload_task():
# if not nz3.is_profiling():
# print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}")
if offload_key_events.get(task[1]) is None:
offload_key_events[task[1]] = get_accelerator().Event()
if task[2] == "hp_param":
move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]])
else:
assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer"
state = optimizer.state[task[1]]
# if offload_key_events.get(task[1]) is None:
# offload_key_events[task[1]] = get_accelerator().Event()
move_key(state, task[2], offload_key_events[task[1]])
return run_offload_task
def make_offload_sync(task):
def run_offload_sync():
# if not nz3.is_profiling():
event = offload_key_events[task[1]]
event.synchronize()
if task[2] != "hp_param":
state = optimizer.state[task[1]]
key = task[2]
if key in state:
del state[key]
# print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}")
return run_offload_sync
def make_reload_task(task):
def run_reload_task():
if not nz3.is_profiling():
if reload_key_events.get(task[1]) is None:
reload_key_events[task[1]] = get_accelerator().Event()
if task[2] == "hp_param":
move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]])
else:
state = optimizer.state[task[1]]
# print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}")
move_back_key(state, task[2], reload_key_events[task[1]])
return run_reload_task
def update_max_memory(name):
global max_memory
mem = get_accelerator().max_memory_allocated()
max_memory = max(max_memory, mem)
def empty_cache():
get_accelerator().empty_cache()
offload_tasks = []
offload_tasks_remaining = []
offload_tasks_scheduled = []
reload_task_remaining = []
total_reload_mem = 0
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
to_remove = []
for node in graph.nodes:
if node.op == 'call_function' and \
node.target in [offload_adam_states_sync, sync_offload_states, reload_adam_states_sync, sync_reload_states, update_max_memory]:
to_remove.append(node)
for node in to_remove:
graph.erase_node(node)
accelerator = get_accelerator()
total_mem = accelerator.total_memory() * (1 - MARGIN)
print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}")
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
mem_dict = {name: peak for name, alloc_mem, delta, peak in mem}
current_peak_mem = 0
peak_mem = {}
ordered_node = reversed(graph.nodes) if bwd else graph.nodes
for node in ordered_node:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if mem_dict[node.name] > current_peak_mem:
current_peak_mem = mem_dict[node.name]
peak_mem[node.name] = current_peak_mem
# fwd_max_mem = max(m[3] for m in prof.fwd_mem)
# bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
# peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled
if not bwd:
is_first_graph = graph_id == graph_order[0][0]
# print_r0(
# f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}"
# )
# At the beginning of the first graph, we schedule offload tasks to launch all offloading
if is_first_graph:
# print_r0(
# f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}"
# )
with unset_fake_temporarily():
offload_adam_states_sync()
reload_adam_states_sync()
sync_reload_states()
reload_size = 0
for i, ((k, state), hp_param, hp_param_cpu) in enumerate(
zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat,
optimizer.hp_params_pin_buffers)):
# print_r0(
# f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}")
if _make_offload_state_key("exp_avg") in state:
key = _make_offload_state_key("exp_avg")
size = state[key].numel() * state[key].element_size()
# if total_mem < max_memory + reload_size + size:
offload_tasks.append(
(i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype))
# print_r0(
# f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
# )
if _make_offload_state_key("exp_avg_sq") in state:
key = _make_offload_state_key("exp_avg_sq")
size = state[key].numel() * state[key].element_size()
# if total_mem < max_memory + reload_size + size:
offload_tasks.append(
(i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype))
# print_r0(
# f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
# )
hp_param_size = hp_param.numel() * hp_param.element_size()
# if total_mem < max_memory + reload_size + hp_param_size:
offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param",
hp_param.numel() * hp_param.element_size(), hp_param.dtype))
# print_r0(
# f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}"
# )
# print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}")
for node in graph.nodes:
# print_r0(f"checking sync node insert node: {node.name}")
if node.name not in peak_mem \
or node.op == 'placeholder' \
or "offload_opt_" in node.name:
continue
to_offload = []
optim_size = sum([task[3] for task in offload_tasks])
# print_r0(
# f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
# )
while total_mem - peak_mem[node.name] - optim_size < 0:
if len(offload_tasks) == 0:
break
task = offload_tasks.pop(0)
to_offload.append(task)
optim_size = sum([task[3] for task in offload_tasks])
# print_r0(
# f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
# )
for task in to_offload:
with graph.inserting_before(node):
graph.create_node('call_function',
make_offload_sync(task), (), {},
name=f"offload_opt_sync_{task[0]}_{task[2]}")
print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}")
offload_tasks_scheduled.append(task)
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op != 'placeholder':
print_r0(f"Inserting all offload tasks before {node.name}")
for task in offload_tasks_scheduled:
name = f"offload_opt_{task[0]}_{task[2]}"
with graph.inserting_before(node):
offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name)
break
# print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}")
print_r0(f"offload_opt_states_inc finish graph {graph_id}")
else:
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
is_first_graph = graph_id == graph_order_with_backward[-1]
is_last_graph = graph_id == graph_order_with_backward[0]
# print_r0(
# f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
# )
if is_first_graph:
inserted_sync = False
for node in graph.nodes:
if node.op != 'placeholder' and not inserted_sync:
# print(f"Inserting offload_sync before {node.name}")
with graph.inserting_before(node):
graph.create_node('call_function', empty_cache, (), {}, name="empty_cache")
inserted_sync = True
reload_tasks_remaining = copy.copy(offload_tasks_scheduled)
global total_reload_mem
for node in graph.nodes:
if node.name not in peak_mem \
or node.op == 'placeholder' \
or node.op == 'output' \
or "offload_opt_sync_" in node.name:
continue
if len(reload_tasks_remaining) > 0:
task = reload_tasks_remaining[0]
next_reload_mem = task[3]
insert_pos = node
while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem:
expected_mem = peak_mem[node.name] + total_reload_mem
print_r0(
f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}"
)
with graph.inserting_after(insert_pos):
insert_pos = graph.create_node('call_function',
make_reload_task(task), (), {},
name=f"reload_opt_{task[0]}_{task[2]}")
total_reload_mem += next_reload_mem
reload_tasks_remaining.pop(0)
if len(reload_tasks_remaining) == 0:
break
task = reload_tasks_remaining[0]
next_reload_mem = task[3]
# prev_node = node
if is_last_graph:
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op == 'output':
for task in reload_tasks_remaining:
with graph.inserting_before(node):
graph.create_node('call_function',
make_reload_task(task), (), {},
name=f"reload_opt_{task[0]}_{task[2]}")
sync_fn = lambda: copy_stream.synchronize()
with graph.inserting_before(node):
graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream")
print_r0(
f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
)
return graph
def add_record_max_mem_nodes(graph: Graph):
nodes = list(graph.nodes)
for node in nodes:
if node.op == "output" or node.op == "placeholder":
continue
with graph.inserting_after(node):
name = f"update_max_memory_{node.name}"
graph.create_node('call_function', update_max_memory, (name, ), {}, name=name)
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
if bwd:
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
is_last_graph = graph_id == graph_order_with_backward[0]
inserted_reload = False
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op == 'output' and not inserted_reload and is_last_graph:
# print(f"Inserting reload_opt before {node.name}")
with graph.inserting_before(node):
graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt")
inserted_reload = True
# add_record_max_mem_nodes(graph)
else:
is_first_graph = graph_id == graph_order[0][0]
graph = move_primals_to_head(graph)
inserted_offload = False
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op != 'placeholder' and not inserted_offload and is_first_graph:
print(f"Inserting offload_opt before {node.name}")
with graph.inserting_before(node):
graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt")
inserted_offload = True
add_record_max_mem_nodes(graph)
return graph
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
bwd)
return gm
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
bwd)
return gm
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
if not bwd and graph_id == graph_order[0][0]:
with unset_fake_temporarily():
offload_adam_states_sync()
# returns None, and profiling will be skipped
def init_offload_opt_states(adam_optimizer, _nz3):
lazy_init()
global optimizer
optimizer = adam_optimizer
global device
device = torch.device(get_accelerator().current_device())
global nz3
nz3 = _nz3