DeepCompile: Fuse allgather and downcast (#7588)

With autocast enabled, a majority of weights are downcasted before being
used in calculations. Today zero3_compile gathers the FP32 weights
before they are downcasted. That is sub-optimal because FP32 weights
consumes more bandwidth to allgather and takes more time to downcast.

To reduce communication and downcast time, fuse allgather and downcast
in the dc ops. The target type is now passed to allgather_param() and
prefetch_params_fused() which will downcast the (partial) weights before
launching allgathers.

This corresponds to issue 1 of #7577.

Tested with
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3
(run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect
torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3,
LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited
inter-GPU bandwidth), time per step decreases from 438ms to 337ms and
peak GPU memory usage from 9.5GB to 8.5GB.

Profiles of a single step before this PR:

<img width="1235" height="1029" alt="image"
src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054"
/>

<img width="1466" height="616" alt="image"
src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663"
/>

After this PR:

<img width="1218" height="1006" alt="image"
src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691"
/>

<img width="1537" height="559" alt="image"
src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b"
/>

This PR also reduces peak memory usage because the
`fast_free_schedule()` today always arranges param allgathers and
downcasts at the beginning of the graph. While the original FP32 params
can be freed early, all FP16/BF16-casted params are kept in GPU memory
at the beginning of the backward graph, leading to a higher peak in
memory usage.

P.S. Probably due to organization branch rule settings, I don't find
anywhere to allow reviewers to modify the branch. So I'll update the
branch per reviewers' comments and rebase if needed.

Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
This commit is contained in:
Junjie Mao
2025-09-29 11:15:33 +08:00
committed by GitHub
parent 6fcccfa2c9
commit 4efd7eca73
9 changed files with 156 additions and 39 deletions

View File

@ -10,8 +10,10 @@
TORCH_LIBRARY(dc, m)
{
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor");
m.def(
"prefetch_params_fused(int graph_id, Tensor[] params, int[] ids,"
" ScalarType[]? dtypes = None) -> ()");
m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)");
m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)");
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");

View File

@ -68,7 +68,12 @@ public:
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
at::Tensor ds_tensor = param.getDSTensor();
if (ds_tensor.scalar_type() != output_buf.scalar_type()) {
at::cuda::CUDAStreamGuard guard(ag_stream_);
ds_tensor = ds_tensor.to(output_buf.scalar_type(), true, true);
}
if (symm_mem == nullptr) {
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
@ -110,6 +115,7 @@ public:
}
at::Tensor allgatherParam(long ds_id,
std::optional<at::ScalarType> dtype,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
const DSParam& param = param_registry_->getParam(ds_id);
@ -118,11 +124,16 @@ public:
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));
const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size;
const int64_t padded_numel = static_cast<int64_t>(world_size) * padded_per_rank;
at::ScalarType target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type();
if (param_registry_->isValid(ds_id)) {
// Return a view sliced to the true size with the original shape
//
// Persistent params are gathered in their original dtype which may
// be different from the requested.
auto base = param_registry_->getGatheredParam(ds_id);
return base.flatten()
.to(target_dtype)
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}
@ -134,7 +145,7 @@ public:
}
if (!output_buf.defined()) {
at::cuda::CUDAStreamGuard guard(ag_stream_);
output_buf = torch::empty({padded_numel}, ds_tensor.options());
output_buf = torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype));
}
assert(hasKey(ag_comp_done_events_, ds_id));
@ -150,16 +161,20 @@ public:
.view(param.getShape());
}
void prefetchParamsFused(std::vector<int64_t> ds_ids,
void prefetchParamsFused(const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>> dtypes,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
std::vector<int64_t> invalid_ds_ids;
for (const auto& ds_id : ds_ids) {
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); }
std::vector<std::tuple<long, std::optional<at::ScalarType>>> invalid_params;
for (int i = 0; i < ds_ids.size(); i++) {
if (!param_registry_->isValid(ds_ids[i])) {
auto dtype = dtypes ? dtypes.value()[i] : std::optional<at::ScalarType>();
invalid_params.push_back(std::make_tuple(ds_ids[i], dtype));
}
}
std::unordered_map<long, at::Tensor> output_bufs;
for (long ds_id : invalid_ds_ids) {
for (const auto& [ds_id, dtype] : invalid_params) {
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
const int world_size = process_group_->getSize();
@ -173,22 +188,26 @@ public:
continue;
}
}
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
auto target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type();
output_bufs[ds_id] =
torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype));
}
for (long ds_id : invalid_ds_ids) {
for (const auto& [ds_id, _] : invalid_params) {
ag_comp_done_events_[ds_id]->record();
ag_comp_done_events_[ds_id]->block(ag_stream_);
}
ncclGroupStart();
for (long ds_id : invalid_ds_ids) {
for (const auto& [ds_id, _] : invalid_params) {
assert(hasKey(output_bufs, ds_id));
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem);
}
ncclGroupEnd();
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
for (const auto& [ds_id, _] : invalid_params) {
ag_comm_done_events_[ds_id]->record(ag_stream_);
}
}
void releaseParam(long ds_id, long n_users)
@ -458,12 +477,15 @@ void register_z3_param(long ds_id,
}
}
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
at::Tensor allgather_param(at::Tensor param_tensor,
long graph_id,
long ds_id,
std::optional<at::ScalarType> dtype)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
if (sync_before_allgather) { c10::cuda::device_synchronize(); }
auto ret = executor->allgatherParam(ds_id, symm_mem);
auto ret = executor->allgatherParam(ds_id, dtype, symm_mem);
if (sync_after_allgather) { c10::cuda::device_synchronize(); }
return ret;
}
@ -477,22 +499,25 @@ void set_persistent(long ds_id)
for (auto& it : executors) {
if (it.second->hasParam(ds_id)) {
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors);
executor->allgatherParam(ds_id, symm_mem);
auto dtype = param_registry->getParam(ds_id).getDtype();
executor->allgatherParam(ds_id, dtype, symm_mem);
}
}
}
void prefetch_params_fused(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids)
const std::vector<at::Tensor>& params,
const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>>& dtypes)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
executor->prefetchParamsFused(ds_ids, symm_mem);
executor->prefetchParamsFused(ds_ids, dtypes, symm_mem);
}
void prefetch_params_fused_meta(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids)
const std::vector<at::Tensor>& params,
const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>>& dtypes)
{
}
@ -518,11 +543,14 @@ void clear_all_gathered_params()
}
}
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id)
at::Tensor allgather_param_meta(at::Tensor param_tensor,
long graph_id,
long ds_id,
std::optional<at::ScalarType> dtype)
{
const DSParam& param = param_registry->getParam(ds_id);
auto options = param.getDSTensor().options().device(c10::kMeta);
at::Tensor output_buf = torch::empty(param.getShape(), options);
at::Tensor output_buf = torch::empty(param.getShape(), options.dtype(dtype));
return output_buf;
}

