mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
187 lines
7.4 KiB
Python
187 lines
7.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import gc
|
|
from typing import List, Dict
|
|
|
|
import torch
|
|
from torch.fx import Graph, Node, GraphModule
|
|
|
|
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
|
|
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
|
|
from ..profilers.graph_profile import ProfilingInterpreter
|
|
from ..list_schedule import fast_free_schedule
|
|
|
|
import deepspeed.comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
NAME = "zero3_compile"
|
|
|
|
|
|
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
|
|
new_ag_node = add_postprocess(graph,
|
|
node,
|
|
torch.ops.dc.allgather_param.default,
|
|
extra_args=[graph_id, ds_id],
|
|
name=f"allgather_ds_param_{node.target}_{ds_id}",
|
|
meta=_make_node_meta(node, ds_id, True))
|
|
new_ag_node.meta["val"] = node.meta["val"]
|
|
|
|
# Set the previous node back to output
|
|
# We don't want to change the output node to allgather
|
|
output_node = get_output_node(graph)
|
|
output_node.replace_input_with(new_ag_node, node)
|
|
|
|
# Add wait as well
|
|
new_wait_node = add_postprocess(graph,
|
|
new_ag_node,
|
|
torch.ops.dc.wait_allgather.default,
|
|
extra_args=[graph_id, ds_id],
|
|
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
|
|
meta=_make_node_meta(node, ds_id, False))
|
|
new_wait_node.meta["val"] = node.meta["val"]
|
|
|
|
return new_ag_node
|
|
|
|
|
|
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
|
|
new_node = add_postprocess(graph,
|
|
node,
|
|
torch.ops.dc.release_param.default,
|
|
extra_args=[graph_id, ds_id, n_users],
|
|
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
|
|
meta=_make_node_meta(node, ds_id, False))
|
|
new_node.meta["val"] = None
|
|
|
|
|
|
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
|
|
new_node = add_postprocess(graph,
|
|
grad_node,
|
|
torch.ops.dc.reduce_grad.default,
|
|
extra_args=[graph_id, ds_id],
|
|
name=f"reduce_ds_param_{param_name}",
|
|
meta=_make_node_meta(grad_node, ds_id, True))
|
|
new_node.meta["val"] = None
|
|
|
|
|
|
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
|
|
|
|
node_to_uses = get_real_uses(graph)
|
|
for pn in param_nodes:
|
|
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
|
|
ds_id = param_manager.ds_ids[pn.name]
|
|
users = node_to_uses[pn]
|
|
for user in users:
|
|
add_release(graph_id, graph, user, pn, ds_id, len(users))
|
|
|
|
return move_primals_to_head(graph)
|
|
|
|
|
|
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
|
|
param_name_to_grad: Dict[str, Node]) -> Graph:
|
|
|
|
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
|
|
|
|
for param_name in param_manager.param_names:
|
|
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
|
|
|
|
return move_primals_to_head(graph)
|
|
|
|
|
|
def add_z3_gather_release_fw(gm: GraphModule,
|
|
graph_id: int,
|
|
graph_order: List[int],
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False) -> GraphModule:
|
|
|
|
nz3 = get_deepcompile_handle()
|
|
|
|
real_inputs = create_inputs_fn()
|
|
param_indices = profiling_results[graph_id].param_indices
|
|
|
|
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
|
|
get_param_nodes(gm.graph, param_indices))
|
|
|
|
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
|
|
|
|
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
|
|
profiler.run(*real_inputs)
|
|
del profiler
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
rank = dist.get_rank()
|
|
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
|
if rank == 0 and debug_log:
|
|
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
|
|
|
for n in gm.graph.nodes:
|
|
is_ds_param = n.name in param_manager[graph_id].ds_ids
|
|
if "val" in n.meta and is_ds_param:
|
|
# Used for Inductor's validation
|
|
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
|
|
|
|
gm.graph = fast_free_schedule(
|
|
gm.graph,
|
|
get_accelerator().available_memory(),
|
|
0, # unused
|
|
debug_log=debug_log)
|
|
|
|
if rank == 0 and debug_log:
|
|
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
|
|
|
return gm
|
|
|
|
|
|
def add_z3_gather_release_bw(gm: GraphModule,
|
|
graph_id: int,
|
|
graph_order: List[int],
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False) -> GraphModule:
|
|
|
|
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
|
|
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
|
|
|
|
input_nodes = get_input_nodes(gm.graph)
|
|
real_inputs = create_inputs_fn()
|
|
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
|
|
|
|
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
|
|
|
|
del real_outputs
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
rank = dist.get_rank()
|
|
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
|
if rank == 0 and debug_log:
|
|
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
|
|
|
# gm.graph = fast_free_schedule(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log)
|
|
return gm
|
|
|
|
|
|
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
|
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
|
if bwd:
|
|
return add_z3_gather_release_bw(gm,
|
|
graph_id,
|
|
graph_order,
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False)
|
|
return add_z3_gather_release_fw(gm,
|
|
graph_id,
|
|
graph_order,
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False)
|