DeepCompile for enhanced compiler integration (#7154)

This PR introduces *DeepCompile*, a new feature that efficiently
integrates compiler optimizations with other DeepSpeed features.
DeepCompile utilizes torch's dynamo to capture the computation graph and
modifies it to incorporate DeepSpeed’s optimizations seamlessly.

Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements
such as proactive prefetching and selective unsharding to improve
performance.
(More details will be added later.)

---------

Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
Masahiro Tanaka
2025-04-15 21:33:53 -07:00
committed by GitHub
parent a21e5b9db6
commit 227a60c0c4
56 changed files with 5989 additions and 15 deletions

View File

@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 DS_BUILD_DEEP_COMPILE=0 pip3 install .
- name: DS Report
run: |
ds_report

View File

@ -44,7 +44,7 @@ jobs:
- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning]
pip install .[dev,1bit,autotuning,deepcompile]
ds_report
- name: Python environment

View File

@ -15,6 +15,7 @@
## Latest News
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>
* [2025/04] [DeepCompile: Unlocking Compiler Optimization for Distributed Training](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepcompile/README.md)
* [2025/03] [DeepSpeed-AutoTP: Automatic Tensor Parallel Training of Hugging Face models](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/huggingface-tp/README.md)
* [2024/12] [Ulysses-Offload: Democratizing Long Context LLM Training ](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/ulysses-offload/README.md)
* [2024/12] [DeepSpeed-Domino: Communication-Free LLM Training Engine](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-domino/README.md)

168
blogs/deepcompile/README.md Normal file
View File

@ -0,0 +1,168 @@
<div align="center">
# DeepCompile: Unlocking Compiler Optimization for Distributed Training
</div>
# Introduction
<div align="center">
<img src="media/perf_summary.png" width="1000">
</div>
Distributed training has become essential for scaling todays massive deep learning models. While deep learning compilers like PyTorch compiler dramatically improved single-GPU training performance through optimizations like kernel fusion and operator scheduling, they fall short when it comes to distributed workloads.
Existing distributed training frameworks such as DeepSpeed and FSDP have made large-scale model training feasible through advanced parallelization strategies. While powerful, their optimizations are implemented at the PyTorch framework level, which limits the ability to apply compiler-style techniques like dependency analysis or operator scheduling.
DeepCompile addresses this gap by enabling compiler-level optimizations for distributed training. It takes a standard single-GPU model implementation and transforms it into an optimized multi-GPU training graph without requiring changes to the model code. Unlike existing approaches, DeepCompile automatically applies parameter sharding, communication scheduling, and memory-aware execution at the compiler IR level, enabling global analysis and optimization that are difficult to express in traditional frameworks. Furthermore, during training, DeepCompile employs profile-guided optimization techniques to dynamically tune these parallelization strategies and improve training performance.
Our evaluation demonstrates that DeepCompile improves training performance over ZeRO-3 baselines, achieving up to 1.5x speedup when sufficient GPU resources are available, and up to 7x speedup in GPU-constrained settings that require offloading. DeepCompile is available in DeepSpeed versions >= [0.16.6](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.16.6).
# Design Overview
DeepCompile extends the capabilities of deep learning compilers to support distributed training. It starts from a standard single-GPU model implementation, such as those available on the Hugging Face model hub, and automatically transforms it by inserting necessary distributed training operations such as parameter sharding and communication primitives. Users are not required to embed any distributed logic into the model code.
The process begins by compiling the model into an intermediate representation (IR), which forms a computation graph. DeepCompile then applies a sequence of *optimization passes*, each responsible for a specific transformation of the computation graph or a targeted performance improvement, to incrementally introduce distributed behavior and optimize the graph. These include operations such as all-gather for sharded parameters or offloading of optimizer states, all while preserving the original computation semantics (Fig. 1).
<div align="center">
<img src="media/workflow.png" width="400">
*Figure 1: Workflow of compilation and optimization with DeepCompile.*
</div>
At its core, DeepCompile builds on two key capabilities:
- **Automatic parallelization**: DeepCompile allows optimization passes to rewrite the single-GPU computation graph into a distributed multi-GPU version, incorporating strategies such as ZeRO, FSDP, and more. This eliminates the need for manual implementation of distributed training logic, drastically reducing engineering effort.
- **Profile-guided performance tuning**: At runtime, DeepCompile collects profiling data such as operator-level memory usage and execution latency. It uses this information to dynamically schedule computation and communication operators. This enables effects such as an improved overlap between communication and computation, and an avoidance of memory bottlenecks. Fine-grained tuning through these optimization passes often leads to better performance than even manually engineered implementations.
Figure 2 illustrates the optimization cycle employed by DeepCompile. After the initial computation graph is generated by the compiler, DeepCompile profiles its behavior by measuring operator execution time, communication overhead, and memory usage throughout the forward and backward passes.
<div align="center">
<img src="media/opt_loop.png" width="600">
*Figure 2. Optimization cycle.*
</div>
Based on the collected profiling data, DeepCompile applies a sequence of optimization passes. These passes modify the computation graph by inserting, removing, or reordering operators to improve overall efficiency. The modified graph is then re-profiled, and this cycle of profiling and optimization is repeated.
Once a stable set of optimizations has been applied, the graph is deployed for the remaining training iterations. During execution, memory usage and other runtime characteristics may change. In such cases, DeepCompile can resume the profiling and optimization cycle according to the predefined schedule of passes, allowing the graph to adapt and maintain high performance.
# Optimizations
DeepCompile is designed as a general compiler framework for applying and optimizing a wide range of parallelization strategies. In the following, we describe several optimizations that have been implemented as optimization passes within DeepCompile.
## ZeRO3
As an initial step, we have used DeepCompile to implement and enhance ZeRO-3-style optimizations at the compiler level. ZeRO-3 partitions model parameters, gradients, and optimizer states across devices, reducing memory usage and enabling large-scale training.
In conventional ZeRO-3 implementations, operations such as all-gather, reduce-scatter, and buffer release are typically inserted using Python hooks at runtime. DeepCompile replaces this approach by injecting these operations directly into the computation graph during compilation. This allows the compiler to determine their placement precisely, guided by both the static structure of the graph and runtime profiling information.
One of the key optimizations is **proactive prefetching**, which launches all-gather operations earlier in the computation based on memory usage profiling. This reordering increases the overlap between communication and computation thereby improving throughput, while avoiding OOMs. In addition, small communication operations are often fused to reduce launch latency and improve efficiency.
Another optimization is **selective unsharding**, which keeps certain parameters in an unsharded form during the forward and backward passes when memory conditions permit. This reduces the frequency of all-gather operations and avoids redundant communication, particularly in scenarios where gradient accumulation is enabled.
## Offloading
DeepCompile also supports **adaptive offloading**, which offloads optimizer states to reduce GPU memory pressure. Unlike approaches that offload all the optimizer states, adaptive offloading identifies only the portions that exceed the memory limit—such as momentum and variance used by the Adam optimizer—and schedules data transfers to overlap with computation. This selective and asynchronous strategy minimizes overhead and enables efficient training even in memory-constrained environments.
## ZeRO1
ZeRO-1 differs from ZeRO-3 in that it shards only the optimizer states across devices, while keeping parameters and gradients fully replicated. This approach reduces memory usage with minimal changes to computation flow, making it a lightweight alternative for certain training scenarios.
DeepCompile implements ZeRO-1-style optimization by inserting reduce-scatter operations directly into the computation graph. By avoiding Python-level hooks, this graph-level integration reduces overhead and improves execution efficiency.
# Performance Improvements
## ZeRO-3
We evaluated DeepCompile on Llama-3-70B and Mixtral 8x7B using parameter sharding on top of Hugging Face model implementations.
Figure 3 shows training throughput (TFLOPs/GPU) across different gradient accumulation steps, using 32 H100 GPUs with a sequence length of 1024.
We compare DeepCompile against two DeepSpeed ZeRO-3 baselines: (i) an eager-mode version without compiler support (labelled ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (labelled ZeRO3+Compile). For DeepCompile, we enabled both proactive prefetching and selective unsharding to demonstrate the combined effect of these optimization passes.
<div align="center"> <img src="media/perf_zero3.png" width="800">
*Figure 3. Achieved throughputs for ZeRO3 training of Llama-3 70B and Mixtral 8x7B models.*
</div>
Across both models, DeepCompile consistently delivers higher throughput. The benefit becomes more pronounced at higher accumulation steps, where the reduced frequency of parameter updates makes selective unsharding more effective. DeepCompile with proactive prefetching and selective unsharding achieves up to 1.28× speedup over ZeRO-3 on Llama-3-70B and 1.54× on Mixtral 8x7B.
Meanwhile, enabling the PyTorch compiler with ZeRO-3, i.e., ZeRO3+Compile introduces minor overheads in some settings. This is because ZeRO-3 includes many conditional branches for runtime features such as prefetching. When the compiler encounters branches that cannot be statically resolved, it splits the computation into multiple graph segments. These fragmented segments can reduce optimization opportunities and introduce additional overheads during execution.
## Offloading
Training models as large as Llama-3 70B with ZeRO-3 typically requires 32 GPUs with 80GB of memory.
DeepSpeed addresses this challenge by offering offloading capabilities, which transfer optimizer states and optionally model parameters to CPU memory to reduce GPU memory usage. DeepCompile also supports offloading through a dedicated optimization pass, but with a few key differences in design.
Unlike the traditional approach of offloading both optimizer computation and memory, DeepCompile offloads only optimizer memory (e.g., momentum, variance, and master weights of Adam optimizer) while the optimizer computation remains on GPU. DeepCompile profiles memory usage during both forward and backward passes to identify when offloading is necessary, and transfers only the required data. This fine-grained approach avoids unnecessary overhead and helps maintain high computational throughput.
Furthermore, DeepCompile overlaps data transfers with computation whenever possible, dynamically adjusting the timing based on observed memory usage patterns. This asynchronous behavior is a crucial aspect of DeepCompiles offloading strategy, allowing it to reduce GPU memory pressure without stalling execution.
We evaluated DeepCompile's offloading using Llama-3 70B on 16xH100-80GB (half the required GPU counts) and present the results in Figure 4.
<div align="center">
<img src="media/perf_offload.png" width="400">
*Figure 4. Achieved throughput of optimizer offloading for Llama-3 70B on 16x80GB GPUs*
</div>
We compare against two ZeRO-3 offloading baselines: (i) an eager-mode version without compiler support (ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO3+Compile). As shown by the results, DeepCompile significantly improves offloading efficiency and provides up to 7× speedup over ZeRO3+Eager. In contrast, we see that ZeRO3+Compile achieves similar performance as ZeRO3+Eager.
## ZeRO-1
We also evaluated DeepCompile with ZeRO-1 using the Llama-3-8B model. We compare DeepCompile against two ZeRO-1 baselines: (i) an eager-mode version without compiler support (ZeRO1+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO1+Compile). In our experiment with 8 GPUs and a batch size of 2, DeepCompile achieved consistent throughput improvements across different sequence lengths, as shown in Figure 5.
<div align="center">
<img src="media/perf_zero1.png" width="800">
*Figure 5. Achieved throughput of ZeRO-1 training of Llama-3 8B*
</div>
The most significant speedup was observed with batch size 1 and sequence length 512, where DeepCompile outperformed ZeRO1+Eager by up to 1.9×, and ZeRO1+Compile by up to 2.5×.
While compiler-based approaches can be effective for large batch sizes and long sequences by replacing suboptimal operations with more efficient kernels, they may also introduce overheads in ZeRO-1-style training in the form of *graph breaks* around the communication operations. These overheads become more pronounced with smaller batch sizes and sequence lengths, thus hurting performance compared to the non-compiled execution. In contrast, DeepCompile inserts communication operators directly into the computation graph during compilation, avoiding graph fragmentation and minimizing associated overhead. This makes DeepCompile more robust to small-scale workloads, while still benefiting from compiler-level optimizations.
## Additional Results and Analysis
Please refer to our [arXiv paper](https://arxiv.org/abs/2504.09983) for additional results, such as detailed comparisons across different batch sizes, sequence lengths, and memory usage.
# Looking Ahead
DeepCompile brings the power of compiler-based optimizations to distributed deep learning. By transforming computation graphs and applying profile-guided optimization passes, it enables more efficient training without requiring changes to model code.
This release is just the beginning. Were actively working on expanding the set of optimization passes and improving integration with a broader range of distributed training strategies. Future directions include automated parallelization (sequence/tensor parallelisms), smarter memory management, and dynamic adaptation to runtime behavior.
We invite the community to try DeepCompile, explore its capabilities, and contribute to its evolution. Lets build the next generation of scalable deep learning together.
# Contributors
This project is the result of a close collaboration between Microsoft and the University of Virginia. The contributors are: Masahiro Tanaka, Du Li, and Umesh Chand, Olatunji Ruwase (Microsoft); and Ali Zafar and Haiying Shen (University of Virginia).
# Appendix
## Examples and Benchmarks
Our DeepSpeedExamples repository provides [example code](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/deepcompile) to enable DeepCompile.
## Optimization Passes
The following optimization passes are currently available in DeepCompile:
- All-gather & reduce-scatter insertion (ZeRO3)
- Proactive prefetching (ZeRO3)
- Selective unsharding (ZeRO3)
- Reduce-scatter insertion (ZeRO1)
- Adaptive offloading
We used the following combinations of passes in the experiments presented above:
- Improved communication scheduling for ZeRO-3: All-gather & reduce-scatter → Proactive prefetching → Selective unsharding
- Offloading optimizer states for ZeRO3: Adding all-gather & reduce-scatter → Adaptive offloading
- Reduced overhead and improved overlap for ZeRO-1: Adding reduce-scatter

Binary file not shown.

After

Width:  |  Height:  |  Size: 355 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

View File

@ -10,6 +10,7 @@ set DS_BUILD_FP_QUANTIZER=0
set DS_BUILD_GDS=0
set DS_BUILD_RAGGED_DEVICE_OPS=0
set DS_BUILD_SPARSE_ATTN=0
set DS_BUILD_DEEP_COMPILE=0
python -m build --wheel --no-isolation

View File

@ -0,0 +1,188 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "deepcompile.h"
#define USE_C10D_NCCL
namespace dc {
std::shared_ptr<DSParamRegistry> param_registry;
std::unordered_map<long, std::shared_ptr<CustomOpExecutor>> executors;
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets = nullptr;
c10::intrusive_ptr<c10d::ProcessGroup> process_group = nullptr;
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem = nullptr;
ncclComm_t nccl_comm;
bool use_symm_mem;
bool clone_custom_op_output;
bool profile = false;
bool pre_div_reduce = true;
bool sync_before_reduce; // for debugging
bool sync_after_reduce; // for debugging
bool sync_before_allgather; // for debugging
bool sync_after_allgather; // for debugging
std::vector<int64_t> sizes_to_int_vector(at::IntArrayRef sizes)
{
std::vector<int64_t> result;
for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); }
return result;
}
void enable_profiling(bool enable) { profile = enable; }
bool is_profiling() { return profile; }
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> getSymmMemWorkspace(int64_t size)
{
c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device());
std::vector<int64_t> sizes = {size};
std::vector<int64_t> strides = {1};
at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p(
{size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt);
return c10d::symmetric_memory::rendezvous(sym_mem_ws);
}
void lazy_init_symm_memory()
{
if (use_symm_mem && !symm_mem) {
int64_t max_param_size = 0;
for (const auto& it : param_registry->getParams()) {
int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size();
if (size > max_param_size) { max_param_size = size; }
}
symm_mem = getSymmMemWorkspace(max_param_size);
}
}
ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type)
{
switch (scalar_type) {
case at::kFloat: return ncclFloat;
case at::kHalf: return ncclHalf;
case at::kDouble: return ncclDouble;
case at::kBFloat16: return ncclBfloat16;
case at::kLong: return ncclInt64;
case at::kInt: return ncclInt;
case at::kChar: return ncclInt8;
default: throw std::runtime_error("Unsupported scalar type");
}
}
void reset()
{
executors.clear();
// We keep the buckets for memory estimation
// reduce_buckets->clear();
}
void cleanup()
{
reset();
ncclCommDestroy(nccl_comm);
process_group = nullptr;
symm_mem = nullptr;
}
at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id)
{
if (sync_before_reduce) { c10::cuda::device_synchronize(); }
assert(hasKey(executors, graph_id));
if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); }
if (sync_after_reduce) { c10::cuda::device_synchronize(); }
return at::Tensor();
}
at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
{
return at::Tensor();
}
void free_tensors(std::vector<at::Tensor> tensors)
{
int64_t THRESHOLD = 10 * 1024 * 1024;
if (!profile) {
for (auto& tensor : tensors) {
if (tensor.is_cuda() && tensor.numel() > THRESHOLD) {
tensor.record_stream(at::cuda::getCurrentCUDAStream());
tensor.set_data(torch::empty({0}, tensor.options()));
}
}
}
}
void free_tensors_meta(std::vector<at::Tensor> tensors) {}
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
int64_t initial_reduce_bucket_size,
bool enable_double_buffer,
bool _use_symm_mem,
bool _clone_custom_op_output,
bool _sync_before_reduce,
bool _sync_after_reduce,
bool _sync_before_allgather,
bool _sync_after_allgather)
{
process_group = pg;
ncclUniqueId ncclID;
ncclGetUniqueId(&ncclID);
// ProcessGroup doesn't have an API to get the CUDA stream for comm calls.
// So we create a NCCL communicator and call NCCL APIs directly.
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(&ncclID),
reinterpret_cast<uint8_t*>(&ncclID) + NCCL_UNIQUE_ID_BYTES);
auto device = torch::Device(torch::kCUDA);
at::Tensor tensor = torch::from_blob(vec.data(), {static_cast<long>(vec.size())}, torch::kUInt8)
.to(torch::Device(torch::kCUDA));
std::vector<at::Tensor> bcast_input = {tensor};
process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait();
// create a new nccl communicator
std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES);
ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank());
param_registry = std::make_shared<DSParamRegistry>();
reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(initial_reduce_bucket_size,
enable_double_buffer);
use_symm_mem = _use_symm_mem;
clone_custom_op_output = _clone_custom_op_output;
sync_before_reduce = _sync_before_reduce;
sync_after_reduce = _sync_after_reduce;
sync_before_allgather = _sync_before_allgather;
sync_after_allgather = _sync_after_allgather;
}
void start_forward()
{
lazy_init_symm_memory();
for (auto& it : executors) { it.second->startForward(); }
}
void end_forward()
{
for (auto& it : executors) { it.second->endForward(); }
}
void start_backward(bool update)
{
for (auto& it : executors) { it.second->startBackward(update); }
}
// We don't call this
// void end_backward(bool update)
// {
// }
} // namespace dc

