mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
Merge branch 'tohtana/deepcompile' into tohtana/deepcompile_fix_scheduling
This commit is contained in:
@ -5,6 +5,7 @@
|
||||
|
||||
from typing import Dict, List, Callable
|
||||
import time
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule
|
||||
@ -15,6 +16,7 @@ try:
|
||||
import torch._inductor.scheduler
|
||||
from functorch.compile import make_boxed_func
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -108,18 +110,31 @@ def run_opt_passes(opt_passes: List[Callable],
|
||||
bwd: bool,
|
||||
debug_log=False) -> None:
|
||||
|
||||
with unset_fake_temporarily():
|
||||
get_accelerator().synchronize()
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
for i, opt_pass_fn in enumerate(opt_passes):
|
||||
log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log)
|
||||
|
||||
opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, bwd)
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager,
|
||||
bwd)
|
||||
if gm_new is not None:
|
||||
gm = gm_new
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log)
|
||||
mem_prof.run(*create_inputs_fn())
|
||||
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
|
||||
mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log)
|
||||
mem_prof.run(*create_inputs_fn())
|
||||
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
|
||||
|
||||
set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results)
|
||||
set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results)
|
||||
|
||||
with unset_fake_temporarily():
|
||||
get_accelerator().synchronize()
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
|
||||
def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False):
|
||||
|
@ -66,6 +66,9 @@ def move_key(state, key, key_event=None):
|
||||
if offload_buf_key not in state:
|
||||
state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu"))
|
||||
|
||||
if key not in state:
|
||||
return
|
||||
|
||||
with get_accelerator().stream(copy_stream):
|
||||
state[offload_buf_key].copy_(state[key], non_blocking=True)
|
||||
|
||||
@ -76,8 +79,32 @@ def move_key(state, key, key_event=None):
|
||||
|
||||
|
||||
def move_back_key(state, key, key_event=None):
|
||||
|
||||
with get_accelerator().stream(copy_stream):
|
||||
state[key] = state[_make_offload_state_key(key)].to(device, non_blocking=True)
|
||||
state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device)
|
||||
state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True)
|
||||
|
||||
if key_event is None:
|
||||
reload_event.record(stream=copy_stream)
|
||||
else:
|
||||
key_event.record(stream=copy_stream)
|
||||
|
||||
|
||||
def move_hp_param(src_tensor, dest_buf, key_event=None):
|
||||
with get_accelerator().stream(copy_stream):
|
||||
dest_buf.copy_(src_tensor, non_blocking=True)
|
||||
src_tensor.data = dest_buf
|
||||
|
||||
if key_event is None:
|
||||
reload_event.record(stream=copy_stream)
|
||||
else:
|
||||
key_event.record(stream=copy_stream)
|
||||
|
||||
|
||||
def move_back_hp_param(src_tensor, dest_buf, key_event=None):
|
||||
with get_accelerator().stream(copy_stream):
|
||||
dest_buf.data = torch.empty_like(src_tensor, device=device)
|
||||
dest_buf.copy_(src_tensor, non_blocking=True)
|
||||
|
||||
if key_event is None:
|
||||
reload_event.record(stream=copy_stream)
|
||||
@ -88,7 +115,13 @@ def move_back_key(state, key, key_event=None):
|
||||
def offload_adam_states_sync():
|
||||
|
||||
with unset_fake_temporarily():
|
||||
# print_r0("Offloading Adam states")
|
||||
|
||||
if not hasattr(optimizer, "hp_params_pin_buffers"):
|
||||
optimizer.hp_params_pin_buffers = [
|
||||
get_accelerator().pin_memory(torch.empty_like(t, device="cpu"))
|
||||
for t in optimizer.fp32_partitioned_groups_flat
|
||||
]
|
||||
|
||||
for i, (k, state) in enumerate(optimizer.state.items()):
|
||||
if "exp_avg" in state:
|
||||
move_key(state, "exp_avg")
|
||||
@ -101,6 +134,9 @@ def offload_adam_states_sync():
|
||||
if "exp_avg_sq" in state:
|
||||
del state["exp_avg_sq"]
|
||||
|
||||
for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers):
|
||||
move_hp_param(src_tensor, dest_buf)
|
||||
|
||||
get_accelerator().synchronize()
|
||||
|
||||
|
||||
@ -115,6 +151,9 @@ def reload_adam_states_sync():
|
||||
if _make_offload_state_key("exp_avg_sq") in state:
|
||||
move_back_key(state, "exp_avg_sq")
|
||||
|
||||
for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat):
|
||||
move_back_hp_param(src, dest)
|
||||
|
||||
get_accelerator().synchronize()
|
||||
|
||||
|
||||
@ -141,12 +180,19 @@ def sync_reload_states(event=None):
|
||||
def make_offload_task(task):
|
||||
|
||||
def run_offload_task():
|
||||
if not nz3.is_profiling():
|
||||
# print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}")
|
||||
# if not nz3.is_profiling():
|
||||
# print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}")
|
||||
|
||||
if offload_key_events.get(task[1]) is None:
|
||||
offload_key_events[task[1]] = get_accelerator().Event()
|
||||
|
||||
if task[2] == "hp_param":
|
||||
move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]])
|
||||
else:
|
||||
assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer"
|
||||
state = optimizer.state[task[1]]
|
||||
if offload_key_events.get(task[1]) is None:
|
||||
offload_key_events[task[1]] = get_accelerator().Event()
|
||||
# if offload_key_events.get(task[1]) is None:
|
||||
# offload_key_events[task[1]] = get_accelerator().Event()
|
||||
move_key(state, task[2], offload_key_events[task[1]])
|
||||
|
||||
return run_offload_task
|
||||
@ -155,13 +201,16 @@ def make_offload_task(task):
|
||||
def make_offload_sync(task):
|
||||
|
||||
def run_offload_sync():
|
||||
if not nz3.is_profiling():
|
||||
event = offload_key_events[task[1]]
|
||||
event.synchronize()
|
||||
# if not nz3.is_profiling():
|
||||
event = offload_key_events[task[1]]
|
||||
event.synchronize()
|
||||
|
||||
if task[2] != "hp_param":
|
||||
state = optimizer.state[task[1]]
|
||||
key = task[2]
|
||||
del state[key]
|
||||
# print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}")
|
||||
if key in state:
|
||||
del state[key]
|
||||
# print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}")
|
||||
|
||||
return run_offload_sync
|
||||
|
||||
@ -170,26 +219,33 @@ def make_reload_task(task):
|
||||
|
||||
def run_reload_task():
|
||||
if not nz3.is_profiling():
|
||||
state = optimizer.state[task[1]]
|
||||
if reload_key_events.get(task[1]) is None:
|
||||
reload_key_events[task[1]] = get_accelerator().Event()
|
||||
# print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}")
|
||||
move_back_key(state, task[2], reload_key_events[task[1]])
|
||||
|
||||
# alloc_mem = get_accelerator().memory_allocated()
|
||||
# print_r0(f"run_reload_task reload_opt_{task[0]}_{task[2]} alloc_mem={alloc_mem}")
|
||||
if task[2] == "hp_param":
|
||||
move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]])
|
||||
else:
|
||||
state = optimizer.state[task[1]]
|
||||
# print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}")
|
||||
move_back_key(state, task[2], reload_key_events[task[1]])
|
||||
|
||||
return run_reload_task
|
||||
|
||||
|
||||
def update_max_memory():
|
||||
def update_max_memory(name):
|
||||
|
||||
global max_memory
|
||||
mem = get_accelerator().max_memory_allocated()
|
||||
max_memory = max(max_memory, mem)
|
||||
|
||||
|
||||
def empty_cache():
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
|
||||
offload_tasks = []
|
||||
offload_tasks_remaining = []
|
||||
offload_tasks_scheduled = []
|
||||
reload_task_remaining = []
|
||||
total_reload_mem = 0
|
||||
|
||||
@ -197,8 +253,6 @@ total_reload_mem = 0
|
||||
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
|
||||
|
||||
# print_r0(f"offload_opt_states_inc graph {graph_id} bwd={bwd} max_memory={max_memory}")
|
||||
|
||||
to_remove = []
|
||||
for node in graph.nodes:
|
||||
if node.op == 'call_function' and \
|
||||
@ -210,6 +264,7 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int],
|
||||
|
||||
accelerator = get_accelerator()
|
||||
total_mem = accelerator.total_memory() * (1 - MARGIN)
|
||||
print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}")
|
||||
|
||||
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
|
||||
mem_dict = {name: peak for name, alloc_mem, delta, peak in mem}
|
||||
@ -228,65 +283,62 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int],
|
||||
# bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
|
||||
# peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
|
||||
|
||||
global offload_tasks_remaining, reload_tasks_remaining
|
||||
global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled
|
||||
|
||||
# print(f"offload_opt_states_inc bwd={bwd}")
|
||||
if not bwd:
|
||||
is_first_graph = graph_id == graph_order[0][0]
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}")
|
||||
# f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}"
|
||||
# )
|
||||
|
||||
# At the beginning of the first graph, we schedule offload tasks to launch all offloading
|
||||
if is_first_graph:
|
||||
# print_r0(f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}")
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}"
|
||||
# )
|
||||
|
||||
with unset_fake_temporarily():
|
||||
offload_adam_states_sync()
|
||||
reload_adam_states_sync()
|
||||
sync_reload_states()
|
||||
|
||||
reload_size = 0
|
||||
for i, (k, state) in enumerate(optimizer.state.items()):
|
||||
|
||||
for i, ((k, state), hp_param, hp_param_cpu) in enumerate(
|
||||
zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat,
|
||||
optimizer.hp_params_pin_buffers)):
|
||||
# print_r0(
|
||||
# f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}")
|
||||
|
||||
if _make_offload_state_key("exp_avg") in state:
|
||||
key = _make_offload_state_key("exp_avg")
|
||||
size = state[key].numel() * state[key].element_size()
|
||||
|
||||
if total_mem < max_memory + reload_size + size:
|
||||
offload_tasks.append(
|
||||
(i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype))
|
||||
# print_r0(f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}")
|
||||
# else:
|
||||
# print_r0(f"Skipping offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}")
|
||||
reload_size += size
|
||||
# if total_mem < max_memory + reload_size + size:
|
||||
offload_tasks.append(
|
||||
(i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype))
|
||||
# print_r0(
|
||||
# f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
|
||||
# )
|
||||
|
||||
if _make_offload_state_key("exp_avg_sq") in state:
|
||||
key = _make_offload_state_key("exp_avg_sq")
|
||||
size = state[key].numel() * state[key].element_size()
|
||||
|
||||
if total_mem < max_memory + reload_size + size:
|
||||
offload_tasks.append(
|
||||
(i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype))
|
||||
# print_r0(f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}")
|
||||
# else:
|
||||
# print_r0(f"Skipping offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}")
|
||||
reload_size += size
|
||||
# if total_mem < max_memory + reload_size + size:
|
||||
offload_tasks.append(
|
||||
(i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype))
|
||||
# print_r0(
|
||||
# f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
|
||||
# )
|
||||
|
||||
# for t in offload_tasks:
|
||||
# print_r0(f"Offloading task {t[0]} {t[2]} {t[3]}")
|
||||
|
||||
inserted_offload = False
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op != 'placeholder' and not inserted_offload:
|
||||
# print(f"Inserting offload_opt before {node.name}")
|
||||
for task in offload_tasks:
|
||||
name = f"offload_opt_{task[0]}_{task[2]}"
|
||||
with graph.inserting_before(node):
|
||||
offload_node = graph.create_node('call_function',
|
||||
make_offload_task(task), (), {},
|
||||
name=name)
|
||||
inserted_offload = True
|
||||
|
||||
offload_tasks_remaining = copy.copy(offload_tasks)
|
||||
hp_param_size = hp_param.numel() * hp_param.element_size()
|
||||
# if total_mem < max_memory + reload_size + hp_param_size:
|
||||
offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param",
|
||||
hp_param.numel() * hp_param.element_size(), hp_param.dtype))
|
||||
# print_r0(
|
||||
# f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}"
|
||||
# )
|
||||
|
||||
# print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}")
|
||||
|
||||
@ -299,47 +351,62 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int],
|
||||
continue
|
||||
|
||||
to_offload = []
|
||||
optim_size = sum([task[3] for task in offload_tasks_remaining])
|
||||
optim_size = sum([task[3] for task in offload_tasks])
|
||||
|
||||
# print_r0(f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks_remaining)}")
|
||||
# print_r0(
|
||||
# f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
|
||||
# )
|
||||
while total_mem - peak_mem[node.name] - optim_size < 0:
|
||||
if len(offload_tasks_remaining) == 0:
|
||||
if len(offload_tasks) == 0:
|
||||
break
|
||||
|
||||
task = offload_tasks_remaining.pop(0)
|
||||
task = offload_tasks.pop(0)
|
||||
to_offload.append(task)
|
||||
optim_size = sum([task[3] for task in offload_tasks_remaining])
|
||||
# print_r0(f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks_remaining)}")
|
||||
optim_size = sum([task[3] for task in offload_tasks])
|
||||
# print_r0(
|
||||
# f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
|
||||
# )
|
||||
|
||||
for task in to_offload:
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function',
|
||||
make_offload_sync(task), (), {},
|
||||
name=f"offload_opt_sync_{task[0]}_{task[2]}")
|
||||
# print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}")
|
||||
print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}")
|
||||
offload_tasks_scheduled.append(task)
|
||||
|
||||
# print_r0(f"offload_opt_states_inc graph {graph_id} fwd graph {graph}")
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op != 'placeholder':
|
||||
print_r0(f"Inserting all offload tasks before {node.name}")
|
||||
for task in offload_tasks_scheduled:
|
||||
name = f"offload_opt_{task[0]}_{task[2]}"
|
||||
with graph.inserting_before(node):
|
||||
offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name)
|
||||
break
|
||||
|
||||
# print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}")
|
||||
print_r0(f"offload_opt_states_inc finish graph {graph_id}")
|
||||
else:
|
||||
|
||||
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
|
||||
is_first_graph = graph_id == graph_order_with_backward[-1]
|
||||
is_last_graph = graph_id == graph_order_with_backward[0]
|
||||
|
||||
# print_r0(f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}")
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
|
||||
# )
|
||||
|
||||
if is_first_graph:
|
||||
inserted_sync = False
|
||||
for node in graph.nodes:
|
||||
if node.op != 'placeholder' and not inserted_sync:
|
||||
# print(f"Inserting offload_sync before {node.name}")
|
||||
for task in offload_tasks_remaining:
|
||||
name = f"offload_opt_sync_{task[0]}_{task[2]}"
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', make_offload_sync(task), (), {}, name=name)
|
||||
# print_r0(f"Inserting bwd offload_opt_sync_{task[0]}_{task[2]}")
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', empty_cache, (), {}, name="empty_cache")
|
||||
|
||||
inserted_sync = True
|
||||
reload_tasks_remaining = copy.copy(offload_tasks)
|
||||
reload_tasks_remaining = copy.copy(offload_tasks_scheduled)
|
||||
|
||||
global total_reload_mem
|
||||
for node in graph.nodes:
|
||||
@ -356,9 +423,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int],
|
||||
insert_pos = node
|
||||
while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem:
|
||||
expected_mem = peak_mem[node.name] + total_reload_mem
|
||||
# print_r0(
|
||||
# f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}"
|
||||
# )
|
||||
print_r0(
|
||||
f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}"
|
||||
)
|
||||
|
||||
with graph.inserting_after(insert_pos):
|
||||
insert_pos = graph.create_node('call_function',
|
||||
@ -389,9 +456,9 @@ def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int],
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream")
|
||||
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph} {graph}"
|
||||
# )
|
||||
print_r0(
|
||||
f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
@ -405,7 +472,7 @@ def add_record_max_mem_nodes(graph: Graph):
|
||||
|
||||
with graph.inserting_after(node):
|
||||
name = f"update_max_memory_{node.name}"
|
||||
graph.create_node('call_function', update_max_memory, (), {}, name=name)
|
||||
graph.create_node('call_function', update_max_memory, (name, ), {}, name=name)
|
||||
|
||||
|
||||
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
|
||||
@ -415,9 +482,6 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int
|
||||
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
|
||||
is_last_graph = graph_id == graph_order_with_backward[0]
|
||||
|
||||
if not is_last_graph:
|
||||
return graph
|
||||
|
||||
inserted_reload = False
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
@ -426,6 +490,9 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt")
|
||||
inserted_reload = True
|
||||
|
||||
# add_record_max_mem_nodes(graph)
|
||||
|
||||
else:
|
||||
is_first_graph = graph_id == graph_order[0][0]
|
||||
|
||||
@ -435,7 +502,7 @@ def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op != 'placeholder' and not inserted_offload and is_first_graph:
|
||||
# print(f"Inserting offload_opt before {node.name}")
|
||||
print(f"Inserting offload_opt before {node.name}")
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt")
|
||||
inserted_offload = True
|
||||
@ -459,6 +526,15 @@ def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int],
|
||||
return gm
|
||||
|
||||
|
||||
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
if not bwd and graph_id == graph_order[0][0]:
|
||||
with unset_fake_temporarily():
|
||||
offload_adam_states_sync()
|
||||
# returns None, and profiling will be skipped
|
||||
|
||||
|
||||
def init_offload_opt_states(adam_optimizer, _nz3):
|
||||
lazy_init()
|
||||
|
||||
|
@ -51,6 +51,26 @@ def _node_size(out):
|
||||
return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)])
|
||||
|
||||
|
||||
def _get_mem_usage_out_of_torch():
|
||||
|
||||
adjust = 0
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
|
||||
current_dev_id = get_accelerator().current_device()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
torch_alloc = get_accelerator().memory_allocated()
|
||||
adjust = info.used - torch_alloc
|
||||
except:
|
||||
# pynvml not available
|
||||
pass
|
||||
|
||||
return adjust
|
||||
|
||||
|
||||
# https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html
|
||||
class ProfilingInterpreter(Interpreter):
|
||||
|
||||
@ -68,6 +88,7 @@ class ProfilingInterpreter(Interpreter):
|
||||
self.distributed = dist.is_initialized()
|
||||
self.allgather_mem: Dict[int, int] = {}
|
||||
self.debug_log = debug_log
|
||||
self.mem_usage_out_of_torch = 0
|
||||
|
||||
def run(self, *args) -> Any:
|
||||
"""Run the graph with profiling enabled.
|
||||
@ -81,6 +102,7 @@ class ProfilingInterpreter(Interpreter):
|
||||
|
||||
with unset_fake_temporarily():
|
||||
with get_accelerator().random().fork_rng(devices=[self.device]):
|
||||
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
|
||||
return_val = super().run(*args)
|
||||
except Exception as e:
|
||||
msg = e.msg if "msg" in dir(e) else str(e)
|
||||
@ -160,8 +182,8 @@ class ProfilingInterpreter(Interpreter):
|
||||
if is_comm_op(n):
|
||||
dist.barrier()
|
||||
|
||||
alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start
|
||||
max_memory = get_accelerator().max_memory_allocated() - max_mem_start
|
||||
alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch
|
||||
max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch
|
||||
tensor_size = _node_size(out)
|
||||
|
||||
def partition_param_if_necessary(v):
|
||||
@ -223,7 +245,7 @@ class MemoryProfilingInterpreter(Interpreter):
|
||||
try:
|
||||
assert _all_real_if_tensor(args), "Inputs must be real tensors"
|
||||
self.nz3.enable_profiling(True)
|
||||
self.mem_adjustment = 0
|
||||
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
|
||||
|
||||
with unset_fake_temporarily():
|
||||
with get_accelerator().random().fork_rng(devices=[self.device]):
|
||||
@ -248,8 +270,8 @@ class MemoryProfilingInterpreter(Interpreter):
|
||||
|
||||
del args, kwargs
|
||||
|
||||
current_alloc = get_accelerator().memory_allocated()
|
||||
max_alloc = get_accelerator().max_memory_allocated()
|
||||
current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch
|
||||
max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch
|
||||
vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device)
|
||||
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX)
|
||||
current_alloc = vals_to_bcast[0].item()
|
||||
|
Reference in New Issue
Block a user