View File

@ -21,18 +21,26 @@ void register_z3_param(long ds_id,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool persistent);
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id);
at::Tensor allgather_param(at::Tensor param_tensor,
long graph_id,
long ds_id,
std::optional<at::ScalarType> dtype);
void set_persistent(long ds_id);
void prefetch_params_fused(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids);
const std::vector<at::Tensor>& params,
const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>>& dtypes);
void prefetch_params_fused_meta(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids);
const std::vector<at::Tensor>& params,
const std::vector<long>& ds_ids,
const std::optional<std::vector<at::ScalarType>>& dtypes);
// for profiling
void invalidate_gathered_param(long ds_id);
void clear_all_gathered_params();
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id);
at::Tensor allgather_param_meta(at::Tensor param_tensor,
long graph_id,
long ds_id,
std::optional<at::ScalarType> dtype);
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);

View File

@ -18,6 +18,7 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#if __has_include(<torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>)
@ -261,6 +262,7 @@ public:
: id_(id),
shape_(std::move(ds_shape)),
ds_tensor_(ds_tensor),
ds_dtype_(ds_tensor.scalar_type()),
grad_buffer_(grad_buffer),
partitioned_(partitioned),
offset_(offset),
@ -272,6 +274,7 @@ public:
long getId() const { return id_; }
std::vector<int64_t> getShape() const { return shape_; }
at::ScalarType getDtype() const { return ds_dtype_; }
at::Tensor getDSTensor() const
{
// If the reload event exists and is complete, return the reloaded tensor (if defined)
@ -343,6 +346,7 @@ public:
private:
long id_;
std::vector<int64_t> shape_;
at::ScalarType ds_dtype_;
at::Tensor ds_tensor_;
at::Tensor ds_reload_tensor_;
at::Tensor grad_buffer_;

View File

@ -3,7 +3,7 @@
# DeepSpeed Team
from typing import Callable, Any, List
from typing import Callable, Any, List, Dict
from collections import defaultdict
import torch
@ -60,7 +60,8 @@ def add_args_process(graph: Graph,
def add_postprocess(graph: Graph,
node: Node,
fn: Callable[..., Any],
extra_args: List[int] = [],
extra_args: List[Any] = [],
extra_kwargs: Dict[str, Any] = {},
name=None,
meta={}) -> Node:
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
@ -70,7 +71,7 @@ def add_postprocess(graph: Graph,
args += (a, )
node_users = node.users.keys()
new_node = graph.create_node('call_function', fn, args, {}, name=name)
new_node = graph.create_node('call_function', fn, args, extra_kwargs, name=name)
users = {}
for u in node_users:
if u != new_node:

View File

@ -10,7 +10,7 @@ import _operator
import torch
from torch.fx import Graph, Node, GraphModule
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule
@ -21,14 +21,15 @@ from deepspeed.accelerator import get_accelerator
NAME = "zero3_compile"
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, dtype: torch.dtype):
new_ag_node = add_postprocess(graph,
node,
torch.ops.dc.allgather_param.default,
extra_args=[graph_id, ds_id],
extra_kwargs={"dtype": dtype},
name=f"allgather_ds_param_{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, True))
new_ag_node.meta["val"] = node.meta["val"]
new_ag_node.meta["val"] = node.meta["val"].to(dtype)
# Set the previous node back to output
# We don't want to change the output node to allgather
@ -42,7 +43,7 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
extra_args=[graph_id, ds_id],
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_wait_node.meta["val"] = node.meta["val"]
new_wait_node.meta["val"] = new_ag_node.meta["val"]
return new_ag_node
@ -74,9 +75,30 @@ def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nod
if len(pn.users) == 0:
continue
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
# If the only use of the parameter is a type-cast to a smaller type, fuse it with all-gather.
fuse_typecast = False
target_dtype = param_manager.params[pn.name].dtype
if len([user for user in pn.users if user.op != "output"]) == 1:
typecast_node = next(iter(pn.users))
is_cast, casted_dtype = is_cast_op(typecast_node)
if is_cast and casted_dtype.itemsize < target_dtype.itemsize:
fuse_typecast = True
target_dtype = casted_dtype
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name], target_dtype)
if fuse_typecast:
users = node_to_uses[typecast_node]
wait_node = typecast_node.args[0]
for user in list(typecast_node.users.keys()):
if user.op == "output":
wait_node.meta["original_output_name"] = typecast_node.name
user.replace_input_with(typecast_node, wait_node)
graph.erase_node(typecast_node)
else:
users = node_to_uses[pn]
ds_id = param_manager.ds_ids[pn.name]
users = node_to_uses[pn]
for user in users:
# release_param() only accepts tensors as its first argument. If
# `user` is a tuple, we should release the param after any of

View File

@ -130,9 +130,15 @@ class ProfilingInterpreter(Interpreter):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
partitioned_params = {}
def rebuild_param_if_necessary(v):
if hasattr(v, "ds_id"):
v.all_gather(param_list=[v])
if hasattr(v, "ds_target_dtype"):
casted = v.to(v.ds_target_dtype)
partitioned_params[id(casted)] = v
return casted
return v
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x))
@ -191,6 +197,8 @@ class ProfilingInterpreter(Interpreter):
tensor_size = _node_size(out)
def partition_param_if_necessary(v):
if id(v) in partitioned_params:
v = partitioned_params[id(v)]
if hasattr(v, "ds_id") and not v.ds_persist:
v.partition(param_list=[v], has_been_updated=False)
return v
@ -227,6 +235,8 @@ class ProfilingInterpreter(Interpreter):
assert hasattr(out, "ds_id")
if not out.ds_persist:
self.nz3.invalidate_gathered_param(args[2])
if "dtype" in n.kwargs:
setattr(out, "ds_target_dtype", n.kwargs["dtype"])
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"]
return out