99
csrc/compile/init.cpp Normal file
View File

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "deepcompile.h"
#include "z1.h"
#include "z3.h"
TORCH_LIBRARY(dc, m)
{
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
m.def("wait_allgather(Tensor a, int graph_id, int id) -> Tensor");
m.def("release_param(Tensor a, int graph_id, int id, int n_users) -> Tensor");
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
m.def("free_tensors(Tensor[] a) -> ()");
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("reload_tensor(Tensor a, int id, int id) -> Tensor");
m.def("wait_offload(Tensor a, int id, int id) -> Tensor");
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
m.def("offload_parameter(Tensor a, int id, int id) -> ()");
m.def("reload_parameter(Tensor a, int id, int id) -> ()");
m.def("test_call(Tensor a) -> Tensor");
}
TORCH_LIBRARY_IMPL(dc, CPU, m)
{
m.impl("allgather_param", &dc::allgather_param);
m.impl("prefetch_params_fused", &dc::prefetch_params_fused);
m.impl("wait_allgather", &dc::wait_allgather);
m.impl("release_param", &dc::release_param);
m.impl("reduce_grad", &dc::reduce_grad);
m.impl("free_tensors", &dc::free_tensors);
m.impl("offload_tensor", &dc::offload_tensor);
m.impl("reload_tensor", &dc::reload_tensor);
m.impl("wait_offload", &dc::wait_offload);
m.impl("wait_reload", &dc::wait_reload);
m.impl("offload_parameter", &dc::offload_parameter);
m.impl("reload_parameter", &dc::reload_parameter);
m.impl("test_call", &dc::test_call);
}
TORCH_LIBRARY_IMPL(dc, CUDA, m)
{
m.impl("allgather_param", &dc::allgather_param);
m.impl("prefetch_params_fused", &dc::prefetch_params_fused);
m.impl("wait_allgather", &dc::wait_allgather);
m.impl("release_param", &dc::release_param);
m.impl("reduce_grad", &dc::reduce_grad);
m.impl("free_tensors", &dc::free_tensors);
m.impl("offload_tensor", &dc::offload_tensor);
m.impl("reload_tensor", &dc::reload_tensor);
m.impl("wait_offload", &dc::wait_offload);
m.impl("wait_reload", &dc::wait_reload);
m.impl("offload_parameter", &dc::offload_parameter);
m.impl("reload_parameter", &dc::reload_parameter);
m.impl("test_call", &dc::test_call);
}
TORCH_LIBRARY_IMPL(dc, Meta, m)
{
m.impl("allgather_param", &dc::allgather_param_meta);
m.impl("prefetch_params_fused", &dc::prefetch_params_fused_meta);
m.impl("release_param", &dc::release_param_meta);
m.impl("wait_allgather", &dc::wait_allgather_meta);
m.impl("reduce_grad", &dc::reduce_grad_meta);
m.impl("free_tensors", &dc::free_tensors_meta);
m.impl("reload_parameter", &dc::reload_parameter_meta);
m.impl("offload_parameter", &dc::offload_parameter_meta);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_persistent", &dc::set_persistent, "Set persistent flag for a parameter");
m.def("enable_profiling", &dc::enable_profiling, "Enable profiling");
m.def("is_profiling", &dc::is_profiling, "Check if profiling is enabled");
m.def("init", &dc::init, "Set the process group");
m.def("cleanup", &dc::cleanup, "Cleanup the process group");
m.def("register_z1_param", &dc::register_z1_param, "Register a parameter");
m.def("register_graph_z1",
&dc::register_graph_z1,
"Register graph with a list of ds parameter ids");
m.def("register_z3_param", &dc::register_z3_param, "Register a parameter");
m.def("register_graph_z3",
&dc::register_graph_z3,
"Register graph with a list of ds parameter ids");
m.def("start_forward", &dc::start_forward, "Start forward pass");
m.def("end_forward", &dc::end_forward, "End forward pass");
m.def("start_backward", &dc::start_backward, "Start backward pass");
// m.def("end_backward", &dc::end_backward, "End backward pass");
m.def("cleanup", &dc::cleanup, "Clean up DeepCompile");
m.def("reset", &dc::reset, "Reset the state");
m.def("invalidate_gathered_param", &dc::invalidate_gathered_param, "Invalidate gathered param");
m.def("clear_all_gathered_params", &dc::clear_all_gathered_params, "Clear all gathered params");
}

89
csrc/compile/util.cpp Normal file
View File

@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "deepcompile.h"
#include <ATen/ATen.h>
namespace dc {
std::string tensorToString(const at::Tensor& t, size_t max_elem, size_t max_str_len)
{
auto t_cpu = t.flatten()
.slice(0, 0, std::min((int64_t)max_elem, t.numel()))
.to(c10::Device(c10::kCPU), false, true);
size_t size = std::min(max_elem, productDim(t.sizes()));
if (t.scalar_type() == c10::ScalarType::Half || t.scalar_type() == c10::ScalarType::BFloat16) {
auto float_ten = t_cpu.to(c10::ScalarType::Float, false, true).contiguous();
return tensorPtrToString((float*)float_ten.data_ptr(), size, max_str_len);
} else if (t.scalar_type() == c10::ScalarType::Float) {
return tensorPtrToString((float*)t_cpu.data_ptr(), size, max_str_len);
} else if (t.scalar_type() == c10::ScalarType::Double) {
return tensorPtrToString((double*)t_cpu.data_ptr(), size, max_str_len);
} else if (t.scalar_type() == c10::ScalarType::Int) {
int* ptr = static_cast<int*>(t_cpu.data_ptr());
return tensorPtrToString(ptr, size, max_str_len);
} else if (t.scalar_type() == c10::ScalarType::Long) {
long* ptr = static_cast<long*>(t_cpu.data_ptr());
return tensorPtrToString(ptr, size, max_str_len);
} else if (t.scalar_type() == c10::ScalarType::Byte) {
unsigned char* ptr = static_cast<unsigned char*>(t_cpu.data_ptr());
std::vector<unsigned short> vec;
vec.reserve(size);
for (size_t i = 0; i < size; i++) {
vec.push_back(*ptr);
ptr++;
}
return tensorPtrToString(&vec[0], size, max_str_len);
} else if (t.scalar_type() == c10::ScalarType::Bool) {
bool* ptr = static_cast<bool*>(t_cpu.data_ptr());
std::vector<int> vec;
vec.reserve(size);
for (size_t i = 0; i < size; i++) {
vec.push_back(*ptr);
ptr++;
}
return tensorPtrToString(&vec[0], size, max_str_len);
}
std::stringstream ss;
ss << "Failed to convert tensor to string. Invalid type of tensor: "
<< toString(t.scalar_type());
throw std::invalid_argument(ss.str());
}
std::string tensorPtrToString(void* ptr,
size_t size,
c10::ScalarType datatype,
size_t max_elem,
size_t max_str_len)
{
int64_t elem_size = std::min((size_t)max_elem, size);
if (datatype == c10::ScalarType::Long) {
return tensorPtrToString(static_cast<long*>(ptr), elem_size, max_str_len);
} else if (datatype == c10::ScalarType::Int) {
return tensorPtrToString(static_cast<int*>(ptr), elem_size, max_str_len);
} else if (datatype == c10::ScalarType::Double) {
return tensorPtrToString(static_cast<double*>(ptr), elem_size, max_str_len);
} else if (datatype == c10::ScalarType::Float) {
return tensorPtrToString(static_cast<float*>(ptr), elem_size, max_str_len);
} else if (datatype == c10::ScalarType::Half || datatype == c10::ScalarType::BFloat16) {
const auto ten = torch::from_blob(ptr, {(int64_t)elem_size}, datatype);
auto float_ten = ten.to(c10::ScalarType::Float, false, true).contiguous();
return tensorPtrToString((float*)float_ten.data_ptr(), elem_size, max_str_len);
}
std::stringstream ss;
ss << "Failed to convert tensor ptr to string. Invalid type of tensor: " << toString(datatype);
throw std::invalid_argument(ss.str());
}
std::string tensorDimToString(const at::Tensor& t)
{
const auto dim = t.sizes();
return join_as_str(dim);
}
} // namespace dc

141
csrc/compile/z1.cpp Normal file
View File

@ -0,0 +1,141 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "z1.h"
#include "deepcompile.h"
#define USE_C10D_NCCL
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace dc {
class Z1CustomOpExecutor : public CustomOpExecutor {
public:
Z1CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::shared_ptr<DSParamRegistry> param_registry,
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
std::vector<long> ds_ids,
ncclComm_t nccl_comm,
at::cuda::CUDAStream rs_stream,
at::cuda::CUDAStream copy_stream,
bool pre_div_reduce)
: CustomOpExecutor(process_group,
param_registry,
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
pre_div_reduce)
{
}
~Z1CustomOpExecutor() {}
void endBackward() override
{
if (param_updated_) {
for (auto& it : has_acc_grad_) { it.second = false; }
}
}
void flushReduceBucket(at::ScalarType scalar_type) override
{
int rank = process_group_->getRank();
if (!hasKey(reduce_tasks_, scalar_type)) { return; }
int64_t tmp_recv_numel = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
copy_done_event->block(rs_stream_);
}
ncclGroupStart();
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg;
if (pre_div_reduce_) {
at::cuda::CUDAStreamGuard guard(rs_stream_);
t.getSendBuf().div_(process_group_->getSize());
}
// inplace
ncclResult_t result = ncclAllReduce(t.getSendBuf().data_ptr(),
t.getSendBuf().data_ptr(),
t.getSendBuf().numel(),
get_nccl_data_type(scalar_type),
op,
nccl_comm_,
rs_stream_);
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllReduce failed"); }
}
ncclGroupEnd();
{
at::cuda::CUDAStreamGuard guard(rs_stream_);
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
bool acc_grad = has_acc_grad_.at(t.getDSId());
auto param = param_registry_->getParam(t.getDSId());
auto grad_buf = param.getGradBuffer().flatten();
if (grad_buf.numel() == 0) { continue; }
int64_t offset = param.getOffset();
auto recv_buf = t.getSendBuf().flatten().index(
{torch::indexing::Slice(offset, offset + grad_buf.numel())});
if (acc_grad) {
grad_buf.add_(recv_buf);
} else {
grad_buf.copy_(recv_buf);
}
has_acc_grad_[t.getDSId()] = true;
}
}
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_);
// Not very sure if this is necessary
// Want to prevent grad tensor from being released before the copy is done
auto comp_stream = at::cuda::getCurrentCUDAStream();
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
copy_done_event->block(comp_stream);
}
reduce_tasks_[scalar_type].clear();
}
};
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
void register_graph_z1(long graph_id, const std::vector<long>& ds_ids)
{
executors[graph_id] = std::make_shared<Z1CustomOpExecutor>(process_group,
param_registry,
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
pre_div_reduce);
}
void register_z1_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
int64_t offset)
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false);
}
} // namespace dc

18
csrc/compile/z1.h Normal file
View File

@ -0,0 +1,18 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "deepcompile.h"
#pragma once
namespace dc {
void register_graph_z1(long graph_id, const std::vector<long>& ds_ids);
void register_z1_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
int64_t offset);
} // namespace dc

544
csrc/compile/z3.cpp Normal file
View File

@ -0,0 +1,544 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "z3.h"
#include "deepcompile.h"
#define USE_C10D_NCCL
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace dc {
const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000;
class Z3CustomOpExecutor : public CustomOpExecutor {
public:
Z3CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::shared_ptr<DSParamRegistry> param_registry,
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
std::vector<long> ds_ids,
ncclComm_t nccl_comm,
at::cuda::CUDAStream ag_stream,
at::cuda::CUDAStream rs_stream,
at::cuda::CUDAStream copy_stream,
at::cuda::CUDAStream offload_stream,
at::cuda::CUDAStream reload_stream,
bool pre_div_reduce)
: CustomOpExecutor(process_group,
param_registry,
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
pre_div_reduce),
ag_stream_(ag_stream),
offload_stream_(offload_stream),
reload_stream_(reload_stream)
{
for (long ds_id : ds_ids_) {
ag_comm_done_events_[ds_id] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
ag_comp_done_events_[ds_id] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
param_use_count_[ds_id] = 0;
}
}
~Z3CustomOpExecutor() {}
void endBackward() override
{
if (param_updated_) {
for (auto& it : has_acc_grad_) {
it.second = false;
param_registry_->setValid(it.first, false);
}
}
for (auto& it : reload_buffers_) {
it.second.record_stream(at::cuda::getCurrentCUDAStream());
}
reload_buffers_.clear();
}
void launchAllGather(at::Tensor output_buf,
long ds_id,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
if (symm_mem == nullptr) {
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
output_buf.data_ptr(),
ds_tensor.numel(),
get_nccl_data_type(ds_tensor.scalar_type()),
nccl_comm_,
ag_stream_);
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); }
} else {
at::cuda::CUDAStreamGuard guard(ag_stream_);
int world_size = process_group_->getSize();
int rank = process_group_->getRank();
at::Tensor local_buf =
symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0);
local_buf.copy_(ds_tensor, true);
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER);
auto chunks = output_buf.flatten().chunk(world_size);
for (int step = 0; step < world_size; step++) {
int remote_rank = (rank - step + world_size) % world_size;
auto src_buf = symm_mem->get_buffer(
remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0);
chunks[remote_rank].copy_(src_buf.flatten(), true);
}
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER);
}
param_registry_->registerGatheredParam(ds_id, output_buf);
param_registry_->setValid(ds_id, true);
}
at::Tensor allgatherParam(long ds_id,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
? param_registry_->getGatheredParam(ds_id)
: torch::empty(param.getShape(), ds_tensor.options());
assert(hasKey(ag_comp_done_events_, ds_id));
ag_comp_done_events_[ds_id]->record();
ag_comp_done_events_[ds_id]->block(ag_stream_);
launchAllGather(output_buf, ds_id, symm_mem);
ag_comm_done_events_[ds_id]->record(ag_stream_);
return output_buf;
}
void prefetchParamsFused(std::vector<int64_t> ds_ids,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
std::vector<int64_t> invalid_ds_ids;
for (const auto& ds_id : ds_ids) {
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); }
}
std::unordered_map<long, at::Tensor> output_bufs;
for (long ds_id : invalid_ds_ids) {
const DSParam& param = param_registry_->getParam(ds_id);
if (param_registry_->hasGatheredParam(ds_id)) {
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
} else {
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
}
}
for (long ds_id : invalid_ds_ids) {
ag_comp_done_events_[ds_id]->record();
ag_comp_done_events_[ds_id]->block(ag_stream_);
}
ncclGroupStart();
for (long ds_id : invalid_ds_ids) {
assert(hasKey(output_bufs, ds_id));
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem);
}
ncclGroupEnd();
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
}
void releaseParam(long ds_id, long n_users)
{
const DSParam& param = param_registry_->getParam(ds_id);
assert(hasKey(param_use_count_, ds_id));
if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; }
param_use_count_[ds_id]--;
if (param_use_count_[ds_id] == 0 && !param.isPersistent()) {
at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id);
if (gathered_param.defined()) { // gathered param is undefined while profiling
const auto options = gathered_param.options();
at::Tensor empty_buffer = torch::empty({0}, options);
gathered_param.set_data(empty_buffer);
}
param_registry_->unregisterGatheredParam(ds_id);
}
}
at::Tensor waitAllgather(at::Tensor v, long ds_id)
{
assert(hasKey(ag_comm_done_events_, ds_id));
ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream());
return v;
}
void flushReduceBucket(at::ScalarType scalar_type) override
{
if (!hasKey(reduce_tasks_, scalar_type)) { return; }
int64_t tmp_recv_numel = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
copy_done_event->block(rs_stream_);
if (has_acc_grad_.at(t.getDSId())) {
tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel();
}
}
at::Tensor tmp_recv_buf = at::Tensor();
if (tmp_recv_numel > 0) {
at::cuda::CUDAStreamGuard guard(rs_stream_);
tmp_recv_buf = torch::empty({tmp_recv_numel},
at::TensorOptions().dtype(scalar_type).device(at::kCUDA));
}
ncclGroupStart();
int64_t offset = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
bool acc_grad = has_acc_grad_.at(t.getDSId());
if (acc_grad) {
recv_buf =
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())});
}
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg;
if (pre_div_reduce_) {
at::cuda::CUDAStreamGuard guard(rs_stream_);
t.getSendBuf().div_(process_group_->getSize());
}
ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(),
recv_buf.data_ptr(),
recv_buf.numel(),
get_nccl_data_type(scalar_type),
op,
nccl_comm_,
rs_stream_);
if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); }
if (acc_grad) { offset += recv_buf.numel(); }
}
ncclGroupEnd();
{
at::cuda::CUDAStreamGuard guard(rs_stream_);
int64_t offset = 0;
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
bool acc_grad = has_acc_grad_.at(t.getDSId());
if (acc_grad) {
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
recv_buf.add_(tmp_recv_buf.index(
{torch::indexing::Slice(offset, offset + recv_buf.numel())}));
offset += recv_buf.numel();
}
has_acc_grad_[t.getDSId()] = true;
}
}
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_);
// Not very sure if this is necessary
// Want to prevent grad tensor from being released before the copy is done
auto comp_stream = at::cuda::getCurrentCUDAStream();
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
copy_done_event->block(comp_stream);
}
reduce_tasks_[scalar_type].clear();
if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); }
}
at::Tensor offloadTensor(at::Tensor tensor, long id)
{
if (!hasKey(offload_events_, id)) {
offload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
offload_comp_done_events_[id] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU);
offload_buffers_[id] = at::empty_like(tensor, options);
}
offload_comp_done_events_[id]->record();
offload_comp_done_events_[id]->block(offload_stream_);
{
at::cuda::CUDAStreamGuard guard(offload_stream_);
offload_buffers_.at(id).copy_(tensor, true);
}
tensor.record_stream(offload_stream_);
offload_events_[id]->record(offload_stream_);
assert(hasKey(offload_buffers_, id));
return offload_buffers_.at(id);
}
at::Tensor reloadTensor(at::Tensor tensor, long id)
{
if (!hasKey(reload_events_, id)) {
reload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
}
assert(hasKey(offload_buffers_, id));
offload_events_[id]->block(reload_stream_);
at::Tensor ten;
{
at::cuda::CUDAStreamGuard guard(reload_stream_);
assert(hasKey(offload_buffers_, id));
at::Tensor buf = offload_buffers_.at(id);
const auto options = at::TensorOptions().device(torch::kCUDA);
ten = at::empty_like(buf, options);
ten.copy_(buf, true);
reload_buffers_[id] = ten;
}
reload_events_[id]->record(reload_stream_);
return ten;
}
at::Tensor waitOffload(at::Tensor tensor, long id)
{
assert(hasKey(offload_events_, id));
offload_events_[id]->block(at::cuda::getCurrentCUDAStream());
assert(hasKey(offload_buffers_, id));
return offload_buffers_.at(id);
}
at::Tensor waitReload(at::Tensor tensor, long id)
{
assert(hasKey(reload_events_, id));
reload_events_[id]->block(at::cuda::getCurrentCUDAStream());
assert(hasKey(reload_buffers_, id));
auto ten = reload_buffers_.at(id);
// We can't release here because the tensor is still being used
// We will need "freeReloadedTensor" after the last user of the tensor to call
// ".record_stream". As it is a bit complicated, we clear the buffer and do at the end of
// the backward pass for now. reload_buffers_.erase(id);
return ten;
}
void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); }
void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); }
bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); }
bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); }
private:
at::cuda::CUDAStream ag_stream_;
at::cuda::CUDAStream offload_stream_;
at::cuda::CUDAStream reload_stream_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comp_done_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comm_done_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_comp_done_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> reload_events_;
std::unordered_map<long, at::Tensor> offload_buffers_;
std::unordered_map<long, at::Tensor> reload_buffers_;
std::unordered_map<long, long> param_use_count_;
};
static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true);
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids)
{
executors[graph_id] = std::make_shared<Z3CustomOpExecutor>(process_group,
param_registry,
reduce_buckets,
ds_ids,
nccl_comm,
ag_stream,
rs_stream,
copy_stream,
offload_stream,
reload_stream,
pre_div_reduce);
}
void register_z3_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool persistent)
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }
}
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
if (sync_before_allgather) { c10::cuda::device_synchronize(); }
auto ret = executor->allgatherParam(ds_id, symm_mem);
if (sync_after_allgather) { c10::cuda::device_synchronize(); }
return ret;
}
void set_persistent(long ds_id)
{
param_registry->setPersistent(ds_id, true);
// Allocate buffer here
// Memory fragmentation will be more severe if we allocate in forward/backward
for (auto& it : executors) {
if (it.second->hasParam(ds_id)) {
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors);
executor->allgatherParam(ds_id, symm_mem);
}
}
}
void prefetch_params_fused(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
executor->prefetchParamsFused(ds_ids, symm_mem);
}
void prefetch_params_fused_meta(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids)
{
}
// for profiling
void invalidate_gathered_param(long ds_id)
{
const DSParam& param = param_registry->getParam(ds_id);
if (param.isPersistent()) { return; }
param_registry->unregisterGatheredParam(ds_id);
param_registry->registerGatheredParam(ds_id, at::Tensor());
}
void clear_all_gathered_params()
{
for (const auto& it : param_registry->getParams()) {
long ds_id = it.first;
const DSParam& param = param_registry->getParam(ds_id);
if (param.isPersistent()) { continue; }
if (param_registry->hasGatheredParam(ds_id)) {
param_registry->unregisterGatheredParam(ds_id);
}
}
}
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id)
{
const DSParam& param = param_registry->getParam(ds_id);
auto options = param.getDSTensor().options().device(c10::kMeta);
at::Tensor output_buf = torch::empty(param.getShape(), options);
return output_buf;
}
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
executor->releaseParam(ds_id, n_users);
if (clone_custom_op_output) { return dummy.clone(); }
return dummy;
}
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users)
{
return dummy;
}
at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
executor->waitAllgather(v, ds_id);
return v;
}
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; }
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
return executor->offloadTensor(tensor, id);
}
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
return executor->reloadTensor(tensor, id);
}
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
return executor->waitOffload(tensor, id);
}
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
if (profile && !executor->hasReloadBuffer(id)) { return tensor; }
return executor->waitReload(tensor, id);
}
at::Tensor test_call(at::Tensor a)
{
std::cout << "test_call" << std::endl;
return a;
}
void reload_parameter(at::Tensor tensor, long graph_id, long ds_id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
executor->reloadParameter(tensor, ds_id);
}
void offload_parameter(at::Tensor tensor, long graph_id, long ds_id)
{
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
executor->offloadParameter(tensor, ds_id);
}
void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {}
void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {}
} // namespace dc

