Files
DeepSpeed/deepspeed/compile/graph_param.py
Masahiro Tanaka 227a60c0c4 DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
2025-04-16 04:33:53 +00:00

85 lines
2.9 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple
from functools import reduce
import torch
from torch.fx import Graph, Node
from .fx import get_output_node
from .util import get_param_nodes
@dataclass
class DSGraphParam:
name: str
shape: torch.Size
dtype: torch.dtype
device: torch.device
node: Node
allgather_node: Node
release_node: Node
param: torch.Tensor
numel: int = field(init=False)
def __post_init__(self):
self.numel = reduce(lambda x, y: x * y, self.shape)
class DSGraphParamManager:
def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]):
self._fw_graph = fw_graph
self._bw_graph = None
self._params: Dict[str, DSGraphParam] = {}
self._param_name_to_grad: Dict[str, Node] = {}
self._ds_ids: Dict[str, int] = {}
param_nodes = get_param_nodes(fw_graph, index_to_ds_ids)
self._param_names = [pn.name for pn in param_nodes]
self._param_indices = [i for i, _, _ in index_to_ds_ids]
param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids]
ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids]
ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids]
for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes):
self._params[pn.name] = DSGraphParam(name=pn.name,
shape=ds_shape,
dtype=pi.dtype,
device=pi.device,
node=pn,
allgather_node=None,
release_node=None,
param=pi)
self._ds_ids[pn.name] = ds_id
def get_bwd_mapping(self, bw_graph: Graph):
self._bw_graph = bw_graph
output_node = get_output_node(bw_graph)
param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names]
grad_outputs = [output_node.args[0][i] for i in self._param_indices]
param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)}
return param_nodes_bw, param_name_to_grad
@property
def param_names(self) -> List[str]:
return self._param_names
@property
def params(self) -> Dict[str, DSGraphParam]:
return self._params
@property
def ds_ids(self) -> Dict[str, int]:
return self._ds_ids
def get_grad_name(self, param_name) -> str:
assert self._param_name_to_grad is not None, "Backward graph is not added yet"
return self._param_name_to_grad[param_name]