View File

@ -5,7 +5,7 @@
import functools
import operator
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import torch
@ -131,6 +131,15 @@ def is_comm_op(node: Node) -> bool:
return "comm" in node.meta and node.meta["comm"]
def is_cast_op(node: Node) -> Tuple[bool, Optional[torch.dtype]]:
if node.op == "call_function":
if node.target == torch.ops.prims.convert_element_type.default:
return (True, node.args[1])
elif node.target == torch.ops.aten._to_copy.default and set(node.kwargs.keys()) == {"dtype"}:
return (True, node.kwargs["dtype"])
return (False, None)
def exclude_from_act_offload(node: Node) -> bool:
return node.target in sym_size_ops

View File

@ -180,3 +180,36 @@ class TestDeepCompile(DistributedTest):
}
compare_loss(self, config_dict, dtype)
@pytest.mark.parametrize('dtype', ["bfloat16", "float16"])
@pytest.mark.parametrize('zero_stage', [3])
def test_fusing_allgather_and_autocast(self, zero_stage, dtype):
"""Test that allgather and autocast can be correctly fused with DeepCompile"""
if not required_torch_version(min_version=2.6):
pytest.skip("DeepCompile requires PyTorch >= v2.6")
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"torch_autocast": {
"enable": True,
"dtype": dtype,
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"deepcompile": True
}
}
compare_loss(self, config_dict, torch.float32)