Files
DeepSpeed/deepspeed/compile/passes/zero3_compile.py
Masahiro Tanaka 227a60c0c4 DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2025-04-16 04:33:53 +00:00

187 lines
7.4 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import gc
from typing import List, Dict
import torch
from torch.fx import Graph, Node, GraphModule
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
NAME = "zero3_compile"
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
new_ag_node = add_postprocess(graph,
node,
torch.ops.dc.allgather_param.default,
extra_args=[graph_id, ds_id],
name=f"allgather_ds_param_{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, True))
new_ag_node.meta["val"] = node.meta["val"]
# Set the previous node back to output
# We don't want to change the output node to allgather
output_node = get_output_node(graph)
output_node.replace_input_with(new_ag_node, node)
# Add wait as well
new_wait_node = add_postprocess(graph,
new_ag_node,
torch.ops.dc.wait_allgather.default,
extra_args=[graph_id, ds_id],
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_wait_node.meta["val"] = node.meta["val"]
return new_ag_node
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
new_node = add_postprocess(graph,
node,
torch.ops.dc.release_param.default,
extra_args=[graph_id, ds_id, n_users],
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_node.meta["val"] = None
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
new_node = add_postprocess(graph,
grad_node,
torch.ops.dc.reduce_grad.default,
extra_args=[graph_id, ds_id],
name=f"reduce_ds_param_{param_name}",
meta=_make_node_meta(grad_node, ds_id, True))
new_node.meta["val"] = None
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
node_to_uses = get_real_uses(graph)
for pn in param_nodes:
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
ds_id = param_manager.ds_ids[pn.name]
users = node_to_uses[pn]
for user in users:
add_release(graph_id, graph, user, pn, ds_id, len(users))
return move_primals_to_head(graph)
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
param_name_to_grad: Dict[str, Node]) -> Graph:
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
for param_name in param_manager.param_names:
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
return move_primals_to_head(graph)
def add_z3_gather_release_fw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
nz3 = get_deepcompile_handle()
real_inputs = create_inputs_fn()
param_indices = profiling_results[graph_id].param_indices
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
get_param_nodes(gm.graph, param_indices))
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
profiler.run(*real_inputs)
del profiler
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
for n in gm.graph.nodes:
is_ds_param = n.name in param_manager[graph_id].ds_ids
if "val" in n.meta and is_ds_param:
# Used for Inductor's validation
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
gm.graph = fast_free_schedule(
gm.graph,
get_accelerator().available_memory(),
0, # unused
debug_log=debug_log)
if rank == 0 and debug_log:
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
return gm
def add_z3_gather_release_bw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
input_nodes = get_input_nodes(gm.graph)
real_inputs = create_inputs_fn()
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
del real_outputs
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
# gm.graph = fast_free_schedule(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log)
return gm
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z3_gather_release_bw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)
return add_z3_gather_release_fw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)