mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
from torch.fx import GraphModule
|
|
|
|
from ..util import get_deepcompile_handle
|
|
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta
|
|
|
|
NAME = "zero1_compile"
|
|
|
|
|
|
def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager) -> GraphModule:
|
|
|
|
dc = get_deepcompile_handle()
|
|
param_indices = profiling_results[graph_id].param_indices
|
|
dc.register_graph_z1(graph_id, [v[1] for v in param_indices]) # Need this before profiling
|
|
|
|
return gm
|
|
|
|
|
|
def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule:
|
|
|
|
graph = gm.graph
|
|
pm = param_manager[graph_id]
|
|
_, param_name_to_grad = pm.get_bwd_mapping(graph)
|
|
|
|
for param_name in pm.param_names:
|
|
|
|
grad_node = param_name_to_grad[param_name]
|
|
|
|
assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids"
|
|
ds_id = pm.ds_ids[param_name]
|
|
|
|
new_node = add_postprocess(graph,
|
|
grad_node,
|
|
torch.ops.dc.reduce_grad.default,
|
|
extra_args=[graph_id, ds_id],
|
|
name=f"reduce_param_{param_name}",
|
|
meta=_make_node_meta(grad_node, param_name, True))
|
|
new_node.meta["val"] = None
|
|
|
|
gm.graph = move_primals_to_head(graph)
|
|
return gm
|
|
|
|
|
|
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
|
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
|
if bwd:
|
|
return add_z1_reduce_bw(gm, graph_id, param_manager)
|
|
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager)
|