48
csrc/compile/z3.h Normal file
View File

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "deepcompile.h"
#pragma once
namespace dc {
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids);
void register_graph_ops_z3(long graph_id,
const std::vector<std::string>& op_names,
const std::vector<long>& n_args);
void register_bwd_graph_ops_z3(long graph_id,
const std::vector<std::string>& op_names,
const std::vector<long>& n_args);
void register_z3_param(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool persistent);
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id);
void set_persistent(long ds_id);
void prefetch_params_fused(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids);
void prefetch_params_fused_meta(long graph_id,
const std::vector<at::Tensor> params,
const std::vector<long>& ds_ids);
// for profiling
void invalidate_gathered_param(long ds_id);
void clear_all_gathered_params();
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id);
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id);
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id);
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id);
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id);
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id);
void reload_parameter(at::Tensor tensor, long graph_id, long id);
void offload_parameter(at::Tensor tensor, long graph_id, long id);
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
} // namespace dc

576
csrc/includes/deepcompile.h Normal file
View File

@ -0,0 +1,576 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
#define NOMINMAX // Windows idiosyncrasy
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
#define USE_C10D_NCCL
#include <stdio.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/cuda/nccl.h>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
namespace dc {
template <typename K, typename V>
static bool hasKey(const std::unordered_map<K, V>& map, const K& key)
{
return map.find(key) != map.end();
}
template <typename T>
inline std::string to_string(const T& v)
{
std::stringstream ss;
ss << v;
return ss.str();
}
template <typename L>
size_t productDim(const L& dim)
{
size_t prod = 1;
for (auto d : dim) { prod *= d; }
return prod;
}
template <typename T>
std::string join_as_str(const T& v, const char* delim = ",", const size_t maxlen = 0)
{
std::stringstream ss;
if (!v.empty()) {
auto it = v.begin();
ss << to_string(*it);
it++;
for (; it != v.end(); ++it) {
if (delim) ss << delim;
ss << to_string(*it);
}
}
std::string s = ss.str();
if (maxlen > 0 && s.length() > maxlen) { s = s.substr(0, maxlen) + " ..."; }
return "[" + s + "]";
}
template <typename T>
std::string tensorPtrToString(T* ptr, size_t size, size_t str_len = 100)
{
std::vector<T> vals;
for (size_t i = 0; i < size; i++) {
vals.push_back(*ptr);
ptr++;
}
return join_as_str(vals, ",", str_len);
}
std::string tensorPtrToString(void* ptr,
size_t size,
c10::ScalarType datatype,
size_t max_elem = 20,
size_t max_str_len = 100);
std::string tensorToString(const at::Tensor& t, size_t max_elem = 20, size_t max_str_len = 100);
std::string tensorDimToString(const at::Tensor& t);
at::Tensor test_call(at::Tensor param);
extern c10::intrusive_ptr<c10d::ProcessGroup> process_group;
extern c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem;
extern ncclComm_t nccl_comm;
extern bool use_symm_mem;
extern bool clone_custom_op_output;
extern bool profile;
extern bool pre_div_reduce;
extern bool sync_before_reduce; // for debugging
extern bool sync_after_reduce; // for debugging
extern bool sync_before_allgather; // for debugging
extern bool sync_after_allgather; // for debugging
std::vector<int64_t> sizes_to_int_vector(at::IntArrayRef sizes);
void enable_profiling(bool enable);
bool is_profiling();
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> getSymmMemWorkspace(int64_t size);
void lazy_init_symm_memory();
ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type);
void cleanup();
class ReduceTask {
public:
ReduceTask(long ds_id, at::Tensor grad, at::Tensor send_buf)
: ds_id_(ds_id), grad_(std::move(grad)), send_buf_(std::move(send_buf))
{
}
long getDSId() const { return ds_id_; }
at::Tensor getSendBuf() const { return send_buf_; }
private:
long ds_id_;
at::Tensor grad_;
at::Tensor send_buf_;
};
class ReduceBucket {
public:
ReduceBucket(int64_t size, at::ScalarType scalar_type) : size_(size), scalar_type_(scalar_type)
{
buffer_ = torch::empty({size}, at::TensorOptions().dtype(scalar_type).device(at::kCUDA));
offset_ = 0;
}
int64_t getSize() const { return size_; }
int64_t getOffset() const { return offset_; }
at::Tensor getBuffer() const { return buffer_; }
at::ScalarType getScalarType() const { return scalar_type_; }
void reserve(int64_t size)
{
if (size > size_) {
buffer_ =
torch::empty({size}, at::TensorOptions().dtype(scalar_type_).device(at::kCUDA));
size_ = size;
}
}
at::Tensor allocate(int64_t numel)
{
if (offset_ + numel > size_) {
throw std::runtime_error("Buffer size exceeds the reduce bucket size");
}
at::Tensor result = buffer_.index({torch::indexing::Slice(offset_, offset_ + numel)});
offset_ += numel;
return result;
}
bool shouldFlush(int64_t numel) { return offset_ > 0 && offset_ + numel > size_; }
void reset() { offset_ = 0; }
private:
int64_t size_;
int64_t offset_;
at::Tensor buffer_;
at::ScalarType scalar_type_;
};
class DoubleBufferedReduceBucket {
public:
DoubleBufferedReduceBucket(int64_t initial_bucket_size, bool enable_double_buffer)
: initial_bucket_size_(initial_bucket_size), enable_double_buffer_(enable_double_buffer)
{
}
void swap(at::ScalarType scalar_type,
at::cuda::CUDAStream rs_stream,
at::cuda::CUDAStream copy_stream)
{
assert(hasKey(current_buffer_, scalar_type));
assert(hasKey(current_buffer_events_, scalar_type));
current_buffer_.at(scalar_type)->reset();
current_buffer_events_.at(scalar_type)->record(rs_stream);
if (enable_double_buffer_) {
assert(hasKey(shadow_buffer_, scalar_type));
assert(hasKey(shadow_buffer_events_, scalar_type));
auto tmp = current_buffer_.at(scalar_type);
current_buffer_[scalar_type] = shadow_buffer_.at(scalar_type);
shadow_buffer_[scalar_type] = tmp;
auto tmp_event = current_buffer_events_.at(scalar_type);
current_buffer_events_[scalar_type] = shadow_buffer_events_.at(scalar_type);
shadow_buffer_events_[scalar_type] = tmp_event;
}
}
std::shared_ptr<ReduceBucket> getBuffer(at::ScalarType scalar_type)
{
if (!hasKey(current_buffer_, scalar_type)) {
current_buffer_[scalar_type] =
std::make_shared<ReduceBucket>(initial_bucket_size_, scalar_type);
current_buffer_events_[scalar_type] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
if (enable_double_buffer_) {
shadow_buffer_[scalar_type] =
std::make_shared<ReduceBucket>(initial_bucket_size_, scalar_type);
shadow_buffer_events_[scalar_type] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
}
}
return current_buffer_.at(scalar_type);
}
std::shared_ptr<at::cuda::CUDAEvent> getEvent(at::ScalarType scalar_type)
{
assert(hasKey(current_buffer_events_, scalar_type));
return current_buffer_events_.at(scalar_type);
}
void clear()
{
current_buffer_.clear();
shadow_buffer_.clear();
current_buffer_events_.clear();
shadow_buffer_events_.clear();
}
private:
int64_t initial_bucket_size_;
bool enable_double_buffer_;
std::unordered_map<at::ScalarType, std::shared_ptr<ReduceBucket>> current_buffer_;
std::unordered_map<at::ScalarType, std::shared_ptr<ReduceBucket>> shadow_buffer_;
std::unordered_map<at::ScalarType, std::shared_ptr<at::cuda::CUDAEvent>> current_buffer_events_;
std::unordered_map<at::ScalarType, std::shared_ptr<at::cuda::CUDAEvent>> shadow_buffer_events_;
};
class DSParam {
public:
DSParam(long id,
std::vector<int64_t> ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool partitioned,
int64_t offset, // for Z1
bool persistent // for Z3
)
: id_(id),
shape_(std::move(ds_shape)),
ds_tensor_(ds_tensor),
grad_buffer_(grad_buffer),
partitioned_(partitioned),
offset_(offset),
persistent_(persistent),
offload_stream_(at::cuda::getStreamFromPool()),
reload_stream_(at::cuda::getStreamFromPool())
{
}
long getId() const { return id_; }
std::vector<int64_t> getShape() const { return shape_; }
at::Tensor getDSTensor() const
{
// If the reload event exists and is complete, return the reloaded tensor (if defined)
if (reload_done_event_) {
if (!reload_done_event_->query()) {
reload_done_event_->block(at::cuda::getCurrentCUDAStream());
}
if (ds_reload_tensor_.defined()) { return ds_reload_tensor_; }
}
// Otherwise, if an offload event exists, wait for it to complete
if (offload_done_event_) {
if (!offload_done_event_->query()) {
offload_done_event_->block(at::cuda::getCurrentCUDAStream());
}
}
return ds_tensor_;
}
at::Tensor getGradBuffer() const { return grad_buffer_; }
bool isPartitioned() const { return partitioned_; }
int64_t getOffset() const { return offset_; }
void setPersistent(bool persistent) { persistent_ = persistent; }
bool isPersistent() const { return persistent_; }
void offload()
{
// If a reloaded tensor exists, offload its data back to ds_tensor_
if (ds_reload_tensor_.defined()) {
auto comp_stream = at::cuda::getCurrentCUDAStream();
comp_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
// Record completion and wait on the offload stream
comp_done_event_->record(comp_stream);
comp_done_event_->block(offload_stream_);
offload_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
{
at::cuda::CUDAStreamGuard guard(offload_stream_);
ds_tensor_.copy_(ds_reload_tensor_, /*non_blocking=*/true);
ds_reload_tensor_.reset(); // Clear the reloaded tensor
offload_done_event_->record(offload_stream_);
}
// Reset the reload event to indicate that no valid reload is present.
if (reload_done_event_) { reload_done_event_.reset(); }
}
}
void reload()
{
// Reload only if the current ds_tensor_ is on CPU
if (ds_tensor_.device().is_cpu()) {
auto comp_stream = at::cuda::getCurrentCUDAStream();
comp_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
// Record and wait on the reload stream
comp_done_event_->record(comp_stream);
comp_done_event_->block(reload_stream_);
reload_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
{
at::cuda::CUDAStreamGuard guard(reload_stream_);
ds_reload_tensor_ =
at::empty_like(ds_tensor_, ds_tensor_.options().device(torch::kCUDA));
ds_reload_tensor_.copy_(ds_tensor_, /*non_blocking=*/true);
reload_done_event_->record(reload_stream_);
}
// Reset offload_done_event if it exists to clear any stale offload state.
if (offload_done_event_) { offload_done_event_.reset(); }
}
}
private:
long id_;
std::vector<int64_t> shape_;
at::Tensor ds_tensor_;
at::Tensor ds_reload_tensor_;
at::Tensor grad_buffer_;
bool partitioned_;
int64_t offset_; // for Z1
bool persistent_; // for Z3
mutable bool is_reloaded = false;
at::cuda::CUDAStream offload_stream_;
at::cuda::CUDAStream reload_stream_;
std::shared_ptr<at::cuda::CUDAEvent> comp_done_event_;
std::shared_ptr<at::cuda::CUDAEvent> offload_done_event_;
std::shared_ptr<at::cuda::CUDAEvent> reload_done_event_;
};
class DSParamRegistry {
public:
DSParamRegistry() {}
~DSParamRegistry() {}
void registerParam(long ds_id,
const std::vector<int64_t>& ds_shape,
at::Tensor ds_tensor,
at::Tensor grad_buffer,
bool partitioned,
int64_t offset, // for Z1
bool persistent // for Z3
)
{
grad_buffer.zero_();
params_.emplace(
ds_id,
DSParam(ds_id, ds_shape, ds_tensor, grad_buffer, partitioned, offset, persistent));
valid_[ds_id] = false;
}
void registerGatheredParam(long ds_id, at::Tensor ds_tensor)
{
gathered_params_.emplace(ds_id, ds_tensor);
}
void unregisterGatheredParam(long ds_id)
{
assert(hasKey(gathered_params_, ds_id));
gathered_params_.erase(ds_id);
valid_[ds_id] = false;
}
const std::unordered_map<long, DSParam>& getParams() const { return params_; }
const DSParam& getParam(long ds_id) const { return params_.at(ds_id); }
const size_t getNumParams() const { return params_.size(); }
const at::Tensor& getGatheredParam(long ds_id) const
{
assert(hasKey(gathered_params_, ds_id));
return gathered_params_.at(ds_id);
}
bool hasGatheredParam(long ds_id) const { return hasKey(gathered_params_, ds_id); }
void setPersistent(long ds_id, bool persistent) { params_.at(ds_id).setPersistent(persistent); }
void offload(long ds_id) { params_.at(ds_id).offload(); }
void reload(long ds_id) { params_.at(ds_id).reload(); }
void setValid(long ds_id, bool valid) { valid_[ds_id] = valid; }
bool isValid(long ds_id) const
{
assert(hasKey(valid_, ds_id));
return valid_.at(ds_id);
}
private:
std::unordered_map<long, DSParam> params_;
std::unordered_map<long, at::Tensor> gathered_params_;
std::unordered_map<long, bool> valid_;
};
class CustomOpExecutor {
public:
CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::shared_ptr<DSParamRegistry> param_registry,
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
std::vector<long> ds_ids,
ncclComm_t nccl_comm,
at::cuda::CUDAStream rs_stream,
at::cuda::CUDAStream copy_stream,
bool pre_div_reduce)
: process_group_(process_group),
param_registry_(std::move(param_registry)),
reduce_buckets_(std::move(reduce_buckets)),
ds_ids_(std::move(ds_ids)),
nccl_comm_(nccl_comm),
rs_stream_(rs_stream),
copy_stream_(copy_stream),
pre_div_reduce_(pre_div_reduce)
{
for (long ds_id : ds_ids_) {
has_acc_grad_[ds_id] = false;
rs_comp_done_events_[ds_id] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
rs_copy_done_events_[ds_id] =
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
}
reduce_counter_ = ds_ids_.size();
}
~CustomOpExecutor() {}
virtual void startForward() {}
virtual void endForward() {}
virtual void startBackward(bool update) { param_updated_ = update; }
virtual void endBackward() {}
at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id)
{
int world_size = process_group_->getSize();
const DSParam& param = param_registry_->getParam(ds_id);
const auto scalar_type = grad_tensor.scalar_type();
std::shared_ptr<ReduceBucket> reduce_bucket = reduce_buckets_->getBuffer(scalar_type);
auto comp_stream = at::cuda::getCurrentCUDAStream();
if (reduce_bucket->shouldFlush(grad_tensor.numel())) {
int rank = process_group_->getRank();
flushReduceBucket(scalar_type);
// reduce_bucket is swapped in flushReduceBucket if double buffering is enabled
reduce_bucket = reduce_buckets_->getBuffer(scalar_type);
}
if (grad_tensor.numel() > reduce_bucket->getSize()) {
// extend buckets
at::cuda::stream_synchronize(rs_stream_);
reduce_bucket->reserve(grad_tensor.numel());
}
at::Tensor reduce_in_buffer = reduce_bucket->allocate(grad_tensor.numel());
// This ensures the order of reduce_scatter -> copy
// Without this block, copy may start while reduce_scatter is still running
reduce_buckets_->getEvent(scalar_type)->block(comp_stream);
auto copy_src = grad_tensor.contiguous().view({-1}).detach();
// keep references to copy src
reduce_tasks_[scalar_type].emplace_back(ds_id, copy_src, reduce_in_buffer);
// computation must be done before copy
rs_comp_done_events_[ds_id]->record(comp_stream);
rs_comp_done_events_[ds_id]->block(copy_stream_);
{
at::cuda::CUDAStreamGuard guard(copy_stream_);
reduce_in_buffer.copy_(copy_src, true);
rs_copy_done_events_[ds_id]->record(copy_stream_);
}
reduce_counter_--;
if (reduce_counter_ == 0) {
flushAllReduceBuckets();
reduce_counter_ = ds_ids_.size();
// This synchronization ensures all of reduce calls are done before optimizer's step.
at::cuda::stream_synchronize(rs_stream_);
endBackward();
}
return at::Tensor();
}
bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); }
protected:
c10::intrusive_ptr<c10d::ProcessGroup> process_group_;
std::shared_ptr<DSParamRegistry> param_registry_;
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets_;
std::vector<long> ds_ids_;
ncclComm_t nccl_comm_;
at::cuda::CUDAStream rs_stream_;
at::cuda::CUDAStream copy_stream_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> rs_comp_done_events_;
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> rs_copy_done_events_;
size_t reduce_counter_ = 0;
bool param_updated_ = false;
std::unordered_map<at::ScalarType, std::vector<ReduceTask>> reduce_tasks_;
std::unordered_map<long, bool> has_acc_grad_;
bool pre_div_reduce_;
virtual void flushReduceBucket(at::ScalarType scalar_type) = 0;
void flushAllReduceBuckets()
{
for (const auto& it : reduce_tasks_) { flushReduceBucket(it.first); }
}
};
template <typename T, typename U>
std::shared_ptr<T> getExecutor(long graph_id,
const std::unordered_map<long, std::shared_ptr<U>>& executors)
{
assert(hasKey(executors, graph_id));
if (auto executor = std::dynamic_pointer_cast<T>(executors.at(graph_id))) { return executor; }
throw std::runtime_error("Invalid executor type");
}
extern std::shared_ptr<DSParamRegistry> param_registry;
extern std::unordered_map<long, std::shared_ptr<CustomOpExecutor>> executors;
extern std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets;
at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id);
at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id);
void free_tensors(std::vector<at::Tensor> tensors);
void free_tensors_meta(std::vector<at::Tensor> tensors);
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
int64_t initial_reduce_bucket_size,
bool enable_double_buffer,
bool _use_symm_mem,
bool _clone_custom_op_output,
bool _sync_before_reduce,
bool _sync_after_reduce,
bool _sync_before_allgather,
bool _sync_after_allgather);
void reset();
void cleanup();
void start_forward();
void end_forward();
void start_backward(bool update);
} // namespace dc

View File

@ -621,6 +621,17 @@ def initialize_mesh_device(mesh_shape, mesh_dim_names):
return mesh_device
def enable_symm_mem_for_group(group_name: str):
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
if hasattr(cdb, 'enable_symm_mem_for_group'):
cdb.enable_symm_mem_for_group(group_name)
else:
raise RuntimeError(f"Backend {cdb.name} does not support symmetric memory initialization")
# Main DeepSpeed Comms. public API.
def init_distributed(dist_backend=None,
auto_mpi_discovery=True,

View File

@ -409,6 +409,13 @@ class TorchBackend(Backend):
mesh_shape,
mesh_dim_names=mesh_dim_names)
def enable_symm_mem_for_group(self, group_name):
if not required_torch_version(min_version=2.5):
raise RuntimeError(f"Torch version must be 2.5 or higher to use symmetric memory. "
f"Current version: {torch.__version__}")
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
return enable_symm_mem_for_group(group_name)
# This will become a light-weight wrapper around torch.distributed functions
# TODO: create some example to show how this wrapper can help profile communication

View File

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team

View File

