mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast. To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers. This corresponds to issue 1 of #7577. Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB. Profiles of a single step before this PR: <img width="1235" height="1029" alt="image" src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054" /> <img width="1466" height="616" alt="image" src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663" /> After this PR: <img width="1218" height="1006" alt="image" src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691" /> <img width="1537" height="559" alt="image" src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b" /> This PR also reduces peak memory usage because the `fast_free_schedule()` today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage. P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
235 lines
9.5 KiB
Python
235 lines
9.5 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import gc
|
|
from typing import List, Dict, Tuple
|
|
import _operator
|
|
|
|
import torch
|
|
from torch.fx import Graph, Node, GraphModule
|
|
|
|
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
|
|
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
|
|
from ..profilers.graph_profile import ProfilingInterpreter
|
|
from ..list_schedule import fast_free_schedule
|
|
|
|
import deepspeed.comm as dist
|
|
from deepspeed.accelerator import get_accelerator
|
|
|
|
NAME = "zero3_compile"
|
|
|
|
|
|
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, dtype: torch.dtype):
|
|
new_ag_node = add_postprocess(graph,
|
|
node,
|
|
torch.ops.dc.allgather_param.default,
|
|
extra_args=[graph_id, ds_id],
|
|
extra_kwargs={"dtype": dtype},
|
|
name=f"allgather_ds_param_{node.target}_{ds_id}",
|
|
meta=_make_node_meta(node, ds_id, True))
|
|
new_ag_node.meta["val"] = node.meta["val"].to(dtype)
|
|
|
|
# Set the previous node back to output
|
|
# We don't want to change the output node to allgather
|
|
output_node = get_output_node(graph)
|
|
output_node.replace_input_with(new_ag_node, node)
|
|
|
|
# Add wait as well
|
|
new_wait_node = add_postprocess(graph,
|
|
new_ag_node,
|
|
torch.ops.dc.wait_allgather.default,
|
|
extra_args=[graph_id, ds_id],
|
|
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
|
|
meta=_make_node_meta(node, ds_id, False))
|
|
new_wait_node.meta["val"] = new_ag_node.meta["val"]
|
|
|
|
return new_ag_node
|
|
|
|
|
|
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
|
|
new_node = add_postprocess(graph,
|
|
node,
|
|
torch.ops.dc.release_param.default,
|
|
extra_args=[graph_id, ds_id, n_users],
|
|
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
|
|
meta=_make_node_meta(node, ds_id, False))
|
|
new_node.meta["val"] = None
|
|
|
|
|
|
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
|
|
new_node = add_postprocess(graph,
|
|
grad_node,
|
|
torch.ops.dc.reduce_grad.default,
|
|
extra_args=[graph_id, ds_id],
|
|
name=f"reduce_ds_param_{param_name}",
|
|
meta=_make_node_meta(grad_node, ds_id, True))
|
|
new_node.meta["val"] = None
|
|
|
|
|
|
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
|
|
|
|
node_to_uses = get_real_uses(graph)
|
|
for pn in param_nodes:
|
|
if len(pn.users) == 0:
|
|
continue
|
|
|
|
# If the only use of the parameter is a type-cast to a smaller type, fuse it with all-gather.
|
|
fuse_typecast = False
|
|
target_dtype = param_manager.params[pn.name].dtype
|
|
if len([user for user in pn.users if user.op != "output"]) == 1:
|
|
typecast_node = next(iter(pn.users))
|
|
|
|
is_cast, casted_dtype = is_cast_op(typecast_node)
|
|
if is_cast and casted_dtype.itemsize < target_dtype.itemsize:
|
|
fuse_typecast = True
|
|
target_dtype = casted_dtype
|
|
|
|
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name], target_dtype)
|
|
if fuse_typecast:
|
|
users = node_to_uses[typecast_node]
|
|
wait_node = typecast_node.args[0]
|
|
for user in list(typecast_node.users.keys()):
|
|
if user.op == "output":
|
|
wait_node.meta["original_output_name"] = typecast_node.name
|
|
user.replace_input_with(typecast_node, wait_node)
|
|
graph.erase_node(typecast_node)
|
|
else:
|
|
users = node_to_uses[pn]
|
|
|
|
ds_id = param_manager.ds_ids[pn.name]
|
|
for user in users:
|
|
# release_param() only accepts tensors as its first argument. If
|
|
# `user` is a tuple, we should release the param after any of
|
|
# operator.getitem of that tuple.
|
|
#
|
|
# Since no torch op takes a tuple as an input, we simply walk
|
|
# through users of `user` and check if there is any call to
|
|
# operator.getitem.
|
|
for secondary_user in user.users:
|
|
if secondary_user.op == "call_function" and secondary_user.target == _operator.getitem:
|
|
add_release(graph_id, graph, secondary_user, pn, ds_id, len(users))
|
|
break
|
|
else:
|
|
add_release(graph_id, graph, user, pn, ds_id, len(users))
|
|
|
|
return move_primals_to_head(graph)
|
|
|
|
|
|
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
|
|
param_name_to_grad: Dict[str, Node]) -> Graph:
|
|
|
|
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
|
|
|
|
for param_name in param_manager.param_names:
|
|
if param_name_to_grad[param_name] is None:
|
|
continue
|
|
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
|
|
|
|
return move_primals_to_head(graph)
|
|
|
|
|
|
def add_z3_gather_release_fw(gm: GraphModule,
|
|
graph_id: int,
|
|
graph_order: List[Tuple[int, bool]],
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False) -> GraphModule:
|
|
|
|
nz3 = get_deepcompile_handle()
|
|
|
|
real_inputs = create_inputs_fn()
|
|
param_indices = profiling_results[graph_id].param_indices
|
|
|
|
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
|
|
get_param_nodes(gm.graph, param_indices))
|
|
|
|
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
|
|
|
|
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
|
|
profiler.run(*real_inputs)
|
|
del profiler
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
rank = dist.get_rank()
|
|
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
|
if rank == 0 and debug_log:
|
|
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
|
|
|
for n in gm.graph.nodes:
|
|
is_ds_param = n.name in param_manager[graph_id].ds_ids
|
|
if "val" in n.meta and is_ds_param:
|
|
# Used for Inductor's validation
|
|
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
|
|
|
|
gm.graph = fast_free_schedule(
|
|
gm.graph,
|
|
get_accelerator().available_memory(),
|
|
0, # unused
|
|
debug_log=debug_log)
|
|
|
|
if rank == 0 and debug_log:
|
|
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
|
|
|
return gm
|
|
|
|
|
|
def add_z3_gather_release_bw(gm: GraphModule,
|
|
graph_id: int,
|
|
graph_order: List[Tuple[int, bool]],
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False) -> GraphModule:
|
|
|
|
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
|
|
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
|
|
|
|
input_nodes = get_input_nodes(gm.graph)
|
|
real_inputs = create_inputs_fn()
|
|
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
|
|
|
|
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
|
|
|
|
del real_outputs
|
|
gc.collect()
|
|
get_accelerator().empty_cache()
|
|
|
|
rank = dist.get_rank()
|
|
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
|
if rank == 0 and debug_log:
|
|
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
|
|
|
gm.graph = fast_free_schedule(
|
|
gm.graph,
|
|
get_accelerator().available_memory(),
|
|
0, # unused
|
|
debug_log=debug_log)
|
|
|
|
with gm.graph.inserting_before(get_output_node(gm.graph)):
|
|
gm.graph.create_node("call_function", torch.ops.dc.end_backward.default, (graph_id, ))
|
|
|
|
return gm
|
|
|
|
|
|
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[Tuple[int, bool]], profiling_results,
|
|
create_inputs_fn, mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
|
if bwd:
|
|
return add_z3_gather_release_bw(gm,
|
|
graph_id,
|
|
graph_order,
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False)
|
|
return add_z3_gather_release_fw(gm,
|
|
graph_id,
|
|
graph_order,
|
|
profiling_results,
|
|
create_inputs_fn,
|
|
param_manager,
|
|
debug_log=False)
|