From 4efd7eca7326588c1fab898adcb7da1fad52607d Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Mon, 29 Sep 2025 11:15:33 +0800 Subject: [PATCH] DeepCompile: Fuse allgather and downcast (#7588) With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast. To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers. This corresponds to issue 1 of #7577. Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB. Profiles of a single step before this PR: image image After this PR: image image This PR also reduces peak memory usage because the `fast_free_schedule()` today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage. P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed. Signed-off-by: Junjie Mao --- csrc/compile/init.cpp | 6 +- csrc/compile/z3.cpp | 70 ++++++++++++++------ csrc/compile/z3.h | 20 ++++-- csrc/includes/deepcompile.h | 4 ++ deepspeed/compile/fx.py | 7 +- deepspeed/compile/passes/zero3_compile.py | 34 ++++++++-- deepspeed/compile/profilers/graph_profile.py | 10 +++ deepspeed/compile/util.py | 11 ++- tests/unit/v1/compile/test_compile_zero.py | 33 +++++++++ 9 files changed, 156 insertions(+), 39 deletions(-) diff --git a/csrc/compile/init.cpp b/csrc/compile/init.cpp index 7ae6fcd93..0ec4ce768 100644 --- a/csrc/compile/init.cpp +++ b/csrc/compile/init.cpp @@ -10,8 +10,10 @@ TORCH_LIBRARY(dc, m) { - m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor"); - m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()"); + m.def("allgather_param(Tensor a, int graph_id, int id, ScalarType? dtype = None) -> Tensor"); + m.def( + "prefetch_params_fused(int graph_id, Tensor[] params, int[] ids," + " ScalarType[]? dtypes = None) -> ()"); m.def("wait_allgather(Tensor(a) a, int graph_id, int id) -> Tensor(a)"); m.def("release_param(Tensor(a) a, int graph_id, int id, int n_users) -> Tensor(a)"); m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor"); diff --git a/csrc/compile/z3.cpp b/csrc/compile/z3.cpp index 4896da753..28ab171dd 100644 --- a/csrc/compile/z3.cpp +++ b/csrc/compile/z3.cpp @@ -68,7 +68,12 @@ public: c10::intrusive_ptr symm_mem) { const DSParam& param = param_registry_->getParam(ds_id); - const at::Tensor& ds_tensor = param.getDSTensor(); + at::Tensor ds_tensor = param.getDSTensor(); + + if (ds_tensor.scalar_type() != output_buf.scalar_type()) { + at::cuda::CUDAStreamGuard guard(ag_stream_); + ds_tensor = ds_tensor.to(output_buf.scalar_type(), true, true); + } if (symm_mem == nullptr) { // Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size) @@ -110,6 +115,7 @@ public: } at::Tensor allgatherParam(long ds_id, + std::optional dtype, c10::intrusive_ptr symm_mem) { const DSParam& param = param_registry_->getParam(ds_id); @@ -118,11 +124,16 @@ public: const int64_t true_numel = static_cast(productDim(param.getShape())); const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size; const int64_t padded_numel = static_cast(world_size) * padded_per_rank; + at::ScalarType target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type(); if (param_registry_->isValid(ds_id)) { // Return a view sliced to the true size with the original shape + // + // Persistent params are gathered in their original dtype which may + // be different from the requested. auto base = param_registry_->getGatheredParam(ds_id); return base.flatten() + .to(target_dtype) .index({torch::indexing::Slice(0, true_numel)}) .view(param.getShape()); } @@ -134,7 +145,7 @@ public: } if (!output_buf.defined()) { at::cuda::CUDAStreamGuard guard(ag_stream_); - output_buf = torch::empty({padded_numel}, ds_tensor.options()); + output_buf = torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype)); } assert(hasKey(ag_comp_done_events_, ds_id)); @@ -150,16 +161,20 @@ public: .view(param.getShape()); } - void prefetchParamsFused(std::vector ds_ids, + void prefetchParamsFused(const std::vector& ds_ids, + const std::optional> dtypes, c10::intrusive_ptr symm_mem) { - std::vector invalid_ds_ids; - for (const auto& ds_id : ds_ids) { - if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); } + std::vector>> invalid_params; + for (int i = 0; i < ds_ids.size(); i++) { + if (!param_registry_->isValid(ds_ids[i])) { + auto dtype = dtypes ? dtypes.value()[i] : std::optional(); + invalid_params.push_back(std::make_tuple(ds_ids[i], dtype)); + } } std::unordered_map output_bufs; - for (long ds_id : invalid_ds_ids) { + for (const auto& [ds_id, dtype] : invalid_params) { const DSParam& param = param_registry_->getParam(ds_id); const at::Tensor& ds_tensor = param.getDSTensor(); const int world_size = process_group_->getSize(); @@ -173,22 +188,26 @@ public: continue; } } - output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options()); + auto target_dtype = dtype ? dtype.value() : ds_tensor.scalar_type(); + output_bufs[ds_id] = + torch::empty({padded_numel}, ds_tensor.options().dtype(target_dtype)); } - for (long ds_id : invalid_ds_ids) { + for (const auto& [ds_id, _] : invalid_params) { ag_comp_done_events_[ds_id]->record(); ag_comp_done_events_[ds_id]->block(ag_stream_); } ncclGroupStart(); - for (long ds_id : invalid_ds_ids) { + for (const auto& [ds_id, _] : invalid_params) { assert(hasKey(output_bufs, ds_id)); launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem); } ncclGroupEnd(); - for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); } + for (const auto& [ds_id, _] : invalid_params) { + ag_comm_done_events_[ds_id]->record(ag_stream_); + } } void releaseParam(long ds_id, long n_users) @@ -458,12 +477,15 @@ void register_z3_param(long ds_id, } } -at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id) +at::Tensor allgather_param(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype) { auto executor = getExecutor(graph_id, executors); if (sync_before_allgather) { c10::cuda::device_synchronize(); } - auto ret = executor->allgatherParam(ds_id, symm_mem); + auto ret = executor->allgatherParam(ds_id, dtype, symm_mem); if (sync_after_allgather) { c10::cuda::device_synchronize(); } return ret; } @@ -477,22 +499,25 @@ void set_persistent(long ds_id) for (auto& it : executors) { if (it.second->hasParam(ds_id)) { auto executor = getExecutor(it.first, executors); - executor->allgatherParam(ds_id, symm_mem); + auto dtype = param_registry->getParam(ds_id).getDtype(); + executor->allgatherParam(ds_id, dtype, symm_mem); } } } void prefetch_params_fused(long graph_id, - const std::vector params, - const std::vector& ds_ids) + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes) { auto executor = getExecutor(graph_id, executors); - executor->prefetchParamsFused(ds_ids, symm_mem); + executor->prefetchParamsFused(ds_ids, dtypes, symm_mem); } void prefetch_params_fused_meta(long graph_id, - const std::vector params, - const std::vector& ds_ids) + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes) { } @@ -518,11 +543,14 @@ void clear_all_gathered_params() } } -at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id) +at::Tensor allgather_param_meta(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype) { const DSParam& param = param_registry->getParam(ds_id); auto options = param.getDSTensor().options().device(c10::kMeta); - at::Tensor output_buf = torch::empty(param.getShape(), options); + at::Tensor output_buf = torch::empty(param.getShape(), options.dtype(dtype)); return output_buf; } diff --git a/csrc/compile/z3.h b/csrc/compile/z3.h index 15e65504b..764db6ee9 100644 --- a/csrc/compile/z3.h +++ b/csrc/compile/z3.h @@ -21,18 +21,26 @@ void register_z3_param(long ds_id, at::Tensor ds_tensor, at::Tensor grad_buffer, bool persistent); -at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id); +at::Tensor allgather_param(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype); void set_persistent(long ds_id); void prefetch_params_fused(long graph_id, - const std::vector params, - const std::vector& ds_ids); + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes); void prefetch_params_fused_meta(long graph_id, - const std::vector params, - const std::vector& ds_ids); + const std::vector& params, + const std::vector& ds_ids, + const std::optional>& dtypes); // for profiling void invalidate_gathered_param(long ds_id); void clear_all_gathered_params(); -at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id); +at::Tensor allgather_param_meta(at::Tensor param_tensor, + long graph_id, + long ds_id, + std::optional dtype); at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users); at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users); at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id); diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h index 5eefe69af..ee3d96597 100644 --- a/csrc/includes/deepcompile.h +++ b/csrc/includes/deepcompile.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #if __has_include() @@ -261,6 +262,7 @@ public: : id_(id), shape_(std::move(ds_shape)), ds_tensor_(ds_tensor), + ds_dtype_(ds_tensor.scalar_type()), grad_buffer_(grad_buffer), partitioned_(partitioned), offset_(offset), @@ -272,6 +274,7 @@ public: long getId() const { return id_; } std::vector getShape() const { return shape_; } + at::ScalarType getDtype() const { return ds_dtype_; } at::Tensor getDSTensor() const { // If the reload event exists and is complete, return the reloaded tensor (if defined) @@ -343,6 +346,7 @@ public: private: long id_; std::vector shape_; + at::ScalarType ds_dtype_; at::Tensor ds_tensor_; at::Tensor ds_reload_tensor_; at::Tensor grad_buffer_; diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index 7770b80f3..7b3408b56 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -3,7 +3,7 @@ # DeepSpeed Team -from typing import Callable, Any, List +from typing import Callable, Any, List, Dict from collections import defaultdict import torch @@ -60,7 +60,8 @@ def add_args_process(graph: Graph, def add_postprocess(graph: Graph, node: Node, fn: Callable[..., Any], - extra_args: List[int] = [], + extra_args: List[Any] = [], + extra_kwargs: Dict[str, Any] = {}, name=None, meta={}) -> Node: # https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py @@ -70,7 +71,7 @@ def add_postprocess(graph: Graph, args += (a, ) node_users = node.users.keys() - new_node = graph.create_node('call_function', fn, args, {}, name=name) + new_node = graph.create_node('call_function', fn, args, extra_kwargs, name=name) users = {} for u in node_users: if u != new_node: diff --git a/deepspeed/compile/passes/zero3_compile.py b/deepspeed/compile/passes/zero3_compile.py index 046e85c81..9fdb3946a 100644 --- a/deepspeed/compile/passes/zero3_compile.py +++ b/deepspeed/compile/passes/zero3_compile.py @@ -10,7 +10,7 @@ import _operator import torch from torch.fx import Graph, Node, GraphModule -from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses +from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses, is_cast_op from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head from ..profilers.graph_profile import ProfilingInterpreter from ..list_schedule import fast_free_schedule @@ -21,14 +21,15 @@ from deepspeed.accelerator import get_accelerator NAME = "zero3_compile" -def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int): +def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int, dtype: torch.dtype): new_ag_node = add_postprocess(graph, node, torch.ops.dc.allgather_param.default, extra_args=[graph_id, ds_id], + extra_kwargs={"dtype": dtype}, name=f"allgather_ds_param_{node.target}_{ds_id}", meta=_make_node_meta(node, ds_id, True)) - new_ag_node.meta["val"] = node.meta["val"] + new_ag_node.meta["val"] = node.meta["val"].to(dtype) # Set the previous node back to output # We don't want to change the output node to allgather @@ -42,7 +43,7 @@ def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int): extra_args=[graph_id, ds_id], name=f"wait_allgather_ds_param__{node.target}_{ds_id}", meta=_make_node_meta(node, ds_id, False)) - new_wait_node.meta["val"] = node.meta["val"] + new_wait_node.meta["val"] = new_ag_node.meta["val"] return new_ag_node @@ -74,9 +75,30 @@ def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nod if len(pn.users) == 0: continue - add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name]) + # If the only use of the parameter is a type-cast to a smaller type, fuse it with all-gather. + fuse_typecast = False + target_dtype = param_manager.params[pn.name].dtype + if len([user for user in pn.users if user.op != "output"]) == 1: + typecast_node = next(iter(pn.users)) + + is_cast, casted_dtype = is_cast_op(typecast_node) + if is_cast and casted_dtype.itemsize < target_dtype.itemsize: + fuse_typecast = True + target_dtype = casted_dtype + + add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name], target_dtype) + if fuse_typecast: + users = node_to_uses[typecast_node] + wait_node = typecast_node.args[0] + for user in list(typecast_node.users.keys()): + if user.op == "output": + wait_node.meta["original_output_name"] = typecast_node.name + user.replace_input_with(typecast_node, wait_node) + graph.erase_node(typecast_node) + else: + users = node_to_uses[pn] + ds_id = param_manager.ds_ids[pn.name] - users = node_to_uses[pn] for user in users: # release_param() only accepts tensors as its first argument. If # `user` is a tuple, we should release the param after any of diff --git a/deepspeed/compile/profilers/graph_profile.py b/deepspeed/compile/profilers/graph_profile.py index ac43b05c2..e81f4e453 100644 --- a/deepspeed/compile/profilers/graph_profile.py +++ b/deepspeed/compile/profilers/graph_profile.py @@ -130,9 +130,15 @@ class ProfilingInterpreter(Interpreter): assert isinstance(args, tuple) assert isinstance(kwargs, dict) + partitioned_params = {} + def rebuild_param_if_necessary(v): if hasattr(v, "ds_id"): v.all_gather(param_list=[v]) + if hasattr(v, "ds_target_dtype"): + casted = v.to(v.ds_target_dtype) + partitioned_params[id(casted)] = v + return casted return v args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x)) @@ -191,6 +197,8 @@ class ProfilingInterpreter(Interpreter): tensor_size = _node_size(out) def partition_param_if_necessary(v): + if id(v) in partitioned_params: + v = partitioned_params[id(v)] if hasattr(v, "ds_id") and not v.ds_persist: v.partition(param_list=[v], has_been_updated=False) return v @@ -227,6 +235,8 @@ class ProfilingInterpreter(Interpreter): assert hasattr(out, "ds_id") if not out.ds_persist: self.nz3.invalidate_gathered_param(args[2]) + if "dtype" in n.kwargs: + setattr(out, "ds_target_dtype", n.kwargs["dtype"]) self.allgather_mem[out.ds_id] = n.meta["alloc_mem"] return out diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index 9ac6ba904..7b8dc83cd 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -5,7 +5,7 @@ import functools import operator -from typing import List, Tuple, Dict +from typing import List, Tuple, Dict, Optional from collections import defaultdict import torch @@ -131,6 +131,15 @@ def is_comm_op(node: Node) -> bool: return "comm" in node.meta and node.meta["comm"] +def is_cast_op(node: Node) -> Tuple[bool, Optional[torch.dtype]]: + if node.op == "call_function": + if node.target == torch.ops.prims.convert_element_type.default: + return (True, node.args[1]) + elif node.target == torch.ops.aten._to_copy.default and set(node.kwargs.keys()) == {"dtype"}: + return (True, node.kwargs["dtype"]) + return (False, None) + + def exclude_from_act_offload(node: Node) -> bool: return node.target in sym_size_ops diff --git a/tests/unit/v1/compile/test_compile_zero.py b/tests/unit/v1/compile/test_compile_zero.py index 02202aeda..16ad12d30 100644 --- a/tests/unit/v1/compile/test_compile_zero.py +++ b/tests/unit/v1/compile/test_compile_zero.py @@ -180,3 +180,36 @@ class TestDeepCompile(DistributedTest): } compare_loss(self, config_dict, dtype) + + @pytest.mark.parametrize('dtype', ["bfloat16", "float16"]) + @pytest.mark.parametrize('zero_stage', [3]) + def test_fusing_allgather_and_autocast(self, zero_stage, dtype): + """Test that allgather and autocast can be correctly fused with DeepCompile""" + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "torch_autocast": { + "enable": True, + "dtype": dtype, + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": True + } + } + + compare_loss(self, config_dict, torch.float32)