@ -0,0 +1,279 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Dict, List, Callable
import time
import gc
import torch
from torch.fx import Graph, GraphModule
try:
import torch.utils._pytree as pytree
import torch._dynamo
import torch._inductor.scheduler
from functorch.compile import make_boxed_func
from torch._functorch.aot_autograd import aot_module_simplified
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
pass
from deepspeed.accelerator import get_accelerator
from .fx import add_free_activations
from .graph_param import DSGraphParamManager
from .profilers import ProfilingResult
from .profilers.graph_profile import MemoryProfilingInterpreter
from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs
from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0
from .partitioner import get_wrapped_partitioner
from .inductor import register_custom_ops, patch_create_aot_dispatcher_function
remaining_schedule = None
next_pass_step = -1
next_passes = None
current_passes = None
param_manager: Dict[int, DSGraphParamManager] = {}
graph_order = []
profiling_results: Dict[int, ProfilingResult] = {}
opt_pass_times = []
opt_passes = {}
fwd_real_inputs = []
remaining_bwd_compile_count = 0
def register_compile_pass(name: str, opt_pass_fn):
opt_passes[name] = opt_pass_fn
def init_schedule(schedule):
assert isinstance(schedule, list), f"schedule should be a list, but got {type(schedule)}"
for step, passes in schedule:
assert isinstance(step, int), f"Each step in schedule should be an integer, but got {type(step)}"
assert isinstance(passes, list), f"Passes at a certain step should be a list, but got {type(passes)}"
global remaining_schedule
remaining_schedule = schedule
def launch_compile_passes(global_steps: int):
global next_pass_step, next_passes
if len(remaining_schedule) > 0 and global_steps == remaining_schedule[0][0]:
_, next_passes = remaining_schedule.pop(0)
log_rank0(f"Launching compile passes: global_steps={global_steps} passes={next_passes}", True)
torch._dynamo.reset()
get_deepcompile_handle().reset()
patch_compiled_func()
graph_order.clear()
profiling_results.clear()
param_manager.clear()
def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results):
node_time = []
tensor_sizes = []
for n in graph.nodes:
node_time.append((n.name, n.meta["device_time"] if "device_time" in n.meta else 0.0,
n.meta["wall_time"] if "wall_time" in n.meta else 0.0))
tensor_sizes.append((n.name, n.meta["tensor_size"] if "tensor_size" in n.meta else 0))
if bwd:
profiling_results[graph_id].bwd_graph = graph
profiling_results[graph_id].bwd_time = node_time
profiling_results[graph_id].bwd_tensor_sizes = tensor_sizes
profiling_results[graph_id].bwd_mem = mem
else:
profiling_results[graph_id].fwd_graph = graph
profiling_results[graph_id].fwd_time = node_time
profiling_results[graph_id].fwd_tensor_sizes = tensor_sizes
profiling_results[graph_id].fwd_mem = mem
def run_opt_passes(opt_passes: List[Callable],
gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
mem_budget: float,
param_manager,
bwd: bool,
debug_log=False) -> None:
with unset_fake_temporarily():
get_accelerator().synchronize()
gc.collect()
get_accelerator().empty_cache()
for i, opt_pass_fn in enumerate(opt_passes):
log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log)
gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager,
bwd)
if gm_new is not None:
gm = gm_new
gm.graph.lint()
gm.recompile()
mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log)
mem_prof.run(*create_inputs_fn())
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results)
with unset_fake_temporarily():
get_accelerator().synchronize()
gc.collect()
get_accelerator().empty_cache()
def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False):
register_custom_ops()
def backend_fn(gm: GraphModule, real_inputs):
graph_id = id(gm.graph)
needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs)
global graph_order
graph_order.append((graph_id, needs_backward))
z3_partition = any(hasattr(v, "ds_id") for v in real_inputs)
if z3_partition:
param_indices = [(i, input_val.ds_id, input_val.ds_shape) for i, input_val in enumerate(real_inputs)
if isinstance(input_val, torch.nn.Parameter)]
else:
assert all(hasattr(v, "param_id") for v in real_inputs
if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id"
param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs)
if isinstance(input_val, torch.nn.Parameter)]
global fwd_real_inputs
fwd_real_inputs.append(real_inputs)
global profiling_results
if graph_id not in profiling_results:
profiling_results[graph_id] = ProfilingResult()
profiling_results[graph_id].param_indices = param_indices
profiling_results[graph_id].needs_backward = needs_backward
def make_fw_graph(gm, sample_inputs):
time_start = time.time()
graph_index = len(graph_order) - 1
real_inputs = fwd_real_inputs.pop(0)
param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices)
real_inputs_with_rng = real_inputs + sample_inputs[len(real_inputs):]
run_opt_passes(
opt_passes=next_passes,
gm=gm,
graph_id=graph_id,
graph_order=graph_order,
profiling_results=profiling_results,
create_inputs_fn=lambda: real_inputs_with_rng,
mem_budget=.0, # unused
param_manager=param_manager,
bwd=False,
debug_log=debug_log)
if needs_backward:
global remaining_bwd_compile_count
remaining_bwd_compile_count += 1
opt_pass_times.append(("fwd", graph_index, graph_id, time.time() - time_start))
log_rank0(
f"Fwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
enable=debug_log)
return gm.graph
def make_bw_graph(gm, sample_inputs):
time_start = time.time()
graph_index = get_index_by_graph_id(graph_order, graph_id)
log_rank0(
f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
enable=debug_log)
bwd_inputs_stack = get_backward_inputs()
if len(bwd_inputs_stack) == 0:
# dynamo calls bw compiler ahead of time when symints are saved for backward. See the details for aot_dispatch_autograd in jit_compile_runtime_wrappers.
# As we currently use actually bwd input values in bw compiler, we return None to skip the compilation there.
# This would need be handled properly in the future.
return None
bwd_real_inputs = bwd_inputs_stack.pop()
run_opt_passes(
opt_passes=next_passes,
gm=gm,
graph_id=graph_id,
graph_order=graph_order,
profiling_results=profiling_results,
create_inputs_fn=lambda: tuple(bwd_real_inputs),
mem_budget=.0, # unused
param_manager=param_manager,
bwd=True,
debug_log=debug_log)
# assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager"
if free_activation:
param_nodes_bw, _ = param_manager[graph_id].get_bwd_mapping(gm.graph)
param_names = [n.name for n in param_nodes_bw]
non_param_input_names = [n.name for n in get_input_nodes(gm.graph) if n.name not in param_names]
add_free_activations(graph_id, gm.graph,
get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names))
global remaining_bwd_compile_count
remaining_bwd_compile_count -= 1
if remaining_bwd_compile_count == 0:
unpatch_compiled_func()
log_rank0(
f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
enable=debug_log)
gm.recompile()
opt_pass_times.append(("bwd", graph_index, graph_id, time.time() - time_start))
return gm.graph
if backend == "eager":
def make_compiler_fn(make_graph_fn):
def compiler_fn(gm, sample_inputs):
return None if make_graph_fn(gm, sample_inputs) is None else make_boxed_func(gm.forward)
return compiler_fn
aot_mod = aot_module_simplified(gm,
real_inputs,
fw_compiler=make_compiler_fn(make_fw_graph),
bw_compiler=make_compiler_fn(make_bw_graph),
partition_fn=get_wrapped_partitioner(param_indices))
return torch._dynamo.optimize(**compile_kwargs)(aot_mod)
elif backend == "inductor":
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager)
from .partitioner import get_wrapped_choose_saved_values_set
torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices)
return torch._inductor.compile(gm, real_inputs)
raise ValueError(f"Unsupported backend {backend}")
return backend_fn

View File

@ -0,0 +1,46 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
class CompileConfig(DeepSpeedConfigModel):
""" Configure compile settings """
deepcompile: bool = False
""" Turn on/off the DeepCompile mode """
free_activation: bool = False
""" Turn on/off the free activation mode """
offload_activation: bool = False
""" Turn on/off the activation offloading """
offload_opt_states: bool = False
""" Turn on/off the optimizer states offloading """
double_buffer: bool = True
""" Turn on/off the double buffering """
symmetric_memory: bool = False
""" Turn on/off the symmetric memory """
debug_log: bool = False
""" Turn on/off the graph dumping """
offload_parameters: bool = False
""" Turn on/off the parameter offloading """
sync_before_reduce: bool = False
""" Turn on/off the sync before reduce """
sync_after_reduce: bool = False
""" Turn on/off the sync after reduce """
sync_before_allgather: bool = False
""" Turn on/off the sync before allgather """
sync_after_allgather: bool = False
""" Turn on/off the sync after allgather """

139
deepspeed/compile/fx.py Normal file
View File

@ -0,0 +1,139 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Callable, Any, List
from collections import defaultdict
import torch
from torch.fx import Node, Graph
from .util import get_last_uses
def get_output_node(graph: Graph):
for v in graph.nodes:
if v.target == "output":
return v
raise ValueError("No output node found")
def move_primals_to_head(graph: Graph):
# Move primals to the head of the graph
primals = [n for n in graph.nodes if n.op == "placeholder"]
non_primals = [n for n in graph.nodes if n.op != "placeholder"]
all_nodes = primals + non_primals
new_graph = Graph()
env = {}
for node in all_nodes:
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
new_graph.lint()
return new_graph
def add_args_process(graph: Graph,
node: Node,
fn: Callable[..., Any],
extra_args: List[int] = [],
name=None,
meta={}) -> List[Node]:
# Apply fn to all args of node
new_nodes = []
with graph.inserting_before(node):
target_args = [arg for arg in node.args if isinstance(arg, Node)]
for arg in target_args:
new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
for k, v in meta.items():
new_node.meta[k] = v
node.replace_input_with(arg, new_node)
new_nodes.append(new_node)
return new_nodes
def add_postprocess(graph: Graph,
node: Node,
fn: Callable[..., Any],
extra_args: List[int] = [],
name=None,
meta={}) -> Node:
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
with graph.inserting_after(node):
args = (node, )
for a in extra_args: # To add ds_id
args += (a, )
node_users = node.users.keys()
new_node = graph.create_node('call_function', fn, args, {}, name=name)
users = {}
for u in node_users:
if u != new_node:
users[u] = (node, new_node)
for u, (old_in, new_in) in users.items():
u.replace_input_with(old_in, new_in)
for k, v in meta.items():
new_node.meta[k] = v
return new_node
def _make_node_meta(node: Node, ds_id: int, comm: bool):
meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
if "tensor_meta" in node.meta:
meta["tensor_meta"] = node.meta["tensor_meta"]
return meta
def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
node_to_last_use, _ = get_last_uses(graph)
activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])
offload_id_to_node = {}
node_to_wait_reload = {}
for node in graph.nodes:
if node.target == torch.ops.dc.reload_tensor.default:
offload_act = node.args[0]
# node_to_offload_id[offload_act] = node.args[2]
offload_id_to_node[node.args[2]] = offload_act
elif node.target == torch.ops.dc.wait_reload.default:
offload_id = node.args[2]
node_to_wait_reload[offload_id_to_node[offload_id]] = node
activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)
last_user_to_uses = defaultdict(list)
for node, last_user in node_to_last_use.items():
last_user_to_uses[last_user].append(node)
def _should_free(node: Node) -> bool:
if not hasattr(node, "meta"):
return False
if not "tensor_meta" in node.meta:
return False
return True
def free_tensors(tensors: List[torch.Tensor]):
for a in tensors:
if a.numel() > 10_000_000:
a.data = torch.empty([0], device=a.device, dtype=a.dtype)
for last_user, used_nodes in last_user_to_uses.items():
activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]
if len(activation_args) == 0:
continue
node_name = f"free_activations_{[n.name for n in used_nodes]}"
with graph.inserting_after(last_user):
args = (activation_args, )
graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)
# Python version for debugging
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)

View File

@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple
from functools import reduce
import torch
from torch.fx import Graph, Node
from .fx import get_output_node
from .util import get_param_nodes
@dataclass
class DSGraphParam:
name: str
shape: torch.Size
dtype: torch.dtype
device: torch.device
node: Node
allgather_node: Node
release_node: Node
param: torch.Tensor
numel: int = field(init=False)
def __post_init__(self):
self.numel = reduce(lambda x, y: x * y, self.shape)
class DSGraphParamManager:
def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]):
self._fw_graph = fw_graph
self._bw_graph = None
self._params: Dict[str, DSGraphParam] = {}
self._param_name_to_grad: Dict[str, Node] = {}
self._ds_ids: Dict[str, int] = {}
param_nodes = get_param_nodes(fw_graph, index_to_ds_ids)
self._param_names = [pn.name for pn in param_nodes]
self._param_indices = [i for i, _, _ in index_to_ds_ids]
param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids]
ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids]
ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids]
for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes):
self._params[pn.name] = DSGraphParam(name=pn.name,
shape=ds_shape,
dtype=pi.dtype,
device=pi.device,
node=pn,
allgather_node=None,
release_node=None,
param=pi)
self._ds_ids[pn.name] = ds_id
def get_bwd_mapping(self, bw_graph: Graph):
self._bw_graph = bw_graph
output_node = get_output_node(bw_graph)
param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names]
grad_outputs = [output_node.args[0][i] for i in self._param_indices]
param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)}
return param_nodes_bw, param_name_to_grad
@property
def param_names(self) -> List[str]:
return self._param_names
@property
def params(self) -> Dict[str, DSGraphParam]:
return self._params
@property
def ds_ids(self) -> Dict[str, int]:
return self._ds_ids
def get_grad_name(self, param_name) -> str:
assert self._param_name_to_grad is not None, "Backward graph is not added yet"
return self._param_name_to_grad[param_name]

View File

@ -0,0 +1,214 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
try:
import torch.utils._pytree as pytree
from torch._functorch.aot_autograd import create_aot_dispatcher_function
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
from torch._inductor.virtualized import V
from torch._inductor.scheduler import Scheduler
original_create_aot_dispatcher_function = create_aot_dispatcher_function
except ImportError:
pass
from .util import get_input_nodes
from .graph_param import DSGraphParamManager
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
def wrapped_compiler(gm, fake_inputs):
mod_graph = dc_compiler(gm, fake_inputs)
# For symint case
if mod_graph is None:
return None
if z3_partition:
# Inductor validates input size estimated by the first trace, where ds tensor is materialized.
# We need to patch the input tensors to avoid the validation error.
patched_inputs = []
if bwd:
param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
param_names = [n.name for n in param_nodes_bw]
else:
param_names = graph_param_manager[graph_id].param_names
input_nodes = get_input_nodes(gm.graph)
for in_node, in_v in zip(input_nodes, fake_inputs):
ds_param = in_node.name in param_names
if ds_param:
from torch._subclasses.fake_tensor import is_fake
from torch._dynamo.utils import to_fake_tensor
assert is_fake(in_v), f"Input {in_v} should be fake tensor"
patched_inputs.append(
to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
else:
patched_inputs.append(in_v)
patched_inputs = tuple(patched_inputs)
else:
patched_inputs = fake_inputs
return original_compiler(gm, patched_inputs)
return wrapped_compiler
def wrap_partition_fn(partition_fn, real_inputs, param_indices):
def wrapped_partition_fn(*args, **kwargs):
fw_module, bw_module = partition_fn(*args, **kwargs)
# get parameter names
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
def fix_placeholder_meta(graph):
for n in graph.nodes:
if n.op == "placeholder" and n.name in pm.param_names:
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)
fix_placeholder_meta(fw_module.graph)
fix_placeholder_meta(bw_module.graph)
return fw_module, bw_module
return wrapped_partition_fn
def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
param_indices, param_manager):
from torch._dynamo.backends.common import AotAutograd
import functools
def patch_aotautograd():
# Unpatch if it was already patched
if hasattr(AotAutograd, "__original_init"):
AotAutograd.__init__ = AotAutograd.__original_init
original_init = AotAutograd.__init__
@functools.wraps(original_init)
def patched_init(self, **kwargs):
kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
make_fw_graph,
z3_partition,
graph_id,
param_manager,
bwd=False)
kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
make_bw_graph,
z3_partition,
graph_id,
param_manager,
bwd=True)
kwargs["inference_compiler"] = kwargs["fw_compiler"]
if z3_partition:
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)
original_init(self, **kwargs)
AotAutograd.__original_init = original_init
AotAutograd.__init__ = patched_init
patch_aotautograd()
def register_custom_ops():
def fallback_handler_no_reuse(kernel,
never_reuse_input,
never_reuse_output,
force_free_input,
add_to_fallback_set=True):
if add_to_fallback_set:
fallbacks.add(kernel)
def handler(*args, **kwargs):
def wrap_tensors(x):
out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
if out is not None and never_reuse_output:
V.graph.never_reuse_buffers.add(out.get_name())
return out
class CustomDCKernel(FallbackKernel):
def __init__(self, op, *args, **kwargs):
super().__init__(op, *args, **kwargs)
def add_to_never_reuse(x):
if isinstance(x, IRNode):
assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
V.graph.never_reuse_buffers.add(x.get_name())
if never_reuse_input:
pytree.tree_map(add_to_never_reuse, args)
def get_var_name_for_arg(self, arg: str):
if arg.isidentifier():
return arg
import re
match = re.match(r"reinterpret_tensor\((\w+),", arg)
if match:
return match.group(1)
return None
def codegen(self, wrapper):
if not force_free_input:
return super().codegen(wrapper)
kernel = self.op_overload
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
V.graph.wrapper_code.generate_fallback_kernel(self, args)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
var_name = self.get_var_name_for_arg(args[0])
if var_name:
wrapper.writeline(f"{var_name} = None")
self.codegen_unbacked_symbol_defs(wrapper)
kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))
return handler
def register_fallback_no_reuse(op_overload,
never_reuse_input=False,
never_reuse_output=False,
force_free_input=False):
add_needs_realized_inputs(op_overload)
return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
op_overload,
never_reuse_input=never_reuse_input,
never_reuse_output=never_reuse_output,
force_free_input=force_free_input))
# Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
# -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
never_reuse_input=True,
never_reuse_output=True,
force_free_input=True)
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
Scheduler.is_dc_patched = True
Scheduler.dead_node_elimination = lambda _: None

View File

@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import copy
import torch
from deepspeed.accelerator import get_accelerator
from .passes import zero1_compile, zero3_compile
from .backend import make_backend, launch_compile_passes, init_schedule
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
WARMUP = 5
def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None):
optimizer = engine.optimizer
optimizer.contiguous_gradients = False # Avoid creating unnecessary buffer
for hook in optimizer._grad_acc_hooks:
hook.remove()
optimizer._grad_acc_hooks.clear()
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group,
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, False,
False)
grad_buffer = {}
for i, group in enumerate(optimizer.bit16_groups):
grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i],
optimizer.first_offset[i],
optimizer.partition_size[i],
dtype=optimizer.gradient_accumulation_dtype,
device=get_accelerator().current_device_name(),
return_tensor_list=True)
grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary
index_in_partition = 0
first_in_partition = True
for p in group:
param_id = optimizer.get_param_id(p)
p.param_id = param_id
in_partition = optimizer.is_param_in_current_partition[param_id]
if in_partition:
buf = grad_buffer[i][index_in_partition]
offset = optimizer.first_offset[i] if first_in_partition else 0
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}")
dc.register_z1_param(p.param_id, p.shape, p, buf, int(offset))
index_in_partition += 1
first_in_partition = False
else:
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None")
dc.register_z1_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0)
def set_grad_buffer():
optimizer.averaged_gradients = copy.copy(grad_buffer)
add_pre_backward_hook(set_grad_buffer)
if schedule is None:
schedule = []
schedule.append((0, [zero1_compile.add_z1_reduce]))
else:
for opt in schedule:
# avoid typical misconfiguration
if zero3_compile.add_z3_gather_release in opt[1]:
raise ValueError("A pass for ZeRO3 is not specified though ZeRO1 is enabled")
init_schedule(schedule)
engine.launch_compile_passes = launch_compile_passes
return make_backend(backend,
compile_kwargs=compile_kwargs,
free_activation=False,
debug_log=compile_config.debug_log)

View File

@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
from .backend import make_backend, launch_compile_passes, init_schedule
from .patch_fake_tensor import patch_fake_tensor
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
WARMUP = 5
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
optimizer = engine.optimizer
if optimizer is not None and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'):
optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None
get_accelerator().empty_cache()
dc = get_deepcompile_handle()
dc.init(engine.data_parallel_group,
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce,
compile_config.sync_before_allgather, compile_config.sync_after_allgather)
# Unset hooks
for m in engine.module.modules():
m._parameters = m._original_parameters
optimizer.parameter_offload._remove_module_hooks()
for hook in optimizer._grad_acc_hooks:
hook.remove()
optimizer._grad_acc_hooks.clear()
# Unpatch linear
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk
if compile_config.symmetric_memory:
group_name = engine.data_parallel_group.group_name
dist.enable_symm_mem_for_group(group_name)
for p in engine.module.parameters():
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
# Disable persistent param
p.ds_persist = False
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
def set_grad_buffer():
for i, sub_group in enumerate(optimizer.fp16_groups):
optimizer.averaged_gradients[i] = [
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group
]
add_pre_backward_hook(set_grad_buffer)
if schedule is None:
schedule = []
if (compile_config.offload_parameters):
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd]))
else:
schedule.append((0, [zero3_compile.add_z3_gather_release]))
schedule.append(
(WARMUP,
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
init_schedule(schedule)
# offloading opt states need additional setup
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states
for _, passes in schedule:
if move_opt_states in passes or move_opt_states_sync in passes:
init_offload_opt_states(optimizer, dc)
engine.launch_compile_passes = launch_compile_passes
patch_fake_tensor()
free_activation = compile_config.free_activation and not is_backend_inductor(backend)
torch._inductor.config.size_asserts = False
return make_backend(backend,
compile_kwargs=compile_kwargs,
free_activation=free_activation,
debug_log=compile_config.debug_log)

View File

@ -0,0 +1,431 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from collections import defaultdict
from typing import List, Dict
from copy import copy
from dataclasses import dataclass
import torch
from torch.fx import Graph, Node
from torch.fx.node import map_arg
try:
from torch.utils._pytree import tree_iter
except ImportError:
pass
from .util import get_last_uses, is_release_node
from .fx import get_output_node
def make_graph_from_schedule(scheduled: List[Node]):
new_graph = Graph()
env = {}
for node in scheduled:
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
return new_graph
def get_original_args_num(node: Node):
if node.name.startswith("allgather_ds_param") \
or node.name.startswith("release_ds_param") \
or node.name.startswith("wait_allgather_ds_param") \
or node.name.startswith("reduce_ds_param"):
return 1
return len(node.args)
def flat_nodes_in_args(args: List[Node]):
return [a for a in tree_iter(args) if isinstance(a, Node)]
def filter_args(node: Node):
args = node.args[:get_original_args_num(node)]
return flat_nodes_in_args(args)
def init_schedule(graph: Graph):
mem_table = create_mem_table(graph)
remaining_users = defaultdict(set)
user_to_producer = {}
scheduled = []
unscheduled = []
edges = defaultdict(list)
for node in graph.nodes:
filtered_args = filter_args(node)
# print(f"Node: {node} args: {node.args}")
if len(filtered_args) == 0:
scheduled.append(node)
remaining_users[node] = set(node.users.keys())
for user in node.users.keys():
user_to_producer[user] = node
else:
unscheduled.append(node)
for a in filtered_args:
for elem_a in tree_iter(a):
if isinstance(elem_a, Node):
if node not in edges[elem_a]:
edges[elem_a].append(node)
return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer
def get_runnable_nodes(scheduled: List[Node], unscheduled: List[Node]):
scheduled = set(scheduled)
return [node for node in unscheduled if all(arg in scheduled for arg in filter_args(node))]
def choose_next_node(scheduled: List[Node], unscheduled: List[Node], mem_table: Dict[str, int]):
runnable_nodes = get_runnable_nodes(scheduled, unscheduled)
# sort by memory usage
runnable_nodes = sorted(runnable_nodes, key=lambda n: mem_table[n.name])
return runnable_nodes[0]
def create_mem_table(graph: Graph) -> Dict[str, int]:
mem_table = {}
for node in graph.nodes:
if node.name.startswith("allgather_ds_param"):
mem_table[node.name] = node.meta["tensor_size"]
elif node.name.startswith("release_ds_param") or node.name.startswith("reduce_ds_param"):
mem_table[node.name] = -node.meta["tensor_size"]
else:
mem_table[node.name] = 0
return mem_table
def list_schedule(graph: Graph) -> Graph:
scheduled, unscheduled, mem_table = init_schedule(graph)
while len(unscheduled) > 0:
next_node = choose_next_node(scheduled, unscheduled, mem_table)
scheduled.append(next_node)
unscheduled.remove(next_node)
return make_graph_from_schedule(scheduled)
###############################
def get_new_runnable_nodes_with(scheduled: List[Node], edges: Dict[Node, List[Node]], new_scheduled: Node):
scheduled = set(scheduled)
new_runnables = []
for node in edges[new_scheduled]:
if all(arg in scheduled for arg in filter_args(node) if arg != new_scheduled):
new_runnables.append(node)
return new_runnables
def _do_schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]],
non_ag_runnable: List[Node]):
while len(non_ag_runnable) > 0:
next_node = non_ag_runnable.pop()
new_runnables = get_new_runnable_nodes_with(scheduled, edges, next_node)
non_ag_runnable += [n for n in new_runnables if not n.name.startswith("allgather_ds_param")]
scheduled.append(next_node)
unscheduled.remove(next_node)
return scheduled, unscheduled
def schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]]):
runnable = get_runnable_nodes(scheduled, unscheduled)
non_ag_runnable = [n for n in runnable if not n.name.startswith("allgather_ds_param")]
tmp_scheduled = copy(scheduled)
tmp_unscheduled = copy(unscheduled)
return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable)
def try_schedule_with_new_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]],
new_scheduled: Node):
new_runnables = get_new_runnable_nodes_with(scheduled, edges, new_scheduled)
non_ag_runnable = [n for n in new_runnables if not n.name.startswith("allgather_ds_param")]
tmp_scheduled = copy(scheduled)
tmp_unscheduled = copy(unscheduled)
tmp_scheduled.append(new_scheduled)
tmp_unscheduled.remove(new_scheduled)
return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable)
def simple_prefetch(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph:
scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule(graph)
tmp_scheduled, tmp_unscheduled = schedule_without_allgather(scheduled, unscheduled, edges)
while len(tmp_unscheduled) > 0:
runnable = get_runnable_nodes(tmp_scheduled, tmp_unscheduled)
ag_with_unblock_time = []
for ag_node in runnable:
ag_scheduled, ag_unscheduled = try_schedule_with_new_allgather(tmp_scheduled, tmp_unscheduled, edges,
ag_node)
unblock_time = sum(n.meta["device_time"] for n in ag_scheduled[len(tmp_scheduled) + 1:])
ag_with_unblock_time.append((ag_node, unblock_time, ag_scheduled, ag_unscheduled))
ag_with_unblock_time = sorted(ag_with_unblock_time, key=lambda x: x[1], reverse=True)
best_ag_node = ag_with_unblock_time[0][0]
best_ag_scheduled = ag_with_unblock_time[0][2]
no_ag_runnables = tmp_scheduled[len(scheduled):]
after_ag_runnables = best_ag_scheduled[len(tmp_scheduled) + 1:]
scheduled.append(best_ag_node)
unscheduled.remove(best_ag_node)
for n in no_ag_runnables:
scheduled.append(n)
unscheduled.remove(n)
tmp_scheduled = copy(scheduled)
tmp_unscheduled = copy(unscheduled)
for n in after_ag_runnables:
tmp_scheduled.append(n)
tmp_unscheduled.remove(n)
return make_graph_from_schedule(tmp_scheduled)
###############################
def init_schedule_with_placeholders(graph: Graph):
mem_table = create_mem_table(graph)
remaining_users = defaultdict(set)
user_to_producer = {}
scheduled = []
unscheduled = []
edges = defaultdict(list)
for node in graph.nodes:
if node.op == 'placeholder':
scheduled.append(node)
remaining_users[node] = set(node.users.keys())
for user in node.users.keys():
user_to_producer[user] = node
else:
unscheduled.append(node)
return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer
def get_node_requirements(target_node: Node, scheduled: List[Node]):
scheduled = set(scheduled)
visited = set()
ordered_nodes = []
def dfs(node: Node):
if node in scheduled:
return
if node in visited:
return
visited.add(node)
args = []
def register_arg(n: Node):
args.append(n)
map_arg(node.args, register_arg)
for arg in args:
dfs(arg)
ordered_nodes.append(node)
dfs(target_node)
return ordered_nodes
@dataclass
class AllgatherTask:
node: Node
allgather_cost: float
free_cost: float
allgathered_mem: int
allgather_acc_mem: int
free_acc_mem: int
last_use: Node
n_scheduled_ags: int
schedule_until_ag: List[Node]
schedule_until_free: List[Node]
def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph:
node_to_last_use, user_to_last_uses = get_last_uses(graph)
# check tensor size
for node in graph.nodes:
if "tensor_size" not in node.meta:
# Our profiler may not visit all nodes because of the control flow.
node.meta["tensor_size"] = 0
scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule_with_placeholders(
graph)
unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.dc.allgather_param.default]
release_nodes = defaultdict(list)
for n in unscheduled:
if is_release_node(n):
release_nodes[n.args[2]].append(n)
ag_nodes_in_path = {}
for ag_node in unscheduled_ags:
last_use = node_to_last_use[ag_node]
required_nodes = get_node_requirements(last_use, scheduled)
ag_nodes_in_path[ag_node] = set(n for n in required_nodes if n.target == torch.ops.dc.allgather_param.default)
reduce_nodes = [n for n in unscheduled if n.target == torch.ops.dc.reduce_grad.default]
ag_nodes_in_path_to_reduce_nodes = {}
for reduce_node in reduce_nodes:
ag_nodes_in_path_to_reduce_nodes[reduce_node] = set(n for n in get_node_requirements(reduce_node, scheduled)
if n.target == torch.ops.dc.allgather_param.default)
output_nodes = [
n for n in get_output_node(graph).args[0]
if isinstance(n, Node) and n.target != torch.ops.dc.reduce_grad.default
]
ag_nodes_in_path_to_output_nodes = {}
for output_node in output_nodes:
ag_nodes_in_path_to_output_nodes[output_node] = set(n for n in get_node_requirements(output_node, scheduled)
if n.target == torch.ops.dc.allgather_param.default)
while len(unscheduled_ags) > 0:
ag_nodes_count = {ag_node: len(nodes) for ag_node, nodes in ag_nodes_in_path.items()}
count_list = sorted(set(ag_nodes_count.values()))
runnable_ags = []
for ag_count in count_list:
target_unscheduled_ags = [ag for ag in unscheduled_ags if ag_nodes_count[ag] == ag_count]
for node in target_unscheduled_ags:
ds_id = node.args[2]
schedule_until_ag = get_node_requirements(node, scheduled)
if schedule_until_ag is None:
continue
last_use = node_to_last_use[node]
diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag)
allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag)
free_cost = sum(n.meta["device_time"] for n in diff_required_nodes)
allgathered_mem = node.meta["tensor_size"]
allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag
if n.target == torch.ops.dc.allgather_param.default)
free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes
if n.target == torch.ops.dc.allgather_param.default)
schedule_until_free = schedule_until_ag + diff_required_nodes
for release_node in release_nodes[ds_id]:
if release_node not in schedule_until_free:
schedule_until_free.append(release_node)
n_scheduled_ags = len(
[n for n in schedule_until_free if n.target == torch.ops.dc.allgather_param.default])
task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem,
last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free)
# print(f" ag_count {ag_count} allgather runnable {i}: {node} last_use: {node_to_last_use[node]} t: {t2-t1:.2f}")
runnable_ags.append(task)
if len(runnable_ags) > 0:
break
assert len(runnable_ags) > 0, "No runnable allgather nodes"
# Criteria of the choice:
# We want to choose allgather that does not require additional allgather until releasing the param.
# When we can find such a node, free_acc_mem will be zero. In that case, we choose the one with the smallest cost until free to minimize the period of occupying memory for the gathered param.
# If there is no such node, we choose the one with the smallest free_cost to minimize the period of occupying memory for the gathered param.
ags_with_no_additional_ag = [ag for ag in runnable_ags if ag.free_acc_mem == 0]
if len(ags_with_no_additional_ag) > 0:
sorted_ags = sorted(runnable_ags, key=lambda x: x.free_cost)
next_ag = sorted_ags[0]
nodes_to_schedule = next_ag.schedule_until_free
else:
# sorted_ags = sorted(runnable_ags, key=lambda x: x.allgathered_mem)
sorted_ags = sorted(runnable_ags, key=lambda x: x.free_acc_mem)
next_ag = sorted_ags[0]
nodes_to_schedule = next_ag.schedule_until_ag
# print(f" next_ag {next_ag}")
for n in nodes_to_schedule:
scheduled.append(n)
unscheduled.remove(n)
unscheduled_ags.remove(next_ag.node)
ag_nodes_in_path.pop(next_ag.node)
for ag_node, nodes in ag_nodes_in_path.items():
if next_ag.node in nodes:
nodes.remove(next_ag.node)
# Schedule reduce nodes when possible to free memory earlier
reduces_to_schedule = []
for reduce_node in reduce_nodes:
if next_ag.node in ag_nodes_in_path_to_reduce_nodes[reduce_node]:
ag_nodes_in_path_to_reduce_nodes[reduce_node].remove(next_ag.node)
if len(ag_nodes_in_path_to_reduce_nodes[reduce_node]) == 0:
reduces_to_schedule.append(reduce_node)
for n in reduces_to_schedule:
need_to_schedule = get_node_requirements(n, scheduled)
for nn in need_to_schedule:
scheduled.append(nn)
unscheduled.remove(nn)
# Do the same for output nodes
outputs_to_schedule = []
for output_node in output_nodes:
if next_ag.node in ag_nodes_in_path_to_output_nodes[output_node]:
ag_nodes_in_path_to_output_nodes[output_node].remove(next_ag.node)
if len(ag_nodes_in_path_to_output_nodes[output_node]) == 0:
outputs_to_schedule.append(output_node)
for n in outputs_to_schedule:
need_to_schedule = get_node_requirements(n, scheduled)
for nn in need_to_schedule:
scheduled.append(nn)
unscheduled.remove(nn)
# print(f"After ag scheduled: scheduled: {scheduled}")
scheduled_set = set(scheduled)
for node in graph.nodes:
if node in scheduled_set:
continue
scheduled.append(node)
unscheduled.remove(node)
assert len(unscheduled) == 0, f"There are unscheduled nodes: {unscheduled}"
ret_graph = make_graph_from_schedule(scheduled)
ret_graph.lint()
return ret_graph

View File

@ -0,0 +1,158 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This file was copied from PyTorch and modified for DeepSpeed.
from typing import Tuple, List
import operator
import torch
from torch.fx import GraphModule, Graph, Node
try:
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
except ImportError:
pass
from .util import get_no_copy_ops
_recompute_ops = {torch.ops.aten.t.default}
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
"""
Given a graph and a node that represents a parameter that was allgathered,
find all nodes that use the parameter and require recomputation.
"""
no_copy_ops = get_no_copy_ops()
recompute_nodes = set()
for node in graph.nodes:
if node.target in no_copy_ops:
if ds_param_node in node.args:
recompute_nodes.add(node)
if any(a in recompute_nodes for a in node.args):
recompute_nodes.add(node)
return recompute_nodes
def _get_values_from_ds_params(joint_graph, param_indices):
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
no_copy_ops = get_no_copy_ops()
ds_param_inputs = set(ds_param_inputs)
ds_param_users = {}
for node in joint_graph.nodes:
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
for a in node.args:
if a in ds_param_inputs:
ds_param_users[node] = a
elif a in ds_param_users:
ds_param_users[node] = ds_param_users[a]
return ds_param_users
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
new_saved_values = []
for v in saved_values:
if v in ds_param_users:
ds_val = ds_param_users[v]
if ds_val not in new_saved_values:
new_saved_values.append(ds_val)
else:
new_saved_values.append(v)
return new_saved_values
return ds_choose_saved_values_set
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
"""
This is basically the same as the default_partition function, but
it doesn't save the gathered params and values computed from them.
"""
if has_recomputable_ops(joint_module):
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
saved_values = []
saved_sym_nodes = []
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
ds_param_input_names = {node.name for node in ds_param_inputs}
ds_param_recompute_nodes = set()
for node in joint_module.graph.nodes:
if node.name not in forward_node_names:
continue
if is_sym_node(node):
# Symints must be kept separate from tensors so that PythonFunction only calls
# save_for_backward on tensors and stashes symints in autograd .ctx
saved_sym_nodes.append(node)
elif "tensor_meta" not in node.meta and node.op == "call_function":
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
users = node.users
assert all(user.target == operator.getitem for user in users)
saved_values.extend(users)
else:
backward_usages = [n for n in node.users if n.name not in forward_node_names]
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
# and not the actual tensor data,
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
#
# Note that saving the tensor could also cause compilation problems:
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
# then we would be obligated to clone the input before saving it to appease autograd.
# (This is how we originally found this bug).
saved_sym_nodes.extend(backward_usages)
if node.name in ds_param_input_names:
saved_values.append(node)
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
for recompute_node in recompute_nodes:
ds_param_recompute_nodes.add(recompute_node)
if len(recompute_nodes) > 0:
saved_values.append(node)
else:
if node not in ds_param_recompute_nodes:
saved_values.append(node)
saved_values = list(dict.fromkeys(saved_values).keys())
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
f_gm, b_gm = _extract_fwd_bwd_modules(
joint_module,
saved_values,
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
)
return f_gm, b_gm
return partition_recompute_ds_params

View File

@ -0,0 +1,48 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from ..profilers.graph_profile import MemoryProfilingInterpreter
import deepspeed.comm as dist
def run_opt_passes(nz3,
graph_index,
graph_id,
gm,
create_inputs_fn,
opt_passes,
graph_order,
profiling_results,
param_manager,
bwd,
debug_log=False):
profile = profiling_results[graph_id]
rank = dist.get_rank()
for i, opt_pass in enumerate(opt_passes):
opt_pass_fn, mem_budget = opt_pass
graph = opt_pass_fn(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd)
graph.lint()
gm.graph = graph
gm.recompile()
if debug_log:
print(f"Prefetching enabled for {'bwd' if bwd else 'fwd'} graph_id={graph_id} {graph}")
mem_prof = MemoryProfilingInterpreter(nz3, gm)
mem_prof.run(*create_inputs_fn())
if debug_log and rank == 0:
mem_prof.dump(f"mem_prof_r{rank}_{'bwd' if bwd else 'fwd'}_{graph_index}_{graph_id}_pass_{i}.csv")
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
if bwd:
profile.bwd_mem = mem
else:
profile.fwd_mem = mem
return gm

View File

