mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast. To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers. This corresponds to issue 1 of #7577. Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB. Profiles of a single step before this PR: <img width="1235" height="1029" alt="image" src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054" /> <img width="1466" height="616" alt="image" src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663" /> After this PR: <img width="1218" height="1006" alt="image" src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691" /> <img width="1537" height="559" alt="image" src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b" /> This PR also reduces peak memory usage because the `fast_free_schedule()` today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage. P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
from typing import Callable, Any, List, Dict
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
from torch.fx import Node, Graph
|
|
|
|
from .util import get_last_uses
|
|
|
|
|
|
def get_output_node(graph: Graph):
|
|
for v in graph.nodes:
|
|
if v.target == "output":
|
|
return v
|
|
raise ValueError("No output node found")
|
|
|
|
|
|
def move_primals_to_head(graph: Graph):
|
|
|
|
# Move primals to the head of the graph
|
|
primals = [n for n in graph.nodes if n.op == "placeholder"]
|
|
non_primals = [n for n in graph.nodes if n.op != "placeholder"]
|
|
all_nodes = primals + non_primals
|
|
|
|
new_graph = Graph()
|
|
env = {}
|
|
for node in all_nodes:
|
|
new_node = new_graph.node_copy(node, lambda n: env[n.name])
|
|
env[node.name] = new_node
|
|
new_graph.lint()
|
|
|
|
return new_graph
|
|
|
|
|
|
def add_args_process(graph: Graph,
|
|
node: Node,
|
|
fn: Callable[..., Any],
|
|
extra_args: List[int] = [],
|
|
name=None,
|
|
meta={}) -> List[Node]:
|
|
# Apply fn to all args of node
|
|
new_nodes = []
|
|
with graph.inserting_before(node):
|
|
target_args = [arg for arg in node.args if isinstance(arg, Node)]
|
|
|
|
for arg in target_args:
|
|
new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
|
|
for k, v in meta.items():
|
|
new_node.meta[k] = v
|
|
node.replace_input_with(arg, new_node)
|
|
new_nodes.append(new_node)
|
|
|
|
return new_nodes
|
|
|
|
|
|
def add_postprocess(graph: Graph,
|
|
node: Node,
|
|
fn: Callable[..., Any],
|
|
extra_args: List[Any] = [],
|
|
extra_kwargs: Dict[str, Any] = {},
|
|
name=None,
|
|
meta={}) -> Node:
|
|
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
|
|
with graph.inserting_after(node):
|
|
args = (node, )
|
|
for a in extra_args: # To add ds_id
|
|
args += (a, )
|
|
|
|
node_users = node.users.keys()
|
|
new_node = graph.create_node('call_function', fn, args, extra_kwargs, name=name)
|
|
users = {}
|
|
for u in node_users:
|
|
if u != new_node:
|
|
users[u] = (node, new_node)
|
|
for u, (old_in, new_in) in users.items():
|
|
u.replace_input_with(old_in, new_in)
|
|
|
|
for k, v in meta.items():
|
|
new_node.meta[k] = v
|
|
|
|
return new_node
|
|
|
|
|
|
def _make_node_meta(node: Node, ds_id: int, comm: bool):
|
|
meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
|
|
if "tensor_meta" in node.meta:
|
|
meta["tensor_meta"] = node.meta["tensor_meta"]
|
|
return meta
|
|
|
|
|
|
def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
|
|
node_to_last_use, _ = get_last_uses(graph)
|
|
activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])
|
|
|
|
offload_id_to_node = {}
|
|
node_to_wait_reload = {}
|
|
for node in graph.nodes:
|
|
if node.target == torch.ops.dc.reload_tensor.default:
|
|
offload_act = node.args[0]
|
|
# node_to_offload_id[offload_act] = node.args[2]
|
|
offload_id_to_node[node.args[2]] = offload_act
|
|
elif node.target == torch.ops.dc.wait_reload.default:
|
|
offload_id = node.args[2]
|
|
node_to_wait_reload[offload_id_to_node[offload_id]] = node
|
|
|
|
activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)
|
|
|
|
last_user_to_uses = defaultdict(list)
|
|
for node, last_user in node_to_last_use.items():
|
|
last_user_to_uses[last_user].append(node)
|
|
|
|
def _should_free(node: Node) -> bool:
|
|
if not hasattr(node, "meta"):
|
|
return False
|
|
if "tensor_meta" not in node.meta:
|
|
return False
|
|
return True
|
|
|
|
def free_tensors(tensors: List[torch.Tensor]):
|
|
for a in tensors:
|
|
if a.numel() > 10_000_000:
|
|
a.data = torch.empty([0], device=a.device, dtype=a.dtype)
|
|
|
|
for last_user, used_nodes in last_user_to_uses.items():
|
|
activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]
|
|
|
|
if len(activation_args) == 0:
|
|
continue
|
|
|
|
node_name = f"free_activations_{[n.name for n in used_nodes]}"
|
|
with graph.inserting_after(last_user):
|
|
args = (activation_args, )
|
|
graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)
|
|
|
|
# Python version for debugging
|
|
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)
|