Files
DeepSpeed/deepspeed/compile/passes/__init__.py
Masahiro Tanaka 227a60c0c4 DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2025-04-16 04:33:53 +00:00

49 lines
1.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from ..profilers.graph_profile import MemoryProfilingInterpreter
import deepspeed.comm as dist
def run_opt_passes(nz3,
graph_index,
graph_id,
gm,
create_inputs_fn,
opt_passes,
graph_order,
profiling_results,
param_manager,
bwd,
debug_log=False):
profile = profiling_results[graph_id]
rank = dist.get_rank()
for i, opt_pass in enumerate(opt_passes):
opt_pass_fn, mem_budget = opt_pass
graph = opt_pass_fn(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd)
graph.lint()
gm.graph = graph
gm.recompile()
if debug_log:
print(f"Prefetching enabled for {'bwd' if bwd else 'fwd'} graph_id={graph_id} {graph}")
mem_prof = MemoryProfilingInterpreter(nz3, gm)
mem_prof.run(*create_inputs_fn())
if debug_log and rank == 0:
mem_prof.dump(f"mem_prof_r{rank}_{'bwd' if bwd else 'fwd'}_{graph_index}_{graph_id}_pass_{i}.csv")
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
if bwd:
profile.bwd_mem = mem
else:
profile.fwd_mem = mem
return gm