@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List, Dict, Set, Tuple
import random
from collections import defaultdict
import torch
from torch.fx import Graph, Node
from ..fx import get_output_node, move_primals_to_head
from ..graph_param import DSGraphParamManager
value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict)
used_ids: Set[int] = set()
def get_random_id() -> int:
def _gen():
# generate random int
return random.randint(10000, 2**31)
global used_ids
v = _gen()
while v in used_ids:
v = _gen()
used_ids.add(v)
return v
def _should_offload(node: Node) -> bool:
if not hasattr(node, "meta"):
return False
if not "tensor_meta" in node.meta:
return False
return True
def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]],
graph_order: List[int], mem_budget: float, param_manager: DSGraphParamManager) -> Graph:
param_names = set(param_manager.param_names)
import copy
cl_graph = copy.deepcopy(graph)
cl_graph.erase_node(get_output_node(cl_graph))
global value_to_id
for name, node in nodes_to_offload_with_names:
if node.name in param_names:
continue
if not _should_offload(node):
continue
val_id = get_random_id()
with graph.inserting_after(node):
offload_node = graph.create_node('call_function',
torch.ops.dc.offload_tensor.default, (node, graph_id, val_id), {},
name=f"offload_{node.name}_{val_id}")
with graph.inserting_after(offload_node):
wait_node = graph.create_node('call_function',
torch.ops.dc.wait_offload.default, (offload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")
output_node = get_output_node(graph)
output_node.replace_input_with(node, wait_node)
value_to_id[graph_id][name] = val_id
graph = move_primals_to_head(graph)
graph.lint()
return graph
def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
param_manager: DSGraphParamManager) -> Graph:
graph_value_to_id = value_to_id[graph_id]
name_to_node = {n.name: n for n in graph.nodes}
act_nodes = [name_to_node[n] for n in graph_value_to_id.keys()]
node_to_first_user = {}
for act in act_nodes:
for node in graph.nodes:
if act in node.args:
node_to_first_user[act] = node
break
for node in act_nodes:
val_id = graph_value_to_id[node.name]
with graph.inserting_before(node_to_first_user[node]):
reload_node = graph.create_node('call_function',
torch.ops.dc.reload_tensor.default, (node, graph_id, val_id), {},
name=f"reload_{node.name}_{val_id}")
with graph.inserting_after(reload_node):
wait_node = graph.create_node('call_function',
torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {},
name=f"wait_copy_{node.name}_{val_id}")
# replace all uses of node with wait_node
users = {}
for u in node.users.keys():
if u != reload_node:
users[u] = (node, wait_node)
for u, (old_in, new_in) in users.items():
u.replace_input_with(old_in, new_in)
graph = move_primals_to_head(graph)
graph.lint()
return graph

View File

@ -0,0 +1,546 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import copy
from typing import List
import torch
from torch.fx import Graph, GraphModule
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.offload_states import _make_offload_state_key
try:
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
# Unsupported torch version
pass
from ..profilers import ProfilingResult
from ..graph_param import DSGraphParamManager
from ..fx import move_primals_to_head
import deepspeed.comm as dist
NAME = "offload_adam_states"
def print_r0(msg):
if dist.get_rank() == 0:
print(msg)
MARGIN = 0.2
copy_stream = None
offload_event = None
reload_event = None
offload_key_events = {}
reload_key_events = {}
max_memory = 0
def lazy_init():
global copy_stream
global offload_event
global reload_event
if copy_stream is None:
copy_stream = get_accelerator().Stream()
offload_event = get_accelerator().Event()
reload_event = get_accelerator().Event()
optimizer = None
device = None
nz3 = None
def move_key(state, key, key_event=None):
offload_buf_key = _make_offload_state_key(key)
if offload_buf_key not in state:
state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu"))
if key not in state:
return
with get_accelerator().stream(copy_stream):
state[offload_buf_key].copy_(state[key], non_blocking=True)
if key_event is None:
offload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def move_back_key(state, key, key_event=None):
with get_accelerator().stream(copy_stream):
state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device)
state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True)
if key_event is None:
reload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def move_hp_param(src_tensor, dest_buf, key_event=None):
with get_accelerator().stream(copy_stream):
dest_buf.copy_(src_tensor, non_blocking=True)
src_tensor.data = dest_buf
if key_event is None:
reload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def move_back_hp_param(src_tensor, dest_buf, key_event=None):
with get_accelerator().stream(copy_stream):
dest_buf.data = torch.empty_like(src_tensor, device=device)
dest_buf.copy_(src_tensor, non_blocking=True)
if key_event is None:
reload_event.record(stream=copy_stream)
else:
key_event.record(stream=copy_stream)
def offload_adam_states_sync():
with unset_fake_temporarily():
if not hasattr(optimizer, "hp_params_pin_buffers"):
optimizer.hp_params_pin_buffers = [
get_accelerator().pin_memory(torch.empty_like(t, device="cpu"))
for t in optimizer.fp32_partitioned_groups_flat
]
for i, (k, state) in enumerate(optimizer.state.items()):
if "exp_avg" in state:
move_key(state, "exp_avg")
if "exp_avg_sq" in state:
move_key(state, "exp_avg_sq")
for _, state in optimizer.state.items():
if "exp_avg" in state:
del state["exp_avg"]
if "exp_avg_sq" in state:
del state["exp_avg_sq"]
for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers):
move_hp_param(src_tensor, dest_buf)
get_accelerator().synchronize()
def reload_adam_states_sync():
with unset_fake_temporarily():
# print_r0("Reloading Adam states")
for _, state in optimizer.state.items():
if _make_offload_state_key("exp_avg") in state:
move_back_key(state, "exp_avg")
if _make_offload_state_key("exp_avg_sq") in state:
move_back_key(state, "exp_avg_sq")
for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat):
move_back_hp_param(src, dest)
get_accelerator().synchronize()
def sync_offload_states(event=None):
if nz3.is_profiling():
offload_adam_states_sync()
else:
if event is None:
offload_event.wait(copy_stream)
else:
event.wait(copy_stream)
def sync_reload_states(event=None):
if nz3.is_profiling():
reload_adam_states_sync()
else:
if event is None:
reload_event.wait(copy_stream)
else:
event.wait(copy_stream)
def make_offload_task(task):
def run_offload_task():
# if not nz3.is_profiling():
# print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}")
if offload_key_events.get(task[1]) is None:
offload_key_events[task[1]] = get_accelerator().Event()
if task[2] == "hp_param":
move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]])
else:
assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer"
state = optimizer.state[task[1]]
# if offload_key_events.get(task[1]) is None:
# offload_key_events[task[1]] = get_accelerator().Event()
move_key(state, task[2], offload_key_events[task[1]])
return run_offload_task
def make_offload_sync(task):
def run_offload_sync():
# if not nz3.is_profiling():
event = offload_key_events[task[1]]
event.synchronize()
if task[2] != "hp_param":
state = optimizer.state[task[1]]
key = task[2]
if key in state:
del state[key]
# print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}")
return run_offload_sync
def make_reload_task(task):
def run_reload_task():
if not nz3.is_profiling():
if reload_key_events.get(task[1]) is None:
reload_key_events[task[1]] = get_accelerator().Event()
if task[2] == "hp_param":
move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]])
else:
state = optimizer.state[task[1]]
# print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}")
move_back_key(state, task[2], reload_key_events[task[1]])
return run_reload_task
def update_max_memory(name):
global max_memory
mem = get_accelerator().max_memory_allocated()
max_memory = max(max_memory, mem)
def empty_cache():
get_accelerator().empty_cache()
offload_tasks = []
offload_tasks_remaining = []
offload_tasks_scheduled = []
reload_task_remaining = []
total_reload_mem = 0
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
to_remove = []
for node in graph.nodes:
if node.op == 'call_function' and \
node.target in [offload_adam_states_sync, sync_offload_states, reload_adam_states_sync, sync_reload_states, update_max_memory]:
to_remove.append(node)
for node in to_remove:
graph.erase_node(node)
accelerator = get_accelerator()
total_mem = accelerator.total_memory() * (1 - MARGIN)
print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}")
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
mem_dict = {name: peak for name, alloc_mem, delta, peak in mem}
current_peak_mem = 0
peak_mem = {}
ordered_node = reversed(graph.nodes) if bwd else graph.nodes
for node in ordered_node:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if mem_dict[node.name] > current_peak_mem:
current_peak_mem = mem_dict[node.name]
peak_mem[node.name] = current_peak_mem
# fwd_max_mem = max(m[3] for m in prof.fwd_mem)
# bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
# peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled
if not bwd:
is_first_graph = graph_id == graph_order[0][0]
# print_r0(
# f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}"
# )
# At the beginning of the first graph, we schedule offload tasks to launch all offloading
if is_first_graph:
# print_r0(
# f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}"
# )
with unset_fake_temporarily():
offload_adam_states_sync()
reload_adam_states_sync()
sync_reload_states()
reload_size = 0
for i, ((k, state), hp_param, hp_param_cpu) in enumerate(
zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat,
optimizer.hp_params_pin_buffers)):
# print_r0(
# f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}")
if _make_offload_state_key("exp_avg") in state:
key = _make_offload_state_key("exp_avg")
size = state[key].numel() * state[key].element_size()
# if total_mem < max_memory + reload_size + size:
offload_tasks.append(
(i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype))
# print_r0(
# f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
# )
if _make_offload_state_key("exp_avg_sq") in state:
key = _make_offload_state_key("exp_avg_sq")
size = state[key].numel() * state[key].element_size()
# if total_mem < max_memory + reload_size + size:
offload_tasks.append(
(i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype))
# print_r0(
# f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
# )
hp_param_size = hp_param.numel() * hp_param.element_size()
# if total_mem < max_memory + reload_size + hp_param_size:
offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param",
hp_param.numel() * hp_param.element_size(), hp_param.dtype))
# print_r0(
# f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}"
# )
# print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}")
for node in graph.nodes:
# print_r0(f"checking sync node insert node: {node.name}")
if node.name not in peak_mem \
or node.op == 'placeholder' \
or "offload_opt_" in node.name:
continue
to_offload = []
optim_size = sum([task[3] for task in offload_tasks])
# print_r0(
# f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
# )
while total_mem - peak_mem[node.name] - optim_size < 0:
if len(offload_tasks) == 0:
break
task = offload_tasks.pop(0)
to_offload.append(task)
optim_size = sum([task[3] for task in offload_tasks])
# print_r0(
# f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
# )
for task in to_offload:
with graph.inserting_before(node):
graph.create_node('call_function',
make_offload_sync(task), (), {},
name=f"offload_opt_sync_{task[0]}_{task[2]}")
print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}")
offload_tasks_scheduled.append(task)
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op != 'placeholder':
print_r0(f"Inserting all offload tasks before {node.name}")
for task in offload_tasks_scheduled:
name = f"offload_opt_{task[0]}_{task[2]}"
with graph.inserting_before(node):
offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name)
break
# print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}")
print_r0(f"offload_opt_states_inc finish graph {graph_id}")
else:
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
is_first_graph = graph_id == graph_order_with_backward[-1]
is_last_graph = graph_id == graph_order_with_backward[0]
# print_r0(
# f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
# )
if is_first_graph:
inserted_sync = False
for node in graph.nodes:
if node.op != 'placeholder' and not inserted_sync:
# print(f"Inserting offload_sync before {node.name}")
with graph.inserting_before(node):
graph.create_node('call_function', empty_cache, (), {}, name="empty_cache")
inserted_sync = True
reload_tasks_remaining = copy.copy(offload_tasks_scheduled)
global total_reload_mem
for node in graph.nodes:
if node.name not in peak_mem \
or node.op == 'placeholder' \
or node.op == 'output' \
or "offload_opt_sync_" in node.name:
continue
if len(reload_tasks_remaining) > 0:
task = reload_tasks_remaining[0]
next_reload_mem = task[3]
insert_pos = node
while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem:
expected_mem = peak_mem[node.name] + total_reload_mem
print_r0(
f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}"
)
with graph.inserting_after(insert_pos):
insert_pos = graph.create_node('call_function',
make_reload_task(task), (), {},
name=f"reload_opt_{task[0]}_{task[2]}")
total_reload_mem += next_reload_mem
reload_tasks_remaining.pop(0)
if len(reload_tasks_remaining) == 0:
break
task = reload_tasks_remaining[0]
next_reload_mem = task[3]
# prev_node = node
if is_last_graph:
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op == 'output':
for task in reload_tasks_remaining:
with graph.inserting_before(node):
graph.create_node('call_function',
make_reload_task(task), (), {},
name=f"reload_opt_{task[0]}_{task[2]}")
sync_fn = lambda: copy_stream.synchronize()
with graph.inserting_before(node):
graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream")
print_r0(
f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
)
return graph
def add_record_max_mem_nodes(graph: Graph):
nodes = list(graph.nodes)
for node in nodes:
if node.op == "output" or node.op == "placeholder":
continue
with graph.inserting_after(node):
name = f"update_max_memory_{node.name}"
graph.create_node('call_function', update_max_memory, (name, ), {}, name=name)
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
if bwd:
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
is_last_graph = graph_id == graph_order_with_backward[0]
inserted_reload = False
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op == 'output' and not inserted_reload and is_last_graph:
# print(f"Inserting reload_opt before {node.name}")
with graph.inserting_before(node):
graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt")
inserted_reload = True
# add_record_max_mem_nodes(graph)
else:
is_first_graph = graph_id == graph_order[0][0]
graph = move_primals_to_head(graph)
inserted_offload = False
for node in graph.nodes:
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
if node.op != 'placeholder' and not inserted_offload and is_first_graph:
print(f"Inserting offload_opt before {node.name}")
with graph.inserting_before(node):
graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt")
inserted_offload = True
add_record_max_mem_nodes(graph)
return graph
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
bwd)
return gm
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
bwd)
return gm
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results,
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
bwd: bool) -> GraphModule:
if not bwd and graph_id == graph_order[0][0]:
with unset_fake_temporarily():
offload_adam_states_sync()
# returns None, and profiling will be skipped
def init_offload_opt_states(adam_optimizer, _nz3):
lazy_init()
global optimizer
optimizer = adam_optimizer
global device
device = torch.device(get_accelerator().current_device())
global nz3
nz3 = _nz3

View File

@ -0,0 +1,54 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import Node, GraphModule
from deepspeed.compile.util import get_last_uses
from ..graph_param import DSGraphParamManager
def add_offload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
new_node = None
with gm.graph.inserting_after(node):
args = (node, )
for a in [graph_id, ds_id]: # To add ds_id
args += (a, )
new_node = gm.graph.create_node('call_function',
torch.ops.dc.offload_parameter.default,
args, {},
name="offload_parameter")
return new_node
def add_reload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
new_node = None
with gm.graph.inserting_after(node):
args = (node, )
for a in [graph_id, ds_id]: # To add ds_id
args += (a, )
new_node = gm.graph.create_node('call_function',
torch.ops.dc.reload_parameter.default,
args, {},
name=f"reload_parameter")
return new_node
def get_ds_id(node: Node):
assert node.target == torch.ops.dc.allgather_param.default
return node.args[2]
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
node_to_last_use, user_to_last_uses = get_last_uses(gm.graph)
for node in gm.graph.nodes:
if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default):
add_reload_parameter(graph_id, gm, node.args[0], get_ds_id(node))
add_offload_parameter(graph_id, gm, node_to_last_use[node], get_ds_id(node))
gm.graph.lint()
return gm

View File

@ -0,0 +1,174 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import Graph, Node, GraphModule
from deepspeed.accelerator import get_accelerator
import deepspeed.comm as dist
from ..profilers.comm_profile import create_predictor
from ..graph_param import DSGraphParamManager
NAME = "prefetch"
FUSE_FACTOR = 0.8
MARGIN = 0.1
MAX_FUSE_SIZE = 1e9
MAX_BUFFERED_SIZE = 4e9
run_prefetch_pass = False
def print_rank_0(message):
if dist.get_rank() == 0:
print(message)
def get_ds_id(node: Node):
assert node.target == torch.ops.dc.allgather_param.default
return node.args[2]
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
max_mem = get_accelerator().total_memory() * (1 - MARGIN)
vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
max_mem = vals_to_bcast[0].item()
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
op_time = profiling_results[graph_id].bwd_time if bwd else profiling_results[graph_id].fwd_time
tensor_sizes = profiling_results[graph_id].bwd_tensor_sizes if bwd else profiling_results[graph_id].fwd_tensor_sizes
mem_dict = {name: (alloc_mem, peak) for name, alloc_mem, delta, peak in mem}
time_dict = {name: (device_time, wall_time) for name, device_time, wall_time in op_time}
tensor_size_dict = {name: size for name, size in tensor_sizes}
graph = gm.graph
total_param_size = sum(
[tensor_size_dict[n.name] for n in graph.nodes if n.target == torch.ops.dc.allgather_param.default])
print_rank_0(
f"schedule_prefetch graph_id={graph_id} max_mem={max_mem} available_memory={get_accelerator().available_memory()} memory_allocated={get_accelerator().memory_allocated()} max_allocated={get_accelerator().max_memory_allocated()} total_param_size={total_param_size} margin={MARGIN}"
)
# Fill missing values
prev_mem = 0
prev_peak = 0
for node in graph.nodes:
if node.name in mem_dict:
prev_mem = mem_dict[node.name][0]
prev_peak = mem_dict[node.name][1]
else:
print_rank_0(f"node {node.name} not in mem_dict")
mem_dict[node.name] = (prev_mem, prev_peak)
comm_predictor = create_predictor()
order_rev = list(reversed(graph.nodes))
new_order_rev = []
prefetch_ags = []
prefetch_ag_groups = []
ag_tensor_size_sum = 0
for i, node in enumerate(order_rev):
# print_rank_0(
# f"Checking node reverse order {node.name} {node.target} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
if node.op != "placeholder":
assert i < len(order_rev) - 1
assert node.name in mem_dict
next_node = order_rev[i + 1]
next_alloc_mem, next_peak = mem_dict[next_node.name]
# Free up memory
while next_peak + ag_tensor_size_sum > max_mem or ag_tensor_size_sum > MAX_BUFFERED_SIZE:
if len(prefetch_ag_groups) > 0:
# launch prefetch
fused_ag_nodes = prefetch_ag_groups.pop(0)
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in fused_ag_nodes])
ag_tensor_size_sum -= total_ag_tensor_size
new_order_rev.append(fused_ag_nodes)
assert len(fused_ag_nodes) > 0
# print_rank_0(
# f"Free up memory fused_ag_nodes={fused_ag_nodes} next_alloc_mem={next_alloc_mem} total_ag_tensor_size={total_ag_tensor_size} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
elif len(prefetch_ags) > 0:
prefetch_ag_groups.append(prefetch_ags)
prefetch_ags = []
# print_rank_0(
# f"Free up memory prefetch_ags={prefetch_ag_groups} next_alloc_mem={next_alloc_mem} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
# )
else:
break
if node.target == torch.ops.dc.allgather_param.default:
current_ag_size = sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
pred_time_current = comm_predictor(current_ag_size)
pred_time_next = comm_predictor(tensor_size_dict[node.name])
pred_time_fused = comm_predictor(current_ag_size + tensor_size_dict[node.name])
do_fuse = max(pred_time_current, pred_time_next) * 1.2 > pred_time_fused and (
current_ag_size + tensor_size_dict[node.name]) < MAX_FUSE_SIZE
# print_rank_0(
# f"found allgather_param do_fuse={do_fuse} current_ag_size={current_ag_size} tensor_size_dict[node.name]={tensor_size_dict[node.name]} pred_time_current={pred_time_current} pred_time_next={pred_time_next} pred_time_fused={pred_time_fused}"
# )
if len(prefetch_ags) > 0 and not do_fuse:
# stop fusing here
prefetch_ag_groups.append(prefetch_ags)
prefetch_ags = []
# print_rank_0(
# f"stop fusing prefetch_ags={prefetch_ag_groups} ag_tensor_size_sum={ag_tensor_size_sum}")
# else:
# print_rank_0(
# f"continue fusing ag_tensor_size_sum={ag_tensor_size_sum} ag_size={tensor_size_dict[node.name]} prefetch_ags={prefetch_ags} prefetch_ag_groups={prefetch_ag_groups}"
# )
prefetch_ags.append(node)
ag_tensor_size_sum += tensor_size_dict[node.name]
new_order_rev.append(node)
if (node.op != "placeholder"
and node.target != torch.ops.dc.reload_parameter) and order_rev[i + 1].op == "placeholder":
for ag_group in prefetch_ag_groups:
assert len(ag_group) > 0
new_order_rev.append(ag_group)
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in ag_group])
ag_tensor_size_sum -= total_ag_tensor_size
if len(prefetch_ags) > 0:
new_order_rev.append(prefetch_ags)
ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
assert ag_tensor_size_sum == 0
# print_rank_0(
# f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}"
# )
assert ag_tensor_size_sum >= 0
new_graph = Graph()
env = {}
for node in reversed(new_order_rev):
if isinstance(node, Node):
#print(f"reconstruct {node.name} {node.target}")
new_node = new_graph.node_copy(node, lambda n: env[n.name])
env[node.name] = new_node
else:
param_nodes = [ag_node.args[0] for ag_node in node]
param_nodes_copy = [env[param_node.name] for param_node in param_nodes]
ds_ids = [get_ds_id(ag_node) for ag_node in node]
new_graph.call_function(torch.ops.dc.prefetch_params_fused.default,
args=(graph_id, param_nodes_copy, ds_ids))
new_graph.lint()
gm.graph = new_graph
return gm

View File

@ -0,0 +1,146 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from collections import defaultdict
from typing import List
import torch
from torch.fx import GraphModule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from ..util import get_deepcompile_handle
from ..graph_param import DSGraphParamManager
NAME = "selective_gather"
max_alloc_mem = 0
last_optimize_step = 0
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
if not bwd:
return gm
last_backward_graph_id = None
for g_id, needs_bwd in graph_order:
if needs_bwd:
last_backward_graph_id = g_id
break
# Run only on the last backward graph
if last_backward_graph_id is None or graph_id != last_backward_graph_id:
return gm
peak_mem = 0
for graph_id, prof in profiling_results.items():
# Use peak memory
fwd_max_mem = max(m[3] for m in prof.fwd_mem)
bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
if dist.get_rank() == 0:
print(
f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}"
)
persistent_ds_ids = set()
for graph_id, pm in param_manager.items():
for name, ds_param in pm.params.items():
if ds_param.param.ds_persist:
persistent_ds_ids.add(pm.ds_ids[name])
ds_id_to_size = {}
ds_id_to_time = defaultdict(float)
ds_id_to_prof_dtime = defaultdict(float)
ds_id_to_prof_wtime = defaultdict(float)
for graph_id, pm in param_manager.items():
params = pm.params
for param_name, param in params.items():
ds_id = pm.ds_ids[param_name]
ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize
profile = profiling_results[graph_id]
for n in profile.fwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
assert "device_time" in n.meta
ds_id_to_time[n.args[2]] += n.meta["device_time"]
ds_id_to_prof_dtime[n.args[2]] = n.meta["device_time"]
ds_id_to_prof_wtime[n.args[2]] = n.meta["wall_time"]
if profile.bwd_graph is not None:
for n in profile.bwd_graph.nodes:
if n.target == torch.ops.dc.allgather_param.default:
assert "tensor_size" in n.meta
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
assert "device_time" in n.meta
ds_id_to_time[n.args[2]] += n.meta["device_time"]
ds_ids = [ds_id for ds_id in ds_id_to_size if ds_id not in persistent_ds_ids]
ds_ids.sort(key=lambda ds_id: ds_id_to_time[ds_id] / ds_id_to_size[ds_id], reverse=True)
# print(f"ds_id_to_size={ds_id_to_size}")
# print(f"ds_id_to_time={ds_id_to_time}")
# if dist.get_rank() == 0:
# for ds_id in ds_ids:
# dtime_in_sec = ds_id_to_prof_dtime[ds_id]
# wtime_in_sec = ds_id_to_prof_wtime[ds_id]
# size_in_mb = ds_id_to_size[ds_id] / 1024 / 1024
# print(
# f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
# )
sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids}
accelerator = get_accelerator()
total_mem = accelerator.total_memory()
vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device()))
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
total_mem = vals_to_bcast[0].item()
MEM_MARGIN = 0.1
available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem
if dist.get_rank() == 0:
print(
f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}"
)
ds_id_to_param = {}
for g_id, g_pm in param_manager.items():
for name, ds_param in g_pm.params.items():
ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param
persistent_mem = 0
nz3 = get_deepcompile_handle()
for ds_id, size in sorted_ds_ids.items():
if persistent_mem + size > available_mem:
break
persistent_mem += size
param_obj = ds_id_to_param[ds_id]
nz3.set_persistent(ds_id)
if dist.get_rank() == 0:
print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
return gm
# def make_selective_gather(z3_optimizer, nz3):
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results,
# mem_budget: float, param_manager, bwd: bool) -> Graph:
# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd,
# z3_optimizer, nz3)
# return selective_gather_wrapper

View File

@ -0,0 +1,55 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List
import torch
from torch.fx import GraphModule
from ..util import get_deepcompile_handle
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta
NAME = "zero1_compile"
def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager) -> GraphModule:
dc = get_deepcompile_handle()
param_indices = profiling_results[graph_id].param_indices
dc.register_graph_z1(graph_id, [v[1] for v in param_indices]) # Need this before profiling
return gm
def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule:
graph = gm.graph
pm = param_manager[graph_id]
_, param_name_to_grad = pm.get_bwd_mapping(graph)
for param_name in pm.param_names:
grad_node = param_name_to_grad[param_name]
assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids"
ds_id = pm.ds_ids[param_name]
new_node = add_postprocess(graph,
grad_node,
torch.ops.dc.reduce_grad.default,
extra_args=[graph_id, ds_id],
name=f"reduce_param_{param_name}",
meta=_make_node_meta(grad_node, param_name, True))
new_node.meta["val"] = None
gm.graph = move_primals_to_head(graph)
return gm
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z1_reduce_bw(gm, graph_id, param_manager)
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager)

View File

@ -0,0 +1,186 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import gc
from typing import List, Dict
import torch
from torch.fx import Graph, Node, GraphModule
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
from ..profilers.graph_profile import ProfilingInterpreter
from ..list_schedule import fast_free_schedule
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
NAME = "zero3_compile"
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
new_ag_node = add_postprocess(graph,
node,
torch.ops.dc.allgather_param.default,
extra_args=[graph_id, ds_id],
name=f"allgather_ds_param_{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, True))
new_ag_node.meta["val"] = node.meta["val"]
# Set the previous node back to output
# We don't want to change the output node to allgather
output_node = get_output_node(graph)
output_node.replace_input_with(new_ag_node, node)
# Add wait as well
new_wait_node = add_postprocess(graph,
new_ag_node,
torch.ops.dc.wait_allgather.default,
extra_args=[graph_id, ds_id],
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_wait_node.meta["val"] = node.meta["val"]
return new_ag_node
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
new_node = add_postprocess(graph,
node,
torch.ops.dc.release_param.default,
extra_args=[graph_id, ds_id, n_users],
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
meta=_make_node_meta(node, ds_id, False))
new_node.meta["val"] = None
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
new_node = add_postprocess(graph,
grad_node,
torch.ops.dc.reduce_grad.default,
extra_args=[graph_id, ds_id],
name=f"reduce_ds_param_{param_name}",
meta=_make_node_meta(grad_node, ds_id, True))
new_node.meta["val"] = None
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
node_to_uses = get_real_uses(graph)
for pn in param_nodes:
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
ds_id = param_manager.ds_ids[pn.name]
users = node_to_uses[pn]
for user in users:
add_release(graph_id, graph, user, pn, ds_id, len(users))
return move_primals_to_head(graph)
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
param_name_to_grad: Dict[str, Node]) -> Graph:
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
for param_name in param_manager.param_names:
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
return move_primals_to_head(graph)
def add_z3_gather_release_fw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
nz3 = get_deepcompile_handle()
real_inputs = create_inputs_fn()
param_indices = profiling_results[graph_id].param_indices
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
get_param_nodes(gm.graph, param_indices))
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
profiler.run(*real_inputs)
del profiler
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
for n in gm.graph.nodes:
is_ds_param = n.name in param_manager[graph_id].ds_ids
if "val" in n.meta and is_ds_param:
# Used for Inductor's validation
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
gm.graph = fast_free_schedule(
gm.graph,
get_accelerator().available_memory(),
0, # unused
debug_log=debug_log)
if rank == 0 and debug_log:
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
return gm
def add_z3_gather_release_bw(gm: GraphModule,
graph_id: int,
graph_order: List[int],
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False) -> GraphModule:
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
input_nodes = get_input_nodes(gm.graph)
real_inputs = create_inputs_fn()
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
del real_outputs
gc.collect()
get_accelerator().empty_cache()
rank = dist.get_rank()
graph_index = get_index_by_graph_id(graph_order, graph_id)
if rank == 0 and debug_log:
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
# gm.graph = fast_free_schedule(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log)
return gm
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
if bwd:
return add_z3_gather_release_bw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)
return add_z3_gather_release_fw(gm,
graph_id,
graph_order,
profiling_results,
create_inputs_fn,
param_manager,
debug_log=False)

View File

@ -0,0 +1,93 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.utils.torch import required_torch_version
backward_inputs = []
enabled_patched_func = False
original_grad_fn = None
base_meta = type(torch.autograd.Function)
if required_torch_version(min_version=2.7):
class FunctionMeta(base_meta):
def __new__(cls, name, bases, dct):
if name == "CompiledFunction":
original_backward_impl = dct.get("_backward_impl")
def wrapped_backward_impl(ctx, all_args):
assert original_backward_impl is not None
if enabled_patched_func:
backward_inputs.append(all_args)
wrapped_backward_impl.owner_class.compiled_bw = None
return original_backward_impl(ctx, all_args)
wrapped_backward_impl.owner_class = None
dct["_backward_impl"] = staticmethod(wrapped_backward_impl)
new_class = super().__new__(cls, name, bases, dct)
wrapped_backward_impl.owner_class = new_class
return new_class
return super().__new__(cls, name, bases, dct)
elif required_torch_version(min_version=2.6):
class FunctionMeta(base_meta):
def __new__(cls, name, bases, dct):
if name == "CompiledFunction":
original_backward_prologue = dct.get("_backward_prologue")
def wrapped_backward_prologue(ctx, *grad_outputs):
assert original_backward_prologue is not None
all_args = original_backward_prologue(ctx, *grad_outputs)
if enabled_patched_func:
backward_inputs.append(all_args)
wrapped_backward_prologue.owner_class.compiled_bw = None
return all_args
wrapped_backward_prologue.owner_class = None
dct["_backward_prologue"] = staticmethod(wrapped_backward_prologue)
new_class = super().__new__(cls, name, bases, dct)
wrapped_backward_prologue.owner_class = new_class
return new_class
return super().__new__(cls, name, bases, dct)
def patch_compiled_func():
global enabled_patched_func
enabled_patched_func = True
class PatchedFunction(torch.autograd.Function, metaclass=FunctionMeta):
pass
global original_grad_fn
original_grad_fn = torch.autograd.Function
torch.autograd.Function = PatchedFunction
return backward_inputs
def unpatch_compiled_func():
global enabled_patched_func
enabled_patched_func = False
global original_grad_fn
torch.autograd.Function = original_grad_fn
def get_backward_inputs():
return backward_inputs

View File

@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
try:
from torch._subclasses import FakeTensorMode
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._dynamo.variables.builder import wrap_to_fake_tensor_and_record
except ImportError:
# Unsupported torch version
pass
def wrap_if_ds_param(t):
if hasattr(t, 'ds_id'):
data = torch.rand(t.ds_shape,
dtype=t.dtype,
layout=t.layout,
device=t.device,
pin_memory=t.is_pinned(),
requires_grad=t.requires_grad)
if isinstance(t, torch.nn.Parameter):
t = torch.nn.Parameter(data, requires_grad=t.requires_grad)
else:
t = data
return t
def patch_fake_tensor():
# dynamo tracer uses wrap_to_fake_tensor_and_record
# Wrapping FakeTensorMode.from_tensor is not sufficient as dynamo generates SymbolicContext before calling from_tensor
original_wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record
def wrap_to_fake_tensor_and_record_wrapper(t, *args, **kwargs):
dummy_tensor = wrap_if_ds_param(t)
ret = original_wrap_to_fake_tensor_and_record(dummy_tensor, *args, **kwargs)
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.tensor_to_context[t] = tracing_context.tensor_to_context.pop(dummy_tensor)
return ret
torch._dynamo.variables.builder.wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record_wrapper
# aot_module_simplified uses fake_mode.from_tensor to process inputs
original_from_tensor = FakeTensorMode.from_tensor
def from_tensor_wrapper(self, t, *args, **kwargs):
with unset_fake_temporarily():
return original_from_tensor(self, wrap_if_ds_param(t), *args, **kwargs)
FakeTensorMode.from_tensor = from_tensor_wrapper

View File

@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import List, Tuple
from dataclasses import dataclass, field
from torch.fx import Graph
@dataclass
class ProfilingResult:
fwd_graph: Graph = None
bwd_graph: Graph = None
needs_backward: bool = False
fwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) # name, current_alloc, delta, peak
bwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list)
fwd_time: List[Tuple[str, int, int]] = field(default_factory=list) # name, device_time, wall_time
bwd_time: List[Tuple[str, int, int]] = field(default_factory=list)
fwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list) # name, size
bwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list)
param_indices: List[Tuple[int, int, Tuple[int, ...]]] = field(default_factory=list) # index, ds_id, ds_shape

View File

@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
try:
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
# Unsupported torch version
pass
import deepspeed
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
def sync_all():
get_accelerator().synchronize()
dist.barrier()
def get_bw(comm_op, size, duration):
n = dist.get_world_size()
tput = 0
busbw = 0
if duration == 0:
raise ValueError("Error. Duration is 0.")
if comm_op == "all_to_all":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_reduce":
tput = (size * 2 / duration)
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_op == "pt2pt" or comm_op == "broadcast":
tput = (size / duration)
busbw = tput
else:
raise ValueError("wrong comm_op specified")
return tput, busbw
# Run all_gather and print metrics
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op):
sync_all()
# Warmups, establish connections, etc.
for i in range(warmup):
dist.all_gather_into_tensor(output, input, async_op=async_op)
sync_all()
# time the actual comm op trials times and average it
start_event.record()
for i in range(trials):
dist.all_gather_into_tensor(output, input, async_op=async_op)
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000
# maintain and clean performance data
avg_duration = duration / trials
size = input.element_size() * input.nelement() * dist.get_world_size()
# tput, busbw = get_bw('all_gather', size, avg_duration)
avg_duration_ten = torch.tensor([avg_duration], device=device)
if dist.get_world_size() > 1:
dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG)
return size, avg_duration_ten.item()
def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False):
# Prepare benchmark header
global_rank = dist.get_rank()
world_size = dist.get_world_size()
start_event = get_accelerator().Event(enable_timing=True)
end_event = get_accelerator().Event(enable_timing=True)
# Create list of message sizes
M_LIST = []
for x in (2**p for p in range(1, maxsize)):
m = x // world_size
if m > 0:
M_LIST.append(m)
results = [(0, 0)]
sync_all()
# loop over various tensor sizes
for M in M_LIST:
global_rank = dist.get_rank()
try:
mat = torch.ones(M, dtype=dtype, device=device)
sync_all()
input = ((mat.mul_(float(global_rank))).view(-1))
# Delete original mat to avoid OOM
del mat
get_accelerator().empty_cache()
output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device)
except RuntimeError as e:
if 'out of memory' in str(e):
if dist.get_rank() == 0:
print('WARNING: Ran out of GPU memory. Exiting comm op.')
sync_all()
break
else:
raise e
sync_all()
results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op))
return results
profile_results = None
def create_predictor():
global profile_results
if profile_results is None:
with unset_fake_temporarily():
device = get_accelerator().current_device()
profile_results = run_all_gather(device, torch.bfloat16, 31)
if dist.get_rank() == 0:
for size, avg_duration in profile_results:
print(f"size: {size}, avg_duration: {avg_duration}")
# Extract size and avg_duration from results
sizes = [result[0] for result in profile_results]
durations = [result[1] for result in profile_results]
try:
from scipy.interpolate import interp1d
except ImportError:
raise RuntimeError("Please install scipy to use communication profiler in DeepCompile")
predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate")
def f(size):
if size == 0:
return 0
return predictor(size)
# Create an interpolation function
return f
if __name__ == "__main__":
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)
print(f"local_rank={local_rank}")
deepspeed.init_distributed(dist_backend='nccl')
# Create predictor function
predictor = create_predictor()
# Predict time for a specific data size
example_size = 1e9
predicted_time = predictor(example_size)
print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds")
dist.destroy_process_group()

View File

@ -0,0 +1,295 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import time
from typing import Any, Tuple, Dict
import statistics
import torch
from torch.fx import GraphModule, Interpreter
from torch.fx.node import map_aggregate
try:
from torch.utils._pytree import tree_all, tree_leaves
from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake
except ImportError:
# Unsupported torch version
pass
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from ..util import is_comm_op, is_release_node, get_deepcompile_handle
def _all_real_if_tensor(args):
return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args)
def _to(v, device):
if torch.is_tensor(v):
with unset_fake_temporarily():
return v.to(device)
return v
def _args_to_key(v):
def _tensor_to_key(v) -> str:
if torch.is_tensor(v):
if v.numel() == 1:
return f"{v.dtype}{v.device}{v.item()}"
else:
return f"{v.dtype}{v.device}{v.shape}"
return str(v)
return map_aggregate(v, _tensor_to_key)
def _node_size(out):
return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)])
def _get_mem_usage_out_of_torch():
adjust = 0
try:
import pynvml
pynvml.nvmlInit()
current_dev_id = get_accelerator().current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
torch_alloc = get_accelerator().memory_allocated()
adjust = info.used - torch_alloc
except:
# pynvml not available
pass
return adjust
# https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html
class ProfilingInterpreter(Interpreter):
def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False):
super().__init__(gm)
self.nz3 = get_deepcompile_handle()
assert iteration > 0
assert warmup >= 0
self.iteration = iteration
self.warmup = warmup
self.device = torch.device(get_accelerator().current_device())
self.cache: Dict[Tuple, Any] = {}
self.distributed = dist.is_initialized()
self.allgather_mem: Dict[int, int] = {}
self.debug_log = debug_log
self.mem_usage_out_of_torch = 0
def run(self, *args) -> Any:
"""Run the graph with profiling enabled.
args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters.
returns: The output of the graph. Tensor in the output is real tensors.
"""
try:
assert _all_real_if_tensor(args), "Inputs must be real tensors"
self.nz3.enable_profiling(True)
with unset_fake_temporarily():
with get_accelerator().random().fork_rng(devices=[self.device]):
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
return_val = super().run(*args)
except Exception as e:
msg = e.msg if "msg" in dir(e) else str(e)
print(f"Profiling error {msg}")
finally:
self.nz3.clear_all_gathered_params()
self.nz3.enable_profiling(False)
return return_val
def run_node(self, n: torch.fx.Node) -> Any:
if n.op in {"placeholder", "output"}:
n.meta["device_time"] = 0.0
n.meta["wall_time"] = 0.0
n.meta["memory"] = 0
n.meta["max_memory"] = 0
n.meta["tensor_size"] = _node_size(n)
return super().run_node(n)
args, kwargs = self.fetch_args_kwargs_from_env(n)
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
def rebuild_param_if_necessary(v):
if hasattr(v, "ds_id"):
v.all_gather(param_list=[v])
return v
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x))
args = map_aggregate(args, lambda x: _to(x, self.device))
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device))
cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs))
cache_hit = cache_key in self.cache
cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int)
if self.distributed:
dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM)
cache_hit = cache_hit_flag.item() == 0
if cache_hit:
device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key]
n.meta["device_time"] = device_time
n.meta["wall_time"] = wall_time
n.meta["alloc_memory"] = alloc_mem
n.meta["max_memory"] = max_mem
n.meta["tensor_size"] = tensor_size
is_release_op = is_release_node(n)
run_only_once = cache_hit or is_release_op
iteration = 1 if run_only_once else self.iteration
accelerator = get_accelerator()
start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
get_accelerator().reset_peak_memory_stats()
alloc_mem_start = get_accelerator().memory_allocated()
max_mem_start = get_accelerator().max_memory_allocated()
if not run_only_once:
for i in range(self.warmup):
out = getattr(self, n.op)(n.target, args, kwargs)
if is_comm_op(n):
assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used."
dist.barrier()
start = time.time()
for i in range(iteration):
start_events[i].record()
out = getattr(self, n.op)(n.target, args, kwargs)
end_events[i].record()
accelerator.synchronize()
walltime_sum = time.time() - start
if is_comm_op(n):
dist.barrier()
alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch
max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch
tensor_size = _node_size(out)
def partition_param_if_necessary(v):
if hasattr(v, "ds_id") and not v.ds_persist:
v.partition(param_list=[v], has_been_updated=False)
return v
args = map_aggregate(args, lambda x: partition_param_if_necessary(x))
if not cache_hit:
device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)])
wall_time = walltime_sum / iteration * 1000
with unset_fake_temporarily():
vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size],
device=self.device)
if self.distributed:
dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG)
n.meta["device_time"] = vals_to_bcast[0].item()
n.meta["wall_time"] = vals_to_bcast[1].item()
n.meta["alloc_mem"] = int(vals_to_bcast[2].item())
n.meta["max_mem"] = int(vals_to_bcast[3].item())
n.meta["tensor_size"] = int(vals_to_bcast[4].item())
self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"],
n.meta["max_mem"], n.meta["tensor_size"])
if is_release_op:
n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0)
if dist.get_rank() == 0 and self.debug_log:
print(
f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}"
)
if n.target == torch.ops.dc.allgather_param.default:
out = args[0]
assert hasattr(out, "ds_id")
if not out.ds_persist:
self.nz3.invalidate_gathered_param(args[2])
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"]
return out
class MemoryProfilingInterpreter(Interpreter):
def __init__(self, gm: GraphModule, debug_log=False):
super().__init__(gm)
self.nz3 = get_deepcompile_handle()
self.device = torch.device(get_accelerator().current_device())
self.mem_record = []
self.last_alloc = get_accelerator().memory_allocated()
self.node_counter = 0
self.node_num = len(gm.graph.nodes)
self.debug_log = debug_log
def run(self, *args) -> Any:
try:
assert _all_real_if_tensor(args), "Inputs must be real tensors"
self.nz3.enable_profiling(True)
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
with unset_fake_temporarily():
with get_accelerator().random().fork_rng(devices=[self.device]):
return_val = super().run(*args)
except Exception as e:
print(f"MemoryProfiling error {e}")
finally:
self.nz3.enable_profiling(False)
return return_val
def run_node(self, n: torch.fx.Node) -> Any:
get_accelerator().reset_peak_memory_stats()
if n.op in {"placeholder", "output"}:
ret = super().run_node(n)
else:
args, kwargs = self.fetch_args_kwargs_from_env(n)
args = map_aggregate(args, lambda x: _to(x, self.device))
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device))
ret = getattr(self, n.op)(n.target, args, kwargs)
del args, kwargs
current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch
max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch
vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device)
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX)
current_alloc = vals_to_bcast[0].item()
max_alloc = vals_to_bcast[1].item()
self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc))
self.node_counter += 1
if self.debug_log and dist.get_rank() == 0:
print(
f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB"
)
self.last_alloc = current_alloc
return ret
def dump(self, path):
import pandas as pd
df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"])
df.to_csv(path, index=False)

429
deepspeed/compile/util.py Normal file
View File

@ -0,0 +1,429 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import functools
import operator
from typing import List, Tuple, Dict
from collections import defaultdict
import torch
from torch.fx import Node, Graph
from torch.fx.node import map_aggregate, Argument, map_arg
try:
from torch._subclasses.fake_tensor import unset_fake_temporarily
except ImportError:
# Unsupported torch version
pass
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from deepspeed.utils.torch import required_torch_version
from deepspeed.ops.op_builder.dc import DeepCompileBuilder
def is_deepcompile_supported() -> bool:
return required_torch_version(min_version=2.6, max_version=2.7) and get_accelerator().device_name() == "cuda"
dc_handle = None
if is_deepcompile_supported():
sym_size_ops = {
operator.ge,
operator.le,
operator.eq,
operator.ne,
operator.gt,
operator.lt,
torch.ops.aten.sym_size.int,
operator.getitem,
}
def get_deepcompile_handle():
global dc_handle
if dc_handle is None:
dc_handle = DeepCompileBuilder().load()
return dc_handle
def is_backend_inductor(backend):
return backend == "inductor"
backward_started = False
pre_backward_hooks = []
def add_pre_backward_hook(hook):
pre_backward_hooks.append(hook)
def deepcompile_backward_prologue(is_gradient_accumulation_boundary):
for hook in pre_backward_hooks:
hook()
dc = get_deepcompile_handle()
dc.start_backward(is_gradient_accumulation_boundary)
def log_rank0(msg: str, enable: bool = False):
if dist.get_rank() == 0 and enable:
print(msg)
def get_no_copy_ops():
# Need to compile custom ops
get_deepcompile_handle()
return {
torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default,
torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default
}
def get_input_nodes(graph: Graph) -> List[Node]:
return [n for n in graph.nodes if n.op == "placeholder"]
def get_param_nodes(graph: Graph, index_to_ds_ids: List[Tuple[int, int]]) -> List[Node]:
all_input_nodes = get_input_nodes(graph)
return [all_input_nodes[i] for i, _, _ in index_to_ds_ids]
def is_comm_op(node: Node) -> bool:
return "comm" in node.meta and node.meta["comm"]
def exclude_from_act_offload(node: Node) -> bool:
return node.target in sym_size_ops
def dtype_to_elem_size(dtype: torch.dtype) -> int:
if dtype == torch.float32:
elem_size = 4
elif dtype == torch.float64:
elem_size = 8
elif dtype == torch.float16:
elem_size = 2
else:
raise ValueError(f"Unsupported dtype: {dtype}")
return elem_size
def tensor_meta_size(tensor_meta) -> int:
numel = 1 if len(tensor_meta.shape) == 0 else functools.reduce(operator.mul, tensor_meta.shape)
dtype = tensor_meta.dtype
if dtype == torch.float32:
elem_size = 4
elif dtype == torch.float64 or dtype == torch.int64:
elem_size = 8
elif dtype == torch.float16 or dtype == torch.bfloat16:
elem_size = 2
elif dtype == torch.bool:
elem_size = 1
else:
raise ValueError(f"Unsupported dtype: {dtype}")
return numel * elem_size
class NodeValueOffloadHelper:
def __init__(self, device):
self.device = device
self.env_values: Dict[str, Argument] = {}
self.original_device: Dict[torch.Tensor, torch.device] = {}
def _to_cpu(self, v):
if torch.is_tensor(v):
with unset_fake_temporarily():
device = v.device
offloaded = v.to('cpu').detach()
self.original_device[offloaded] = device
return offloaded
return v
def _from_cpu(self, v):
if torch.is_tensor(v) and v in self.original_device:
return v.to(self.original_device[v])
return v
def save(self, name: str, v: Argument, offload) -> None:
self.env_values[name] = map_aggregate(v, lambda x: self._to_cpu(x) if offload else x)
def load(self, name: str) -> Argument:
return map_aggregate(self.env_values[name], lambda x: self._from_cpu(x))
def get_offloaded_value(self, name: str) -> Argument:
return self.env_values[name]
def has_value(self, name: str) -> bool:
return name in self.env_values
def clear(self) -> None:
self.env_values.clear()
self.original_device.clear()
def materialize_fake(v, device=None):
from torch._subclasses.fake_tensor import is_fake
def convert(t):
if is_fake(t):
with unset_fake_temporarily():
if t.is_floating_point():
return torch.randn(t.shape,
dtype=t.dtype,
device=t.device if device is None else device,
layout=t.layout,
requires_grad=t.requires_grad,
pin_memory=t.is_pinned())
else:
return torch.zeros(t.shape,
dtype=t.dtype,
device=t.device if device is None else device,
requires_grad=t.requires_grad)
return t
return map_aggregate(v, lambda x: convert(x))
def get_last_uses(graph: Graph):
position = {node: i for i, node in enumerate(graph.nodes)}
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
no_copy_ops = get_no_copy_ops()
def register_last_uses(n: Node, user: Node):
update = False
known_last_use = None
if user.target in no_copy_ops and n in node_to_last_use:
last_user = node_to_last_use[user]
last_use_position = position[last_user]
known_last_use = node_to_last_use[n]
known_last_use_position = position[known_last_use]
update = last_use_position > known_last_use_position
if n not in node_to_last_use or update:
if user.target in no_copy_ops:
user = node_to_last_use[user]
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
if known_last_use:
user_to_last_uses[known_last_use].remove(n)
for node in reversed(graph.nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return node_to_last_use, user_to_last_uses
def get_real_uses(graph: Graph):
node_to_uses: Dict[Node, List[Node]] = defaultdict(list)
no_copy_ops = get_no_copy_ops()
def register_last_uses(n: Node, user: Node):
if user.target == "output":
return
if user.target in no_copy_ops:
users = node_to_uses[user]
node_to_uses[n].extend(users)
else:
node_to_uses[n].append(user)
for node in reversed(graph.nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return node_to_uses
def count_inflight_values(graph: Graph, file_path: str):
position = {node: i for i, node in enumerate(graph.nodes)}
node_to_last_use, user_to_last_uses = get_last_uses(graph)
max_inflight_size = 0
inflight_values = set()
# Output csv.
csv_filename = file_path
csv_data = []
header = [
'Node', 'tensor_size', 'inflight_size', 'inflight_size_in_output', 'args', 'users', 'node_to_last_use',
'lifetime', 'user_to_last_uses', 'inflight_values'
]
csv_data.append(header)
from .fx import get_output_node
output_node = get_output_node(graph)
values_in_output = set([n for n in output_node.args[0] if isinstance(n, Node)])
for node in graph.nodes:
inflight_values.add(node)
if node in user_to_last_uses:
for to_delete in user_to_last_uses[node]:
inflight_values.remove(to_delete)
assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size"
inflight_size = sum(n.meta["tensor_size"] for n in inflight_values)
inflight_size_in_output = sum(n.meta["tensor_size"] for n in inflight_values if n in values_in_output)
lifetime = position[node_to_last_use[node]] - position[node] if node in node_to_last_use else 0
row = [
node.name, node.meta["tensor_size"], inflight_size, inflight_size_in_output,
[a.name for a in node.args if isinstance(a, Node)],
list(node.users.keys()), node_to_last_use[node] if node in node_to_last_use else 'NA', lifetime,
user_to_last_uses[node] if node in user_to_last_uses else 'NA',
list(inflight_values)
]
csv_data.append(row)
# print(
# f"Node: {node.name} users: {list(node.users.keys())} node_to_last_use: {node_to_last_use[node] if node in node_to_last_use else 'NA'} user_to_last_uses: {user_to_last_uses[node] if node in user_to_last_uses else 'NA'} inflight_values: {inflight_values} inflight_size: {inflight_size}"
# )
max_inflight_size = max(max_inflight_size, inflight_size)
import csv
with open(csv_filename, mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerows(csv_data)
print(f"Max inflight size: {max_inflight_size}")
print(f"Data successfully written to {csv_filename}")
def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], fwd_output_names: List[str]):
input_nodes = get_input_nodes(graph)
param_node_names = set([n.name for n in param_nodes_bw])
activation_node_names = []
for in_node in input_nodes:
if in_node.name in fwd_output_names:
if in_node.name not in param_node_names:
activation_node_names.append(in_node.name)
return activation_node_names
class TensorOffloadHelper():
def __init__(self):
self.devices = {}
self.base_tensors = {}
self.views = {}
self.arg_list = []
self.offloaded = {}
self.non_tensor = {}
def offload(self, argument):
def is_base_tensor(tensor):
return torch.is_tensor(a) and not a._is_view() and not hasattr(tensor, "ds_id")
base_tensor_ids = set()
for a in argument:
if is_base_tensor(a):
base_tensor_ids.add(id(a))
for a in argument:
a_id = id(a)
if is_base_tensor(a):
# Base tensor
self.devices[a_id] = a.device
self.base_tensors[a_id] = a
# elif torch.is_tensor(a) and not hasattr(a, "ds_id") and id(a._base) in base_tensor_ids:
# # View
# self.views[a_id] = {
# "base_id": id(a._base),
# "size": a.size(),
# "stride": a.stride(),
# "offset": a.storage_offset(),
# }
else:
# other types or ds tensor
self.non_tensor[a_id] = a
self.arg_list.append(a_id)
for a in argument:
if is_base_tensor(a):
a.data = a.data.to("cpu")
def reload(self, in_place):
loaded_base_tensors = {}
for a_id in self.arg_list:
if a_id in self.base_tensors:
device = self.devices[a_id]
if in_place:
self.base_tensors[a_id].data = self.base_tensors[a_id].to(device)
loaded_base_tensors[a_id] = self.base_tensors[a_id]
else:
loaded_base_tensors[a_id] = self.base_tensors[a_id].to(device)
results = []
for a_id in self.arg_list:
if a_id in self.base_tensors:
results.append(loaded_base_tensors[a_id])
# elif a_id in self.views:
# view_info = self.views[a_id]
# # print(f"load_args loading view {a_id} base_id={view_info['base_id']} size={view_info['size']} stride={view_info['stride']} offset={view_info['offset']}")
# base_tensor = loaded_base_tensors[view_info["base_id"]]
# view_tensor = base_tensor.as_strided(
# view_info["size"], view_info["stride"], view_info["offset"]
# )
# results.append(view_tensor)
elif a_id in self.non_tensor:
results.append(self.non_tensor[a_id])
return results
def add_mem_profile_nodes(graph: Graph, prefix: str):
def show_memory(label: str):
if dist.get_rank() == 0:
print(
f"{prefix} {label} alloc_mem={get_accelerator().memory_allocated()} max_mem={get_accelerator().max_memory_allocated()}"
)
nodes = list(graph.nodes)
for node in nodes:
if node.op == "output":
continue
with graph.inserting_after(node):
msg = f"Mem {node.name}"
name = f"show_memory_{node.name}"
graph.create_node('call_function', show_memory, (msg, ), {}, name=name)
def is_release_node(n: Node) -> bool:
return n.target == torch.ops.dc.release_param.default
def get_index_by_graph_id(graph_order, target_graph_id):
for index, (graph_id, _) in enumerate(graph_order):
if graph_id == target_graph_id:
return index
return -1

View File

@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from ..op_builder import DeepCompileBuilder

View File

@ -31,6 +31,7 @@ from .activation_checkpointing.config import DeepSpeedActivationCheckpointingCon
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from ..compile.config import CompileConfig
from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
@ -912,6 +913,8 @@ class DeepSpeedConfig(object):
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None
self.compile_config = CompileConfig(**param_dict.get('compile', {}))
self.timers_config = get_timers_config(param_dict)
self.tensor_parallel_config = get_tensor_parallel_config(param_dict)

View File

@ -107,6 +107,12 @@ from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.config import DtypeEnum
from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue
from deepspeed.compile.backend import register_compile_pass, opt_passes
from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states
from deepspeed.compile.init_z1 import init_z1
from deepspeed.compile.init_z3 import init_z3
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
DeepSpeedOptimizerCallable = \
@ -271,6 +277,7 @@ class DeepSpeedEngine(Module):
# Configure distributed model
self._configure_distributed_model(model)
if not self.is_deepcompile_enabled():
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
self.module_forward_post_hook = self._create_module_forward_post_hook()
@ -377,6 +384,12 @@ class DeepSpeedEngine(Module):
self.unflatten = _unflatten_dense_tensors
self._is_compiled = False
if is_deepcompile_supported():
# Predefined compile passes
self.register_compile_pass(zero3_compile.NAME, zero3_compile.add_z3_gather_release)
self.register_compile_pass(prefetch.NAME, prefetch.schedule_prefetch)
self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather)
self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states)
def _optimized_linear_offload_setup(self):
self.optimized_linear_base_weight_sharding = False
@ -486,6 +499,8 @@ class DeepSpeedEngine(Module):
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
if self.is_deepcompile_enabled():
get_deepcompile_handle().cleanup()
debug_clear_module_and_param_names()
def _get_model_parameters(self):
@ -2032,6 +2047,10 @@ class DeepSpeedEngine(Module):
if self.autotuning_profile_model_info():
ma = get_ma_status()
if self.is_deepcompile_enabled() and hasattr(self, "launch_compile_passes"):
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
self.launch_compile_passes(self.global_steps)
loss = self.module(*inputs, **kwargs)
if self.autotuning_profile_model_info():
@ -2104,7 +2123,8 @@ class DeepSpeedEngine(Module):
scale_wrt_gas = self.scale_wrt_gas
# scale loss w.r.t. gradient accumulation if reduction is not disabled
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_enabled(
)
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())
@ -2121,6 +2141,9 @@ class DeepSpeedEngine(Module):
)]
self.monitor.write_events(self.summary_events)
if self.is_deepcompile_enabled():
deepcompile_backward_prologue(self.is_gradient_accumulation_boundary())
return loss
def _backward_epilogue(self):
@ -2128,6 +2151,7 @@ class DeepSpeedEngine(Module):
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()
self._stop_timers(self.engine_timers.backward_reduce_timers)
see_memory_usage("Engine after backward", force=self.memory_breakdown())
@ -3849,7 +3873,7 @@ class DeepSpeedEngine(Module):
gc.collect()
get_accelerator().empty_cache()
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
"""Compile the module using the specified backend and kwargs.
If a compiler_fn is set, it will be used instead of torch.compile().
"""
@ -3865,10 +3889,50 @@ class DeepSpeedEngine(Module):
if 'backend' in compile_kwargs:
logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.")
print(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}")
if self.is_deepcompile_enabled():
assert self.zero_optimization_stage() == ZeroStageEnum.optimizer_states \
or self.zero_optimization_stage() == ZeroStageEnum.weights \
, "Currently DeepCompile supports stage 1 or 3 only."
if schedule is not None:
def passes_name_to_fn(passes):
for p in passes:
assert callable(p) or p in opt_passes, f"Unknown pass {p}"
return [p if callable(p) else opt_passes[p] for p in passes]
schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule]
assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile."
compile_config = self._config.compile_config
if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"]
and "offload_param" in self.config["zero_optimization"])
and self._config.zero_config.offload_param.device == "cpu"
and self._config.zero_config.offload_optimizer.device == "cpu"):
compile_config.offload_parameters = True
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states:
backend = init_z1(self, backend, compile_config, compile_kwargs, schedule)
elif self.zero_optimization_stage() == ZeroStageEnum.weights:
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule)
# create new dict to avoid modifying original dict
self.module.compile(**{**compile_kwargs, 'backend': backend})
self._is_compiled = True
def get_compile_time(self):
from deepspeed.compile.backend import opt_pass_times
return opt_pass_times
def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None:
register_compile_pass(pass_name, pass_fn)
def is_deepcompile_enabled(self):
return self._config.compile_config.deepcompile
@property
def is_compiled(self) -> bool:
return self._is_compiled

View File

@ -73,6 +73,8 @@ class ZeROOrderedDict(OrderedDict):
def _inject_parameters(module, cls):
for module in module.modules():
module._original_parameters = module._parameters
if cls == ZeROOrderedDict:
new_param = cls(parent_module=module)
else:
@ -80,6 +82,7 @@ def _inject_parameters(module, cls):
for key, param in module._parameters.items():
new_param[key] = param
module._parameters = new_param
@ -232,6 +235,8 @@ class DeepSpeedZeRoOffload(object):
for hook in self.backward_hooks:
hook.remove()
self.fwd_pre_hook.remove()
print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
force=False)
@ -244,7 +249,7 @@ class DeepSpeedZeRoOffload(object):
self.get_param_coordinator().reset_step()
self.module.register_forward_pre_hook(_start_of_forward_hook)
self.fwd_pre_hook = self.module.register_forward_pre_hook(_start_of_forward_hook)
#likely one of them should be enough but just to be safe
self._register_deepspeed_module(self.module)
@ -287,7 +292,7 @@ class DeepSpeedZeRoOffload(object):
count[0] = count[0] + 1
self._register_deepspeed_module(child, count=count)
@instrument_w_nvtx
@torch.compiler.disable
def _pre_forward_module_hook(module, *args):
self.pre_sub_module_forward_function(module)
@ -365,6 +370,7 @@ class DeepSpeedZeRoOffload(object):
return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function,
_run_after_backward_hook, inputs)
@torch.compiler.disable
def _post_backward_module_hook(module, inputs):
if not hasattr(module, "ds_grads_remaining"):
module.ds_grads_remaining = 0

View File

@ -556,7 +556,8 @@ class InsertPostInitMethodToModuleSubClasses(object):
print_rank_0(
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
force=False)
self.linear_bk = torch.nn.functional.linear
if not hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
InsertPostInitMethodToModuleSubClasses.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = zero3_linear_wrap
if self.quantized_initialization:

View File

@ -1511,6 +1511,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
# free the gradient
if not get_accelerator().is_synchronized_device():
if param.grad is not None:
param.grad.record_stream(get_accelerator().current_stream())
param.grad = None

32
op_builder/dc.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .builder import TorchCPUOpBuilder
class DeepCompileBuilder(TorchCPUOpBuilder):
BUILD_VAR = "DS_BUILD_DEEP_COMPILE"
NAME = "dc"
def __init__(self):
super().__init__(name=self.NAME)
def absolute_name(self):
return f'deepspeed.ops.{self.NAME}_op'
def sources(self):
return [
'csrc/compile/deepcompile.cpp', 'csrc/compile/init.cpp', 'csrc/compile/z1.cpp', 'csrc/compile/z3.cpp',
'csrc/compile/util.cpp'
]
def libraries_args(self):
args = super().libraries_args()
return args
def include_paths(self):
import os
import torch
return ['csrc/includes', os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]

View File

@ -0,0 +1 @@
scipy

View File

@ -91,6 +91,7 @@ extras_require = {
'inf': fetch_requirements('requirements/requirements-inf.txt'),
'sd': fetch_requirements('requirements/requirements-sd.txt'),
'triton': fetch_requirements('requirements/requirements-triton.txt'),
'deepcompile': fetch_requirements('requirements/requirements-deepcompile.txt'),
}
# Only install pynvml on nvidia gpus.

View File

@ -66,3 +66,49 @@ class TestZeRO(DistributedTest):
config_dict["bf16"] = {"enabled": True}
compare_loss(self, config_dict, dtype)
class TestDeepCompile(DistributedTest):
world_size = 2
non_daemonic_procs = True
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize('zero_stage', [1, 3])
@pytest.mark.parametrize('deepcompile', [True]) # deepcompile==False is included in test_compile_zero
def test(self, zero_stage, dtype, deepcompile):
if not required_torch_version(min_version=2.6):
pytest.skip("DeepCompile requires PyTorch >= v2.6")
if dtype == torch.bfloat16:
skip_on_arch(min_arch=8)
if dtype == torch.bfloat16 and not bf16_required_version_check():
pytest.skip(
"DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly"
)
if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"deepcompile": deepcompile
}
}
if dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
# Need warmup steps
compare_loss(self, config_dict, dtype, iteration=10)

View File

@ -70,8 +70,7 @@ def enable_determinism(seed: int):
@enable_determinism(123)
def compare_loss(self, config, dtype):
iteration = 5
def compare_loss(self, config, dtype, iteration=5):
hidden_dim = 10
RTOL = 5e-1
ATOL = 1e-2
@ -116,9 +115,12 @@ def compare_loss(self, config, dtype):
baseline_engine.backward(baseline_loss)
target_engine.backward(target_loss)
baseline_optimizer.step()
target_optimizer.step()
baseline_engine.step()
target_engine.step()
with GatheredParameters(target_engine.parameters()):
for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()):
assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)
baseline_engine.destroy()
target_engine.destroy()