Files
DeepSpeed/deepspeed/compile/passes/zero1_compile.py
Masahiro Tanaka 227a60c0c4 DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2025-04-16 04:33:53 +00:00

56 lines
1.8 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import GraphModule
from ..util import get_deepcompile_handle
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta
NAME = "zero1_compile"
def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager) -> GraphModule:
dc = get_deepcompile_handle()
param_indices = profiling_results[graph_id].param_indices
dc.register_graph_z1(graph_id, [v[1] for v in param_indices]) # Need this before profiling
return gm
def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule:
graph = gm.graph
pm = param_manager[graph_id]
_, param_name_to_grad = pm.get_bwd_mapping(graph)
for param_name in pm.param_names:
grad_node = param_name_to_grad[param_name]
assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids"
ds_id = pm.ds_ids[param_name]
new_node = add_postprocess(graph,
grad_node,
torch.ops.dc.reduce_grad.default,
extra_args=[graph_id, ds_id],
name=f"reduce_param_{param_name}",
meta=_make_node_meta(grad_node, param_name, True))
new_node.meta["val"] = None
gm.graph = move_primals_to_head(graph)
return gm
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z1_reduce_bw(gm, graph_id, param_manager)
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager)