mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Tuple
|
|
from functools import reduce
|
|
|
|
import torch
|
|
from torch.fx import Graph, Node
|
|
|
|
from .fx import get_output_node
|
|
from .util import get_param_nodes
|
|
|
|
|
|
@dataclass
|
|
class DSGraphParam:
|
|
name: str
|
|
shape: torch.Size
|
|
dtype: torch.dtype
|
|
device: torch.device
|
|
node: Node
|
|
allgather_node: Node
|
|
release_node: Node
|
|
param: torch.Tensor
|
|
numel: int = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
self.numel = reduce(lambda x, y: x * y, self.shape)
|
|
|
|
|
|
class DSGraphParamManager:
|
|
|
|
def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]):
|
|
self._fw_graph = fw_graph
|
|
self._bw_graph = None
|
|
self._params: Dict[str, DSGraphParam] = {}
|
|
self._param_name_to_grad: Dict[str, Node] = {}
|
|
self._ds_ids: Dict[str, int] = {}
|
|
|
|
param_nodes = get_param_nodes(fw_graph, index_to_ds_ids)
|
|
self._param_names = [pn.name for pn in param_nodes]
|
|
self._param_indices = [i for i, _, _ in index_to_ds_ids]
|
|
|
|
param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids]
|
|
ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids]
|
|
ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids]
|
|
|
|
for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes):
|
|
self._params[pn.name] = DSGraphParam(name=pn.name,
|
|
shape=ds_shape,
|
|
dtype=pi.dtype,
|
|
device=pi.device,
|
|
node=pn,
|
|
allgather_node=None,
|
|
release_node=None,
|
|
param=pi)
|
|
self._ds_ids[pn.name] = ds_id
|
|
|
|
def get_bwd_mapping(self, bw_graph: Graph):
|
|
self._bw_graph = bw_graph
|
|
|
|
output_node = get_output_node(bw_graph)
|
|
param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names]
|
|
grad_outputs = [output_node.args[0][i] for i in self._param_indices]
|
|
param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)}
|
|
return param_nodes_bw, param_name_to_grad
|
|
|
|
@property
|
|
def param_names(self) -> List[str]:
|
|
return self._param_names
|
|
|
|
@property
|
|
def params(self) -> Dict[str, DSGraphParam]:
|
|
return self._params
|
|
|
|
@property
|
|
def ds_ids(self) -> Dict[str, int]:
|
|
return self._ds_ids
|
|
|
|
def get_grad_name(self, param_name) -> str:
|
|
assert self._param_name_to_grad is not None, "Backward graph is not added yet"
|
|
return self._param_name_to_grad[param_name]
|