mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
With autocast enabled, a majority of weights are downcasted before being used in calculations. Today zero3_compile gathers the FP32 weights before they are downcasted. That is sub-optimal because FP32 weights consumes more bandwidth to allgather and takes more time to downcast. To reduce communication and downcast time, fuse allgather and downcast in the dc ops. The target type is now passed to allgather_param() and prefetch_params_fused() which will downcast the (partial) weights before launching allgathers. This corresponds to issue 1 of #7577. Tested with https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (run with `deepspeed --num_gpus=N this_file.py -c -p -m 23` to collect torch and memory profiles, and with DINOV2_DEPTH = SIGLIP_DEPTH = 3, LLAMA2_DEPTH = 4 for faster compileation) on 5090 (which has limited inter-GPU bandwidth), time per step decreases from 438ms to 337ms and peak GPU memory usage from 9.5GB to 8.5GB. Profiles of a single step before this PR: <img width="1235" height="1029" alt="image" src="https://github.com/user-attachments/assets/d9fe5296-7731-4542-924b-421ff7415054" /> <img width="1466" height="616" alt="image" src="https://github.com/user-attachments/assets/aa192802-8633-4e36-b2c4-f28b1b432663" /> After this PR: <img width="1218" height="1006" alt="image" src="https://github.com/user-attachments/assets/18a0e09c-155b-4783-adb5-b4d36c5c3691" /> <img width="1537" height="559" alt="image" src="https://github.com/user-attachments/assets/16a2ca74-8a89-4db9-9b68-81844295c61b" /> This PR also reduces peak memory usage because the `fast_free_schedule()` today always arranges param allgathers and downcasts at the beginning of the graph. While the original FP32 params can be freed early, all FP16/BF16-casted params are kept in GPU memory at the beginning of the backward graph, leading to a higher peak in memory usage. P.S. Probably due to organization branch rule settings, I don't find anywhere to allow reviewers to modify the branch. So I'll update the branch per reviewers' comments and rebase if needed. Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
58 lines
2.7 KiB
C++
58 lines
2.7 KiB
C++
// Copyright (c) Microsoft Corporation.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
// DeepSpeed Team
|
|
|
|
#include "deepcompile.h"
|
|
|
|
#pragma once
|
|
|
|
namespace dc {
|
|
|
|
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids);
|
|
void register_graph_ops_z3(long graph_id,
|
|
const std::vector<std::string>& op_names,
|
|
const std::vector<long>& n_args);
|
|
void register_bwd_graph_ops_z3(long graph_id,
|
|
const std::vector<std::string>& op_names,
|
|
const std::vector<long>& n_args);
|
|
void register_z3_param(long ds_id,
|
|
const std::vector<int64_t>& ds_shape,
|
|
at::Tensor ds_tensor,
|
|
at::Tensor grad_buffer,
|
|
bool persistent);
|
|
at::Tensor allgather_param(at::Tensor param_tensor,
|
|
long graph_id,
|
|
long ds_id,
|
|
std::optional<at::ScalarType> dtype);
|
|
void set_persistent(long ds_id);
|
|
void prefetch_params_fused(long graph_id,
|
|
const std::vector<at::Tensor>& params,
|
|
const std::vector<long>& ds_ids,
|
|
const std::optional<std::vector<at::ScalarType>>& dtypes);
|
|
void prefetch_params_fused_meta(long graph_id,
|
|
const std::vector<at::Tensor>& params,
|
|
const std::vector<long>& ds_ids,
|
|
const std::optional<std::vector<at::ScalarType>>& dtypes);
|
|
// for profiling
|
|
void invalidate_gathered_param(long ds_id);
|
|
void clear_all_gathered_params();
|
|
at::Tensor allgather_param_meta(at::Tensor param_tensor,
|
|
long graph_id,
|
|
long ds_id,
|
|
std::optional<at::ScalarType> dtype);
|
|
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
|
|
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
|
|
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);
|
|
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id);
|
|
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id);
|
|
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id);
|
|
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id);
|
|
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id);
|
|
void reload_parameter(at::Tensor tensor, long graph_id, long id);
|
|
void offload_parameter(at::Tensor tensor, long graph_id, long id);
|
|
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
|
|
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
|
|
void end_backward(long graph_id);
|
|
} // namespace dc
|