mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 06:53:47 +08:00
DeepCompile for enhanced compiler integration (#7154)
This PR introduces *DeepCompile*, a new feature that efficiently integrates compiler optimizations with other DeepSpeed features. DeepCompile utilizes torch's dynamo to capture the computation graph and modifies it to incorporate DeepSpeed’s optimizations seamlessly. Currently, DeepCompile supports ZeRO-1 and ZeRO-3, with enhancements such as proactive prefetching and selective unsharding to improve performance. (More details will be added later.) --------- Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com> Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: zafarsadiq <zafarsadiq120@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
This commit is contained in:
2
.github/workflows/nv-pre-compile-ops.yml
vendored
2
.github/workflows/nv-pre-compile-ops.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
- name: Compile DeepSpeed Ops
|
||||
run: |
|
||||
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
|
||||
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 DS_BUILD_DEEP_COMPILE=0 pip3 install .
|
||||
- name: DS Report
|
||||
run: |
|
||||
ds_report
|
||||
|
2
.github/workflows/nv-torch-latest-v100.yml
vendored
2
.github/workflows/nv-torch-latest-v100.yml
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Install deepspeed
|
||||
run: |
|
||||
pip install .[dev,1bit,autotuning]
|
||||
pip install .[dev,1bit,autotuning,deepcompile]
|
||||
ds_report
|
||||
|
||||
- name: Python environment
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
## Latest News
|
||||
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>
|
||||
* [2025/04] [DeepCompile: Unlocking Compiler Optimization for Distributed Training](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepcompile/README.md)
|
||||
* [2025/03] [DeepSpeed-AutoTP: Automatic Tensor Parallel Training of Hugging Face models](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/huggingface-tp/README.md)
|
||||
* [2024/12] [Ulysses-Offload: Democratizing Long Context LLM Training ](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/ulysses-offload/README.md)
|
||||
* [2024/12] [DeepSpeed-Domino: Communication-Free LLM Training Engine](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-domino/README.md)
|
||||
|
168
blogs/deepcompile/README.md
Normal file
168
blogs/deepcompile/README.md
Normal file
@ -0,0 +1,168 @@
|
||||
<div align="center">
|
||||
|
||||
# DeepCompile: Unlocking Compiler Optimization for Distributed Training
|
||||
|
||||
</div>
|
||||
|
||||
# Introduction
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="media/perf_summary.png" width="1000">
|
||||
|
||||
</div>
|
||||
|
||||
Distributed training has become essential for scaling today’s massive deep learning models. While deep learning compilers like PyTorch compiler dramatically improved single-GPU training performance through optimizations like kernel fusion and operator scheduling, they fall short when it comes to distributed workloads.
|
||||
Existing distributed training frameworks such as DeepSpeed and FSDP have made large-scale model training feasible through advanced parallelization strategies. While powerful, their optimizations are implemented at the PyTorch framework level, which limits the ability to apply compiler-style techniques like dependency analysis or operator scheduling.
|
||||
|
||||
DeepCompile addresses this gap by enabling compiler-level optimizations for distributed training. It takes a standard single-GPU model implementation and transforms it into an optimized multi-GPU training graph without requiring changes to the model code. Unlike existing approaches, DeepCompile automatically applies parameter sharding, communication scheduling, and memory-aware execution at the compiler IR level, enabling global analysis and optimization that are difficult to express in traditional frameworks. Furthermore, during training, DeepCompile employs profile-guided optimization techniques to dynamically tune these parallelization strategies and improve training performance.
|
||||
|
||||
Our evaluation demonstrates that DeepCompile improves training performance over ZeRO-3 baselines, achieving up to 1.5x speedup when sufficient GPU resources are available, and up to 7x speedup in GPU-constrained settings that require offloading. DeepCompile is available in DeepSpeed versions >= [0.16.6](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.16.6).
|
||||
|
||||
# Design Overview
|
||||
|
||||
DeepCompile extends the capabilities of deep learning compilers to support distributed training. It starts from a standard single-GPU model implementation, such as those available on the Hugging Face model hub, and automatically transforms it by inserting necessary distributed training operations such as parameter sharding and communication primitives. Users are not required to embed any distributed logic into the model code.
|
||||
|
||||
The process begins by compiling the model into an intermediate representation (IR), which forms a computation graph. DeepCompile then applies a sequence of *optimization passes*, each responsible for a specific transformation of the computation graph or a targeted performance improvement, to incrementally introduce distributed behavior and optimize the graph. These include operations such as all-gather for sharded parameters or offloading of optimizer states, all while preserving the original computation semantics (Fig. 1).
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="media/workflow.png" width="400">
|
||||
|
||||
*Figure 1: Workflow of compilation and optimization with DeepCompile.*
|
||||
|
||||
</div>
|
||||
|
||||
At its core, DeepCompile builds on two key capabilities:
|
||||
|
||||
- **Automatic parallelization**: DeepCompile allows optimization passes to rewrite the single-GPU computation graph into a distributed multi-GPU version, incorporating strategies such as ZeRO, FSDP, and more. This eliminates the need for manual implementation of distributed training logic, drastically reducing engineering effort.
|
||||
- **Profile-guided performance tuning**: At runtime, DeepCompile collects profiling data such as operator-level memory usage and execution latency. It uses this information to dynamically schedule computation and communication operators. This enables effects such as an improved overlap between communication and computation, and an avoidance of memory bottlenecks. Fine-grained tuning through these optimization passes often leads to better performance than even manually engineered implementations.
|
||||
|
||||
Figure 2 illustrates the optimization cycle employed by DeepCompile. After the initial computation graph is generated by the compiler, DeepCompile profiles its behavior by measuring operator execution time, communication overhead, and memory usage throughout the forward and backward passes.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="media/opt_loop.png" width="600">
|
||||
|
||||
*Figure 2. Optimization cycle.*
|
||||
|
||||
</div>
|
||||
|
||||
Based on the collected profiling data, DeepCompile applies a sequence of optimization passes. These passes modify the computation graph by inserting, removing, or reordering operators to improve overall efficiency. The modified graph is then re-profiled, and this cycle of profiling and optimization is repeated.
|
||||
|
||||
Once a stable set of optimizations has been applied, the graph is deployed for the remaining training iterations. During execution, memory usage and other runtime characteristics may change. In such cases, DeepCompile can resume the profiling and optimization cycle according to the predefined schedule of passes, allowing the graph to adapt and maintain high performance.
|
||||
|
||||
# Optimizations
|
||||
|
||||
DeepCompile is designed as a general compiler framework for applying and optimizing a wide range of parallelization strategies. In the following, we describe several optimizations that have been implemented as optimization passes within DeepCompile.
|
||||
|
||||
## ZeRO3
|
||||
|
||||
As an initial step, we have used DeepCompile to implement and enhance ZeRO-3-style optimizations at the compiler level. ZeRO-3 partitions model parameters, gradients, and optimizer states across devices, reducing memory usage and enabling large-scale training.
|
||||
|
||||
In conventional ZeRO-3 implementations, operations such as all-gather, reduce-scatter, and buffer release are typically inserted using Python hooks at runtime. DeepCompile replaces this approach by injecting these operations directly into the computation graph during compilation. This allows the compiler to determine their placement precisely, guided by both the static structure of the graph and runtime profiling information.
|
||||
|
||||
One of the key optimizations is **proactive prefetching**, which launches all-gather operations earlier in the computation based on memory usage profiling. This reordering increases the overlap between communication and computation thereby improving throughput, while avoiding OOMs. In addition, small communication operations are often fused to reduce launch latency and improve efficiency.
|
||||
|
||||
Another optimization is **selective unsharding**, which keeps certain parameters in an unsharded form during the forward and backward passes when memory conditions permit. This reduces the frequency of all-gather operations and avoids redundant communication, particularly in scenarios where gradient accumulation is enabled.
|
||||
|
||||
## Offloading
|
||||
|
||||
DeepCompile also supports **adaptive offloading**, which offloads optimizer states to reduce GPU memory pressure. Unlike approaches that offload all the optimizer states, adaptive offloading identifies only the portions that exceed the memory limit—such as momentum and variance used by the Adam optimizer—and schedules data transfers to overlap with computation. This selective and asynchronous strategy minimizes overhead and enables efficient training even in memory-constrained environments.
|
||||
|
||||
## ZeRO1
|
||||
|
||||
ZeRO-1 differs from ZeRO-3 in that it shards only the optimizer states across devices, while keeping parameters and gradients fully replicated. This approach reduces memory usage with minimal changes to computation flow, making it a lightweight alternative for certain training scenarios.
|
||||
DeepCompile implements ZeRO-1-style optimization by inserting reduce-scatter operations directly into the computation graph. By avoiding Python-level hooks, this graph-level integration reduces overhead and improves execution efficiency.
|
||||
|
||||
# Performance Improvements
|
||||
|
||||
## ZeRO-3
|
||||
|
||||
We evaluated DeepCompile on Llama-3-70B and Mixtral 8x7B using parameter sharding on top of Hugging Face model implementations.
|
||||
Figure 3 shows training throughput (TFLOPs/GPU) across different gradient accumulation steps, using 32 H100 GPUs with a sequence length of 1024.
|
||||
We compare DeepCompile against two DeepSpeed ZeRO-3 baselines: (i) an eager-mode version without compiler support (labelled ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (labelled ZeRO3+Compile). For DeepCompile, we enabled both proactive prefetching and selective unsharding to demonstrate the combined effect of these optimization passes.
|
||||
|
||||
<div align="center"> <img src="media/perf_zero3.png" width="800">
|
||||
|
||||
*Figure 3. Achieved throughputs for ZeRO3 training of Llama-3 70B and Mixtral 8x7B models.*
|
||||
|
||||
</div>
|
||||
Across both models, DeepCompile consistently delivers higher throughput. The benefit becomes more pronounced at higher accumulation steps, where the reduced frequency of parameter updates makes selective unsharding more effective. DeepCompile with proactive prefetching and selective unsharding achieves up to 1.28× speedup over ZeRO-3 on Llama-3-70B and 1.54× on Mixtral 8x7B.
|
||||
|
||||
Meanwhile, enabling the PyTorch compiler with ZeRO-3, i.e., ZeRO3+Compile introduces minor overheads in some settings. This is because ZeRO-3 includes many conditional branches for runtime features such as prefetching. When the compiler encounters branches that cannot be statically resolved, it splits the computation into multiple graph segments. These fragmented segments can reduce optimization opportunities and introduce additional overheads during execution.
|
||||
|
||||
## Offloading
|
||||
|
||||
Training models as large as Llama-3 70B with ZeRO-3 typically requires 32 GPUs with 80GB of memory.
|
||||
DeepSpeed addresses this challenge by offering offloading capabilities, which transfer optimizer states and optionally model parameters to CPU memory to reduce GPU memory usage. DeepCompile also supports offloading through a dedicated optimization pass, but with a few key differences in design.
|
||||
|
||||
Unlike the traditional approach of offloading both optimizer computation and memory, DeepCompile offloads only optimizer memory (e.g., momentum, variance, and master weights of Adam optimizer) while the optimizer computation remains on GPU. DeepCompile profiles memory usage during both forward and backward passes to identify when offloading is necessary, and transfers only the required data. This fine-grained approach avoids unnecessary overhead and helps maintain high computational throughput.
|
||||
Furthermore, DeepCompile overlaps data transfers with computation whenever possible, dynamically adjusting the timing based on observed memory usage patterns. This asynchronous behavior is a crucial aspect of DeepCompile’s offloading strategy, allowing it to reduce GPU memory pressure without stalling execution.
|
||||
|
||||
We evaluated DeepCompile's offloading using Llama-3 70B on 16xH100-80GB (half the required GPU counts) and present the results in Figure 4.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="media/perf_offload.png" width="400">
|
||||
|
||||
*Figure 4. Achieved throughput of optimizer offloading for Llama-3 70B on 16x80GB GPUs*
|
||||
|
||||
</div>
|
||||
|
||||
We compare against two ZeRO-3 offloading baselines: (i) an eager-mode version without compiler support (ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO3+Compile). As shown by the results, DeepCompile significantly improves offloading efficiency and provides up to 7× speedup over ZeRO3+Eager. In contrast, we see that ZeRO3+Compile achieves similar performance as ZeRO3+Eager.
|
||||
|
||||
|
||||
## ZeRO-1
|
||||
|
||||
We also evaluated DeepCompile with ZeRO-1 using the Llama-3-8B model. We compare DeepCompile against two ZeRO-1 baselines: (i) an eager-mode version without compiler support (ZeRO1+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO1+Compile). In our experiment with 8 GPUs and a batch size of 2, DeepCompile achieved consistent throughput improvements across different sequence lengths, as shown in Figure 5.
|
||||
|
||||
<div align="center">
|
||||
|
||||
<img src="media/perf_zero1.png" width="800">
|
||||
|
||||
*Figure 5. Achieved throughput of ZeRO-1 training of Llama-3 8B*
|
||||
|
||||
</div>
|
||||
|
||||
The most significant speedup was observed with batch size 1 and sequence length 512, where DeepCompile outperformed ZeRO1+Eager by up to 1.9×, and ZeRO1+Compile by up to 2.5×.
|
||||
|
||||
While compiler-based approaches can be effective for large batch sizes and long sequences by replacing suboptimal operations with more efficient kernels, they may also introduce overheads in ZeRO-1-style training in the form of *graph breaks* around the communication operations. These overheads become more pronounced with smaller batch sizes and sequence lengths, thus hurting performance compared to the non-compiled execution. In contrast, DeepCompile inserts communication operators directly into the computation graph during compilation, avoiding graph fragmentation and minimizing associated overhead. This makes DeepCompile more robust to small-scale workloads, while still benefiting from compiler-level optimizations.
|
||||
|
||||
## Additional Results and Analysis
|
||||
|
||||
Please refer to our [arXiv paper](https://arxiv.org/abs/2504.09983) for additional results, such as detailed comparisons across different batch sizes, sequence lengths, and memory usage.
|
||||
|
||||
# Looking Ahead
|
||||
|
||||
DeepCompile brings the power of compiler-based optimizations to distributed deep learning. By transforming computation graphs and applying profile-guided optimization passes, it enables more efficient training without requiring changes to model code.
|
||||
|
||||
This release is just the beginning. We’re actively working on expanding the set of optimization passes and improving integration with a broader range of distributed training strategies. Future directions include automated parallelization (sequence/tensor parallelisms), smarter memory management, and dynamic adaptation to runtime behavior.
|
||||
|
||||
We invite the community to try DeepCompile, explore its capabilities, and contribute to its evolution. Let’s build the next generation of scalable deep learning together.
|
||||
|
||||
# Contributors
|
||||
|
||||
This project is the result of a close collaboration between Microsoft and the University of Virginia. The contributors are: Masahiro Tanaka, Du Li, and Umesh Chand, Olatunji Ruwase (Microsoft); and Ali Zafar and Haiying Shen (University of Virginia).
|
||||
|
||||
# Appendix
|
||||
|
||||
## Examples and Benchmarks
|
||||
|
||||
Our DeepSpeedExamples repository provides [example code](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/deepcompile) to enable DeepCompile.
|
||||
|
||||
## Optimization Passes
|
||||
|
||||
The following optimization passes are currently available in DeepCompile:
|
||||
|
||||
- All-gather & reduce-scatter insertion (ZeRO3)
|
||||
- Proactive prefetching (ZeRO3)
|
||||
- Selective unsharding (ZeRO3)
|
||||
- Reduce-scatter insertion (ZeRO1)
|
||||
- Adaptive offloading
|
||||
|
||||
We used the following combinations of passes in the experiments presented above:
|
||||
|
||||
- Improved communication scheduling for ZeRO-3: All-gather & reduce-scatter → Proactive prefetching → Selective unsharding
|
||||
- Offloading optimizer states for ZeRO3: Adding all-gather & reduce-scatter → Adaptive offloading
|
||||
- Reduced overhead and improved overlap for ZeRO-1: Adding reduce-scatter
|
BIN
blogs/deepcompile/media/opt_loop.png
Normal file
BIN
blogs/deepcompile/media/opt_loop.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 355 KiB |
BIN
blogs/deepcompile/media/perf_offload.png
Normal file
BIN
blogs/deepcompile/media/perf_offload.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 63 KiB |
BIN
blogs/deepcompile/media/perf_summary.png
Normal file
BIN
blogs/deepcompile/media/perf_summary.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 193 KiB |
BIN
blogs/deepcompile/media/perf_zero1.png
Normal file
BIN
blogs/deepcompile/media/perf_zero1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 117 KiB |
BIN
blogs/deepcompile/media/perf_zero3.png
Normal file
BIN
blogs/deepcompile/media/perf_zero3.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 129 KiB |
BIN
blogs/deepcompile/media/workflow.png
Normal file
BIN
blogs/deepcompile/media/workflow.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 85 KiB |
@ -10,6 +10,7 @@ set DS_BUILD_FP_QUANTIZER=0
|
||||
set DS_BUILD_GDS=0
|
||||
set DS_BUILD_RAGGED_DEVICE_OPS=0
|
||||
set DS_BUILD_SPARSE_ATTN=0
|
||||
set DS_BUILD_DEEP_COMPILE=0
|
||||
|
||||
python -m build --wheel --no-isolation
|
||||
|
||||
|
188
csrc/compile/deepcompile.cpp
Normal file
188
csrc/compile/deepcompile.cpp
Normal file
@ -0,0 +1,188 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "deepcompile.h"
|
||||
|
||||
#define USE_C10D_NCCL
|
||||
|
||||
namespace dc {
|
||||
|
||||
std::shared_ptr<DSParamRegistry> param_registry;
|
||||
std::unordered_map<long, std::shared_ptr<CustomOpExecutor>> executors;
|
||||
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets = nullptr;
|
||||
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> process_group = nullptr;
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem = nullptr;
|
||||
ncclComm_t nccl_comm;
|
||||
bool use_symm_mem;
|
||||
bool clone_custom_op_output;
|
||||
bool profile = false;
|
||||
bool pre_div_reduce = true;
|
||||
|
||||
bool sync_before_reduce; // for debugging
|
||||
bool sync_after_reduce; // for debugging
|
||||
bool sync_before_allgather; // for debugging
|
||||
bool sync_after_allgather; // for debugging
|
||||
|
||||
std::vector<int64_t> sizes_to_int_vector(at::IntArrayRef sizes)
|
||||
{
|
||||
std::vector<int64_t> result;
|
||||
for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); }
|
||||
return result;
|
||||
}
|
||||
|
||||
void enable_profiling(bool enable) { profile = enable; }
|
||||
|
||||
bool is_profiling() { return profile; }
|
||||
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> getSymmMemWorkspace(int64_t size)
|
||||
{
|
||||
c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device());
|
||||
std::vector<int64_t> sizes = {size};
|
||||
std::vector<int64_t> strides = {1};
|
||||
at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p(
|
||||
{size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt);
|
||||
return c10d::symmetric_memory::rendezvous(sym_mem_ws);
|
||||
}
|
||||
|
||||
void lazy_init_symm_memory()
|
||||
{
|
||||
if (use_symm_mem && !symm_mem) {
|
||||
int64_t max_param_size = 0;
|
||||
for (const auto& it : param_registry->getParams()) {
|
||||
int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size();
|
||||
if (size > max_param_size) { max_param_size = size; }
|
||||
}
|
||||
symm_mem = getSymmMemWorkspace(max_param_size);
|
||||
}
|
||||
}
|
||||
|
||||
ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type)
|
||||
{
|
||||
switch (scalar_type) {
|
||||
case at::kFloat: return ncclFloat;
|
||||
case at::kHalf: return ncclHalf;
|
||||
case at::kDouble: return ncclDouble;
|
||||
case at::kBFloat16: return ncclBfloat16;
|
||||
case at::kLong: return ncclInt64;
|
||||
case at::kInt: return ncclInt;
|
||||
case at::kChar: return ncclInt8;
|
||||
default: throw std::runtime_error("Unsupported scalar type");
|
||||
}
|
||||
}
|
||||
|
||||
void reset()
|
||||
{
|
||||
executors.clear();
|
||||
// We keep the buckets for memory estimation
|
||||
// reduce_buckets->clear();
|
||||
}
|
||||
|
||||
void cleanup()
|
||||
{
|
||||
reset();
|
||||
|
||||
ncclCommDestroy(nccl_comm);
|
||||
process_group = nullptr;
|
||||
symm_mem = nullptr;
|
||||
}
|
||||
|
||||
at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id)
|
||||
{
|
||||
if (sync_before_reduce) { c10::cuda::device_synchronize(); }
|
||||
|
||||
assert(hasKey(executors, graph_id));
|
||||
if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); }
|
||||
|
||||
if (sync_after_reduce) { c10::cuda::device_synchronize(); }
|
||||
|
||||
return at::Tensor();
|
||||
}
|
||||
|
||||
at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id)
|
||||
{
|
||||
return at::Tensor();
|
||||
}
|
||||
|
||||
void free_tensors(std::vector<at::Tensor> tensors)
|
||||
{
|
||||
int64_t THRESHOLD = 10 * 1024 * 1024;
|
||||
|
||||
if (!profile) {
|
||||
for (auto& tensor : tensors) {
|
||||
if (tensor.is_cuda() && tensor.numel() > THRESHOLD) {
|
||||
tensor.record_stream(at::cuda::getCurrentCUDAStream());
|
||||
tensor.set_data(torch::empty({0}, tensor.options()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void free_tensors_meta(std::vector<at::Tensor> tensors) {}
|
||||
|
||||
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
int64_t initial_reduce_bucket_size,
|
||||
bool enable_double_buffer,
|
||||
bool _use_symm_mem,
|
||||
bool _clone_custom_op_output,
|
||||
bool _sync_before_reduce,
|
||||
bool _sync_after_reduce,
|
||||
bool _sync_before_allgather,
|
||||
bool _sync_after_allgather)
|
||||
{
|
||||
process_group = pg;
|
||||
|
||||
ncclUniqueId ncclID;
|
||||
ncclGetUniqueId(&ncclID);
|
||||
|
||||
// ProcessGroup doesn't have an API to get the CUDA stream for comm calls.
|
||||
// So we create a NCCL communicator and call NCCL APIs directly.
|
||||
auto vec = std::vector<uint8_t>(reinterpret_cast<uint8_t*>(&ncclID),
|
||||
reinterpret_cast<uint8_t*>(&ncclID) + NCCL_UNIQUE_ID_BYTES);
|
||||
auto device = torch::Device(torch::kCUDA);
|
||||
at::Tensor tensor = torch::from_blob(vec.data(), {static_cast<long>(vec.size())}, torch::kUInt8)
|
||||
.to(torch::Device(torch::kCUDA));
|
||||
std::vector<at::Tensor> bcast_input = {tensor};
|
||||
|
||||
process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait();
|
||||
|
||||
// create a new nccl communicator
|
||||
std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES);
|
||||
ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank());
|
||||
|
||||
param_registry = std::make_shared<DSParamRegistry>();
|
||||
reduce_buckets = std::make_shared<DoubleBufferedReduceBucket>(initial_reduce_bucket_size,
|
||||
enable_double_buffer);
|
||||
use_symm_mem = _use_symm_mem;
|
||||
clone_custom_op_output = _clone_custom_op_output;
|
||||
|
||||
sync_before_reduce = _sync_before_reduce;
|
||||
sync_after_reduce = _sync_after_reduce;
|
||||
sync_before_allgather = _sync_before_allgather;
|
||||
sync_after_allgather = _sync_after_allgather;
|
||||
}
|
||||
|
||||
void start_forward()
|
||||
{
|
||||
lazy_init_symm_memory();
|
||||
for (auto& it : executors) { it.second->startForward(); }
|
||||
}
|
||||
|
||||
void end_forward()
|
||||
{
|
||||
for (auto& it : executors) { it.second->endForward(); }
|
||||
}
|
||||
|
||||
void start_backward(bool update)
|
||||
{
|
||||
for (auto& it : executors) { it.second->startBackward(update); }
|
||||
}
|
||||
|
||||
// We don't call this
|
||||
// void end_backward(bool update)
|
||||
// {
|
||||
// }
|
||||
|
||||
} // namespace dc
|
99
csrc/compile/init.cpp
Normal file
99
csrc/compile/init.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "deepcompile.h"
|
||||
#include "z1.h"
|
||||
#include "z3.h"
|
||||
|
||||
TORCH_LIBRARY(dc, m)
|
||||
{
|
||||
m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()");
|
||||
m.def("wait_allgather(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("release_param(Tensor a, int graph_id, int id, int n_users) -> Tensor");
|
||||
m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor");
|
||||
m.def("free_tensors(Tensor[] a) -> ()");
|
||||
m.def("offload_tensor(Tensor a, int id, int id) -> Tensor");
|
||||
m.def("reload_tensor(Tensor a, int id, int id) -> Tensor");
|
||||
m.def("wait_offload(Tensor a, int id, int id) -> Tensor");
|
||||
m.def("wait_reload(Tensor a, int id, int id) -> Tensor");
|
||||
m.def("offload_parameter(Tensor a, int id, int id) -> ()");
|
||||
m.def("reload_parameter(Tensor a, int id, int id) -> ()");
|
||||
|
||||
m.def("test_call(Tensor a) -> Tensor");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(dc, CPU, m)
|
||||
{
|
||||
m.impl("allgather_param", &dc::allgather_param);
|
||||
m.impl("prefetch_params_fused", &dc::prefetch_params_fused);
|
||||
m.impl("wait_allgather", &dc::wait_allgather);
|
||||
m.impl("release_param", &dc::release_param);
|
||||
m.impl("reduce_grad", &dc::reduce_grad);
|
||||
m.impl("free_tensors", &dc::free_tensors);
|
||||
m.impl("offload_tensor", &dc::offload_tensor);
|
||||
m.impl("reload_tensor", &dc::reload_tensor);
|
||||
m.impl("wait_offload", &dc::wait_offload);
|
||||
m.impl("wait_reload", &dc::wait_reload);
|
||||
m.impl("offload_parameter", &dc::offload_parameter);
|
||||
m.impl("reload_parameter", &dc::reload_parameter);
|
||||
|
||||
m.impl("test_call", &dc::test_call);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(dc, CUDA, m)
|
||||
{
|
||||
m.impl("allgather_param", &dc::allgather_param);
|
||||
m.impl("prefetch_params_fused", &dc::prefetch_params_fused);
|
||||
m.impl("wait_allgather", &dc::wait_allgather);
|
||||
m.impl("release_param", &dc::release_param);
|
||||
m.impl("reduce_grad", &dc::reduce_grad);
|
||||
m.impl("free_tensors", &dc::free_tensors);
|
||||
m.impl("offload_tensor", &dc::offload_tensor);
|
||||
m.impl("reload_tensor", &dc::reload_tensor);
|
||||
m.impl("wait_offload", &dc::wait_offload);
|
||||
m.impl("wait_reload", &dc::wait_reload);
|
||||
m.impl("offload_parameter", &dc::offload_parameter);
|
||||
m.impl("reload_parameter", &dc::reload_parameter);
|
||||
|
||||
m.impl("test_call", &dc::test_call);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(dc, Meta, m)
|
||||
{
|
||||
m.impl("allgather_param", &dc::allgather_param_meta);
|
||||
m.impl("prefetch_params_fused", &dc::prefetch_params_fused_meta);
|
||||
m.impl("release_param", &dc::release_param_meta);
|
||||
m.impl("wait_allgather", &dc::wait_allgather_meta);
|
||||
m.impl("reduce_grad", &dc::reduce_grad_meta);
|
||||
m.impl("free_tensors", &dc::free_tensors_meta);
|
||||
m.impl("reload_parameter", &dc::reload_parameter_meta);
|
||||
m.impl("offload_parameter", &dc::offload_parameter_meta);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("set_persistent", &dc::set_persistent, "Set persistent flag for a parameter");
|
||||
m.def("enable_profiling", &dc::enable_profiling, "Enable profiling");
|
||||
m.def("is_profiling", &dc::is_profiling, "Check if profiling is enabled");
|
||||
m.def("init", &dc::init, "Set the process group");
|
||||
m.def("cleanup", &dc::cleanup, "Cleanup the process group");
|
||||
m.def("register_z1_param", &dc::register_z1_param, "Register a parameter");
|
||||
m.def("register_graph_z1",
|
||||
&dc::register_graph_z1,
|
||||
"Register graph with a list of ds parameter ids");
|
||||
m.def("register_z3_param", &dc::register_z3_param, "Register a parameter");
|
||||
m.def("register_graph_z3",
|
||||
&dc::register_graph_z3,
|
||||
"Register graph with a list of ds parameter ids");
|
||||
m.def("start_forward", &dc::start_forward, "Start forward pass");
|
||||
m.def("end_forward", &dc::end_forward, "End forward pass");
|
||||
m.def("start_backward", &dc::start_backward, "Start backward pass");
|
||||
// m.def("end_backward", &dc::end_backward, "End backward pass");
|
||||
m.def("cleanup", &dc::cleanup, "Clean up DeepCompile");
|
||||
m.def("reset", &dc::reset, "Reset the state");
|
||||
m.def("invalidate_gathered_param", &dc::invalidate_gathered_param, "Invalidate gathered param");
|
||||
m.def("clear_all_gathered_params", &dc::clear_all_gathered_params, "Clear all gathered params");
|
||||
}
|
89
csrc/compile/util.cpp
Normal file
89
csrc/compile/util.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "deepcompile.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace dc {
|
||||
|
||||
std::string tensorToString(const at::Tensor& t, size_t max_elem, size_t max_str_len)
|
||||
{
|
||||
auto t_cpu = t.flatten()
|
||||
.slice(0, 0, std::min((int64_t)max_elem, t.numel()))
|
||||
.to(c10::Device(c10::kCPU), false, true);
|
||||
|
||||
size_t size = std::min(max_elem, productDim(t.sizes()));
|
||||
|
||||
if (t.scalar_type() == c10::ScalarType::Half || t.scalar_type() == c10::ScalarType::BFloat16) {
|
||||
auto float_ten = t_cpu.to(c10::ScalarType::Float, false, true).contiguous();
|
||||
return tensorPtrToString((float*)float_ten.data_ptr(), size, max_str_len);
|
||||
} else if (t.scalar_type() == c10::ScalarType::Float) {
|
||||
return tensorPtrToString((float*)t_cpu.data_ptr(), size, max_str_len);
|
||||
} else if (t.scalar_type() == c10::ScalarType::Double) {
|
||||
return tensorPtrToString((double*)t_cpu.data_ptr(), size, max_str_len);
|
||||
} else if (t.scalar_type() == c10::ScalarType::Int) {
|
||||
int* ptr = static_cast<int*>(t_cpu.data_ptr());
|
||||
return tensorPtrToString(ptr, size, max_str_len);
|
||||
} else if (t.scalar_type() == c10::ScalarType::Long) {
|
||||
long* ptr = static_cast<long*>(t_cpu.data_ptr());
|
||||
return tensorPtrToString(ptr, size, max_str_len);
|
||||
} else if (t.scalar_type() == c10::ScalarType::Byte) {
|
||||
unsigned char* ptr = static_cast<unsigned char*>(t_cpu.data_ptr());
|
||||
std::vector<unsigned short> vec;
|
||||
vec.reserve(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
vec.push_back(*ptr);
|
||||
ptr++;
|
||||
}
|
||||
return tensorPtrToString(&vec[0], size, max_str_len);
|
||||
} else if (t.scalar_type() == c10::ScalarType::Bool) {
|
||||
bool* ptr = static_cast<bool*>(t_cpu.data_ptr());
|
||||
std::vector<int> vec;
|
||||
vec.reserve(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
vec.push_back(*ptr);
|
||||
ptr++;
|
||||
}
|
||||
return tensorPtrToString(&vec[0], size, max_str_len);
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "Failed to convert tensor to string. Invalid type of tensor: "
|
||||
<< toString(t.scalar_type());
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
std::string tensorPtrToString(void* ptr,
|
||||
size_t size,
|
||||
c10::ScalarType datatype,
|
||||
size_t max_elem,
|
||||
size_t max_str_len)
|
||||
{
|
||||
int64_t elem_size = std::min((size_t)max_elem, size);
|
||||
|
||||
if (datatype == c10::ScalarType::Long) {
|
||||
return tensorPtrToString(static_cast<long*>(ptr), elem_size, max_str_len);
|
||||
} else if (datatype == c10::ScalarType::Int) {
|
||||
return tensorPtrToString(static_cast<int*>(ptr), elem_size, max_str_len);
|
||||
} else if (datatype == c10::ScalarType::Double) {
|
||||
return tensorPtrToString(static_cast<double*>(ptr), elem_size, max_str_len);
|
||||
} else if (datatype == c10::ScalarType::Float) {
|
||||
return tensorPtrToString(static_cast<float*>(ptr), elem_size, max_str_len);
|
||||
} else if (datatype == c10::ScalarType::Half || datatype == c10::ScalarType::BFloat16) {
|
||||
const auto ten = torch::from_blob(ptr, {(int64_t)elem_size}, datatype);
|
||||
auto float_ten = ten.to(c10::ScalarType::Float, false, true).contiguous();
|
||||
return tensorPtrToString((float*)float_ten.data_ptr(), elem_size, max_str_len);
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "Failed to convert tensor ptr to string. Invalid type of tensor: " << toString(datatype);
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
std::string tensorDimToString(const at::Tensor& t)
|
||||
{
|
||||
const auto dim = t.sizes();
|
||||
return join_as_str(dim);
|
||||
}
|
||||
} // namespace dc
|
141
csrc/compile/z1.cpp
Normal file
141
csrc/compile/z1.cpp
Normal file
@ -0,0 +1,141 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "z1.h"
|
||||
#include "deepcompile.h"
|
||||
|
||||
#define USE_C10D_NCCL
|
||||
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/cuda/nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace dc {
|
||||
|
||||
class Z1CustomOpExecutor : public CustomOpExecutor {
|
||||
public:
|
||||
Z1CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
||||
std::shared_ptr<DSParamRegistry> param_registry,
|
||||
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
|
||||
std::vector<long> ds_ids,
|
||||
ncclComm_t nccl_comm,
|
||||
at::cuda::CUDAStream rs_stream,
|
||||
at::cuda::CUDAStream copy_stream,
|
||||
bool pre_div_reduce)
|
||||
: CustomOpExecutor(process_group,
|
||||
param_registry,
|
||||
reduce_buckets,
|
||||
ds_ids,
|
||||
nccl_comm,
|
||||
rs_stream,
|
||||
copy_stream,
|
||||
pre_div_reduce)
|
||||
{
|
||||
}
|
||||
~Z1CustomOpExecutor() {}
|
||||
|
||||
void endBackward() override
|
||||
{
|
||||
if (param_updated_) {
|
||||
for (auto& it : has_acc_grad_) { it.second = false; }
|
||||
}
|
||||
}
|
||||
|
||||
void flushReduceBucket(at::ScalarType scalar_type) override
|
||||
{
|
||||
int rank = process_group_->getRank();
|
||||
|
||||
if (!hasKey(reduce_tasks_, scalar_type)) { return; }
|
||||
|
||||
int64_t tmp_recv_numel = 0;
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
|
||||
copy_done_event->block(rs_stream_);
|
||||
}
|
||||
|
||||
ncclGroupStart();
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg;
|
||||
if (pre_div_reduce_) {
|
||||
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
||||
t.getSendBuf().div_(process_group_->getSize());
|
||||
}
|
||||
|
||||
// inplace
|
||||
ncclResult_t result = ncclAllReduce(t.getSendBuf().data_ptr(),
|
||||
t.getSendBuf().data_ptr(),
|
||||
t.getSendBuf().numel(),
|
||||
get_nccl_data_type(scalar_type),
|
||||
op,
|
||||
nccl_comm_,
|
||||
rs_stream_);
|
||||
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllReduce failed"); }
|
||||
}
|
||||
ncclGroupEnd();
|
||||
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
bool acc_grad = has_acc_grad_.at(t.getDSId());
|
||||
auto param = param_registry_->getParam(t.getDSId());
|
||||
auto grad_buf = param.getGradBuffer().flatten();
|
||||
|
||||
if (grad_buf.numel() == 0) { continue; }
|
||||
|
||||
int64_t offset = param.getOffset();
|
||||
auto recv_buf = t.getSendBuf().flatten().index(
|
||||
{torch::indexing::Slice(offset, offset + grad_buf.numel())});
|
||||
if (acc_grad) {
|
||||
grad_buf.add_(recv_buf);
|
||||
} else {
|
||||
grad_buf.copy_(recv_buf);
|
||||
}
|
||||
has_acc_grad_[t.getDSId()] = true;
|
||||
}
|
||||
}
|
||||
|
||||
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_);
|
||||
|
||||
// Not very sure if this is necessary
|
||||
// Want to prevent grad tensor from being released before the copy is done
|
||||
auto comp_stream = at::cuda::getCurrentCUDAStream();
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
|
||||
copy_done_event->block(comp_stream);
|
||||
}
|
||||
reduce_tasks_[scalar_type].clear();
|
||||
}
|
||||
};
|
||||
|
||||
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
|
||||
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
|
||||
|
||||
void register_graph_z1(long graph_id, const std::vector<long>& ds_ids)
|
||||
{
|
||||
executors[graph_id] = std::make_shared<Z1CustomOpExecutor>(process_group,
|
||||
param_registry,
|
||||
reduce_buckets,
|
||||
ds_ids,
|
||||
nccl_comm,
|
||||
rs_stream,
|
||||
copy_stream,
|
||||
pre_div_reduce);
|
||||
}
|
||||
|
||||
void register_z1_param(long ds_id,
|
||||
const std::vector<int64_t>& ds_shape,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
int64_t offset)
|
||||
{
|
||||
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false);
|
||||
}
|
||||
|
||||
} // namespace dc
|
18
csrc/compile/z1.h
Normal file
18
csrc/compile/z1.h
Normal file
@ -0,0 +1,18 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "deepcompile.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace dc {
|
||||
|
||||
void register_graph_z1(long graph_id, const std::vector<long>& ds_ids);
|
||||
void register_z1_param(long ds_id,
|
||||
const std::vector<int64_t>& ds_shape,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
int64_t offset);
|
||||
} // namespace dc
|
544
csrc/compile/z3.cpp
Normal file
544
csrc/compile/z3.cpp
Normal file
@ -0,0 +1,544 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "z3.h"
|
||||
#include "deepcompile.h"
|
||||
|
||||
#define USE_C10D_NCCL
|
||||
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/cuda/nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace dc {
|
||||
|
||||
const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000;
|
||||
|
||||
class Z3CustomOpExecutor : public CustomOpExecutor {
|
||||
public:
|
||||
Z3CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
||||
std::shared_ptr<DSParamRegistry> param_registry,
|
||||
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
|
||||
std::vector<long> ds_ids,
|
||||
ncclComm_t nccl_comm,
|
||||
at::cuda::CUDAStream ag_stream,
|
||||
at::cuda::CUDAStream rs_stream,
|
||||
at::cuda::CUDAStream copy_stream,
|
||||
at::cuda::CUDAStream offload_stream,
|
||||
at::cuda::CUDAStream reload_stream,
|
||||
bool pre_div_reduce)
|
||||
: CustomOpExecutor(process_group,
|
||||
param_registry,
|
||||
reduce_buckets,
|
||||
ds_ids,
|
||||
nccl_comm,
|
||||
rs_stream,
|
||||
copy_stream,
|
||||
pre_div_reduce),
|
||||
ag_stream_(ag_stream),
|
||||
offload_stream_(offload_stream),
|
||||
reload_stream_(reload_stream)
|
||||
{
|
||||
for (long ds_id : ds_ids_) {
|
||||
ag_comm_done_events_[ds_id] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
ag_comp_done_events_[ds_id] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
|
||||
param_use_count_[ds_id] = 0;
|
||||
}
|
||||
}
|
||||
~Z3CustomOpExecutor() {}
|
||||
|
||||
void endBackward() override
|
||||
{
|
||||
if (param_updated_) {
|
||||
for (auto& it : has_acc_grad_) {
|
||||
it.second = false;
|
||||
param_registry_->setValid(it.first, false);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& it : reload_buffers_) {
|
||||
it.second.record_stream(at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
reload_buffers_.clear();
|
||||
}
|
||||
|
||||
void launchAllGather(at::Tensor output_buf,
|
||||
long ds_id,
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
|
||||
if (symm_mem == nullptr) {
|
||||
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
|
||||
output_buf.data_ptr(),
|
||||
ds_tensor.numel(),
|
||||
get_nccl_data_type(ds_tensor.scalar_type()),
|
||||
nccl_comm_,
|
||||
ag_stream_);
|
||||
|
||||
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); }
|
||||
} else {
|
||||
at::cuda::CUDAStreamGuard guard(ag_stream_);
|
||||
int world_size = process_group_->getSize();
|
||||
int rank = process_group_->getRank();
|
||||
|
||||
at::Tensor local_buf =
|
||||
symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0);
|
||||
local_buf.copy_(ds_tensor, true);
|
||||
|
||||
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER);
|
||||
auto chunks = output_buf.flatten().chunk(world_size);
|
||||
for (int step = 0; step < world_size; step++) {
|
||||
int remote_rank = (rank - step + world_size) % world_size;
|
||||
auto src_buf = symm_mem->get_buffer(
|
||||
remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0);
|
||||
chunks[remote_rank].copy_(src_buf.flatten(), true);
|
||||
}
|
||||
symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER);
|
||||
}
|
||||
|
||||
param_registry_->registerGatheredParam(ds_id, output_buf);
|
||||
param_registry_->setValid(ds_id, true);
|
||||
}
|
||||
|
||||
at::Tensor allgatherParam(long ds_id,
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }
|
||||
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const at::Tensor& ds_tensor = param.getDSTensor();
|
||||
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
|
||||
? param_registry_->getGatheredParam(ds_id)
|
||||
: torch::empty(param.getShape(), ds_tensor.options());
|
||||
|
||||
assert(hasKey(ag_comp_done_events_, ds_id));
|
||||
ag_comp_done_events_[ds_id]->record();
|
||||
ag_comp_done_events_[ds_id]->block(ag_stream_);
|
||||
|
||||
launchAllGather(output_buf, ds_id, symm_mem);
|
||||
|
||||
ag_comm_done_events_[ds_id]->record(ag_stream_);
|
||||
return output_buf;
|
||||
}
|
||||
|
||||
void prefetchParamsFused(std::vector<int64_t> ds_ids,
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
|
||||
{
|
||||
std::vector<int64_t> invalid_ds_ids;
|
||||
for (const auto& ds_id : ds_ids) {
|
||||
if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); }
|
||||
}
|
||||
|
||||
std::unordered_map<long, at::Tensor> output_bufs;
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
if (param_registry_->hasGatheredParam(ds_id)) {
|
||||
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
|
||||
} else {
|
||||
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
|
||||
}
|
||||
}
|
||||
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
ag_comp_done_events_[ds_id]->record();
|
||||
ag_comp_done_events_[ds_id]->block(ag_stream_);
|
||||
}
|
||||
|
||||
ncclGroupStart();
|
||||
for (long ds_id : invalid_ds_ids) {
|
||||
assert(hasKey(output_bufs, ds_id));
|
||||
launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem);
|
||||
}
|
||||
ncclGroupEnd();
|
||||
|
||||
for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); }
|
||||
}
|
||||
|
||||
void releaseParam(long ds_id, long n_users)
|
||||
{
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
|
||||
assert(hasKey(param_use_count_, ds_id));
|
||||
if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; }
|
||||
param_use_count_[ds_id]--;
|
||||
|
||||
if (param_use_count_[ds_id] == 0 && !param.isPersistent()) {
|
||||
at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id);
|
||||
|
||||
if (gathered_param.defined()) { // gathered param is undefined while profiling
|
||||
const auto options = gathered_param.options();
|
||||
at::Tensor empty_buffer = torch::empty({0}, options);
|
||||
gathered_param.set_data(empty_buffer);
|
||||
}
|
||||
|
||||
param_registry_->unregisterGatheredParam(ds_id);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor waitAllgather(at::Tensor v, long ds_id)
|
||||
{
|
||||
assert(hasKey(ag_comm_done_events_, ds_id));
|
||||
ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream());
|
||||
return v;
|
||||
}
|
||||
|
||||
void flushReduceBucket(at::ScalarType scalar_type) override
|
||||
{
|
||||
if (!hasKey(reduce_tasks_, scalar_type)) { return; }
|
||||
|
||||
int64_t tmp_recv_numel = 0;
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
|
||||
copy_done_event->block(rs_stream_);
|
||||
|
||||
if (has_acc_grad_.at(t.getDSId())) {
|
||||
tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel();
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor tmp_recv_buf = at::Tensor();
|
||||
if (tmp_recv_numel > 0) {
|
||||
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
||||
tmp_recv_buf = torch::empty({tmp_recv_numel},
|
||||
at::TensorOptions().dtype(scalar_type).device(at::kCUDA));
|
||||
}
|
||||
|
||||
ncclGroupStart();
|
||||
int64_t offset = 0;
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
|
||||
|
||||
bool acc_grad = has_acc_grad_.at(t.getDSId());
|
||||
|
||||
if (acc_grad) {
|
||||
recv_buf =
|
||||
tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())});
|
||||
}
|
||||
|
||||
ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg;
|
||||
if (pre_div_reduce_) {
|
||||
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
||||
t.getSendBuf().div_(process_group_->getSize());
|
||||
}
|
||||
ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(),
|
||||
recv_buf.data_ptr(),
|
||||
recv_buf.numel(),
|
||||
get_nccl_data_type(scalar_type),
|
||||
op,
|
||||
nccl_comm_,
|
||||
rs_stream_);
|
||||
if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); }
|
||||
|
||||
if (acc_grad) { offset += recv_buf.numel(); }
|
||||
}
|
||||
ncclGroupEnd();
|
||||
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(rs_stream_);
|
||||
int64_t offset = 0;
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
bool acc_grad = has_acc_grad_.at(t.getDSId());
|
||||
|
||||
if (acc_grad) {
|
||||
auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer();
|
||||
recv_buf.add_(tmp_recv_buf.index(
|
||||
{torch::indexing::Slice(offset, offset + recv_buf.numel())}));
|
||||
offset += recv_buf.numel();
|
||||
}
|
||||
has_acc_grad_[t.getDSId()] = true;
|
||||
}
|
||||
}
|
||||
|
||||
reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_);
|
||||
|
||||
// Not very sure if this is necessary
|
||||
// Want to prevent grad tensor from being released before the copy is done
|
||||
auto comp_stream = at::cuda::getCurrentCUDAStream();
|
||||
for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) {
|
||||
auto copy_done_event = rs_copy_done_events_.at(t.getDSId());
|
||||
copy_done_event->block(comp_stream);
|
||||
}
|
||||
reduce_tasks_[scalar_type].clear();
|
||||
|
||||
if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); }
|
||||
}
|
||||
|
||||
at::Tensor offloadTensor(at::Tensor tensor, long id)
|
||||
{
|
||||
if (!hasKey(offload_events_, id)) {
|
||||
offload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
offload_comp_done_events_[id] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
|
||||
const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU);
|
||||
offload_buffers_[id] = at::empty_like(tensor, options);
|
||||
}
|
||||
|
||||
offload_comp_done_events_[id]->record();
|
||||
offload_comp_done_events_[id]->block(offload_stream_);
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(offload_stream_);
|
||||
offload_buffers_.at(id).copy_(tensor, true);
|
||||
}
|
||||
|
||||
tensor.record_stream(offload_stream_);
|
||||
|
||||
offload_events_[id]->record(offload_stream_);
|
||||
assert(hasKey(offload_buffers_, id));
|
||||
return offload_buffers_.at(id);
|
||||
}
|
||||
|
||||
at::Tensor reloadTensor(at::Tensor tensor, long id)
|
||||
{
|
||||
if (!hasKey(reload_events_, id)) {
|
||||
reload_events_[id] = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
}
|
||||
|
||||
assert(hasKey(offload_buffers_, id));
|
||||
offload_events_[id]->block(reload_stream_);
|
||||
|
||||
at::Tensor ten;
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(reload_stream_);
|
||||
|
||||
assert(hasKey(offload_buffers_, id));
|
||||
at::Tensor buf = offload_buffers_.at(id);
|
||||
const auto options = at::TensorOptions().device(torch::kCUDA);
|
||||
ten = at::empty_like(buf, options);
|
||||
ten.copy_(buf, true);
|
||||
|
||||
reload_buffers_[id] = ten;
|
||||
}
|
||||
|
||||
reload_events_[id]->record(reload_stream_);
|
||||
return ten;
|
||||
}
|
||||
|
||||
at::Tensor waitOffload(at::Tensor tensor, long id)
|
||||
{
|
||||
assert(hasKey(offload_events_, id));
|
||||
offload_events_[id]->block(at::cuda::getCurrentCUDAStream());
|
||||
|
||||
assert(hasKey(offload_buffers_, id));
|
||||
return offload_buffers_.at(id);
|
||||
}
|
||||
|
||||
at::Tensor waitReload(at::Tensor tensor, long id)
|
||||
{
|
||||
assert(hasKey(reload_events_, id));
|
||||
reload_events_[id]->block(at::cuda::getCurrentCUDAStream());
|
||||
|
||||
assert(hasKey(reload_buffers_, id));
|
||||
auto ten = reload_buffers_.at(id);
|
||||
|
||||
// We can't release here because the tensor is still being used
|
||||
// We will need "freeReloadedTensor" after the last user of the tensor to call
|
||||
// ".record_stream". As it is a bit complicated, we clear the buffer and do at the end of
|
||||
// the backward pass for now. reload_buffers_.erase(id);
|
||||
return ten;
|
||||
}
|
||||
|
||||
void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); }
|
||||
void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); }
|
||||
|
||||
bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); }
|
||||
|
||||
bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); }
|
||||
|
||||
private:
|
||||
at::cuda::CUDAStream ag_stream_;
|
||||
at::cuda::CUDAStream offload_stream_;
|
||||
at::cuda::CUDAStream reload_stream_;
|
||||
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comp_done_events_;
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> ag_comm_done_events_;
|
||||
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_events_;
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> offload_comp_done_events_;
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> reload_events_;
|
||||
std::unordered_map<long, at::Tensor> offload_buffers_;
|
||||
std::unordered_map<long, at::Tensor> reload_buffers_;
|
||||
|
||||
std::unordered_map<long, long> param_use_count_;
|
||||
};
|
||||
|
||||
static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true);
|
||||
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
|
||||
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
|
||||
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true);
|
||||
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true);
|
||||
|
||||
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids)
|
||||
{
|
||||
executors[graph_id] = std::make_shared<Z3CustomOpExecutor>(process_group,
|
||||
param_registry,
|
||||
reduce_buckets,
|
||||
ds_ids,
|
||||
nccl_comm,
|
||||
ag_stream,
|
||||
rs_stream,
|
||||
copy_stream,
|
||||
offload_stream,
|
||||
reload_stream,
|
||||
pre_div_reduce);
|
||||
}
|
||||
|
||||
void register_z3_param(long ds_id,
|
||||
const std::vector<int64_t>& ds_shape,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
bool persistent)
|
||||
{
|
||||
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
|
||||
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }
|
||||
}
|
||||
|
||||
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
|
||||
if (sync_before_allgather) { c10::cuda::device_synchronize(); }
|
||||
auto ret = executor->allgatherParam(ds_id, symm_mem);
|
||||
if (sync_after_allgather) { c10::cuda::device_synchronize(); }
|
||||
return ret;
|
||||
}
|
||||
|
||||
void set_persistent(long ds_id)
|
||||
{
|
||||
param_registry->setPersistent(ds_id, true);
|
||||
|
||||
// Allocate buffer here
|
||||
// Memory fragmentation will be more severe if we allocate in forward/backward
|
||||
for (auto& it : executors) {
|
||||
if (it.second->hasParam(ds_id)) {
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(it.first, executors);
|
||||
executor->allgatherParam(ds_id, symm_mem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void prefetch_params_fused(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->prefetchParamsFused(ds_ids, symm_mem);
|
||||
}
|
||||
|
||||
void prefetch_params_fused_meta(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids)
|
||||
{
|
||||
}
|
||||
|
||||
// for profiling
|
||||
void invalidate_gathered_param(long ds_id)
|
||||
{
|
||||
const DSParam& param = param_registry->getParam(ds_id);
|
||||
if (param.isPersistent()) { return; }
|
||||
|
||||
param_registry->unregisterGatheredParam(ds_id);
|
||||
param_registry->registerGatheredParam(ds_id, at::Tensor());
|
||||
}
|
||||
|
||||
void clear_all_gathered_params()
|
||||
{
|
||||
for (const auto& it : param_registry->getParams()) {
|
||||
long ds_id = it.first;
|
||||
const DSParam& param = param_registry->getParam(ds_id);
|
||||
if (param.isPersistent()) { continue; }
|
||||
if (param_registry->hasGatheredParam(ds_id)) {
|
||||
param_registry->unregisterGatheredParam(ds_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id)
|
||||
{
|
||||
const DSParam& param = param_registry->getParam(ds_id);
|
||||
auto options = param.getDSTensor().options().device(c10::kMeta);
|
||||
at::Tensor output_buf = torch::empty(param.getShape(), options);
|
||||
return output_buf;
|
||||
}
|
||||
|
||||
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->releaseParam(ds_id, n_users);
|
||||
|
||||
if (clone_custom_op_output) { return dummy.clone(); }
|
||||
return dummy;
|
||||
}
|
||||
|
||||
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users)
|
||||
{
|
||||
return dummy;
|
||||
}
|
||||
|
||||
at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->waitAllgather(v, ds_id);
|
||||
return v;
|
||||
}
|
||||
|
||||
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; }
|
||||
|
||||
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
return executor->offloadTensor(tensor, id);
|
||||
}
|
||||
|
||||
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
return executor->reloadTensor(tensor, id);
|
||||
}
|
||||
|
||||
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
return executor->waitOffload(tensor, id);
|
||||
}
|
||||
|
||||
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
if (profile && !executor->hasReloadBuffer(id)) { return tensor; }
|
||||
return executor->waitReload(tensor, id);
|
||||
}
|
||||
|
||||
at::Tensor test_call(at::Tensor a)
|
||||
{
|
||||
std::cout << "test_call" << std::endl;
|
||||
return a;
|
||||
}
|
||||
|
||||
void reload_parameter(at::Tensor tensor, long graph_id, long ds_id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->reloadParameter(tensor, ds_id);
|
||||
}
|
||||
|
||||
void offload_parameter(at::Tensor tensor, long graph_id, long ds_id)
|
||||
{
|
||||
auto executor = getExecutor<Z3CustomOpExecutor>(graph_id, executors);
|
||||
executor->offloadParameter(tensor, ds_id);
|
||||
}
|
||||
void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {}
|
||||
void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {}
|
||||
|
||||
} // namespace dc
|
48
csrc/compile/z3.h
Normal file
48
csrc/compile/z3.h
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#include "deepcompile.h"
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace dc {
|
||||
|
||||
void register_graph_z3(long graph_id, const std::vector<long>& ds_ids);
|
||||
void register_graph_ops_z3(long graph_id,
|
||||
const std::vector<std::string>& op_names,
|
||||
const std::vector<long>& n_args);
|
||||
void register_bwd_graph_ops_z3(long graph_id,
|
||||
const std::vector<std::string>& op_names,
|
||||
const std::vector<long>& n_args);
|
||||
void register_z3_param(long ds_id,
|
||||
const std::vector<int64_t>& ds_shape,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
bool persistent);
|
||||
at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id);
|
||||
void set_persistent(long ds_id);
|
||||
void prefetch_params_fused(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids);
|
||||
void prefetch_params_fused_meta(long graph_id,
|
||||
const std::vector<at::Tensor> params,
|
||||
const std::vector<long>& ds_ids);
|
||||
// for profiling
|
||||
void invalidate_gathered_param(long ds_id);
|
||||
void clear_all_gathered_params();
|
||||
at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id);
|
||||
at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users);
|
||||
at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users);
|
||||
at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id);
|
||||
at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id);
|
||||
at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id);
|
||||
at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id);
|
||||
at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id);
|
||||
at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id);
|
||||
void reload_parameter(at::Tensor tensor, long graph_id, long id);
|
||||
void offload_parameter(at::Tensor tensor, long graph_id, long id);
|
||||
void reload_parameter_meta(at::Tensor tensor, long graph_id, long id);
|
||||
void offload_parameter_meta(at::Tensor tensor, long graph_id, long id);
|
||||
} // namespace dc
|
576
csrc/includes/deepcompile.h
Normal file
576
csrc/includes/deepcompile.h
Normal file
@ -0,0 +1,576 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// DeepSpeed Team
|
||||
|
||||
#pragma once
|
||||
|
||||
#define NOMINMAX // Windows idiosyncrasy
|
||||
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
|
||||
|
||||
#define USE_C10D_NCCL
|
||||
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/cuda/nccl.h>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
#include <torch/csrc/distributed/c10d/SymmetricMemory.hpp>
|
||||
|
||||
namespace dc {
|
||||
|
||||
template <typename K, typename V>
|
||||
static bool hasKey(const std::unordered_map<K, V>& map, const K& key)
|
||||
{
|
||||
return map.find(key) != map.end();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string to_string(const T& v)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << v;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
template <typename L>
|
||||
size_t productDim(const L& dim)
|
||||
{
|
||||
size_t prod = 1;
|
||||
for (auto d : dim) { prod *= d; }
|
||||
return prod;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string join_as_str(const T& v, const char* delim = ",", const size_t maxlen = 0)
|
||||
{
|
||||
std::stringstream ss;
|
||||
|
||||
if (!v.empty()) {
|
||||
auto it = v.begin();
|
||||
ss << to_string(*it);
|
||||
it++;
|
||||
for (; it != v.end(); ++it) {
|
||||
if (delim) ss << delim;
|
||||
ss << to_string(*it);
|
||||
}
|
||||
}
|
||||
|
||||
std::string s = ss.str();
|
||||
if (maxlen > 0 && s.length() > maxlen) { s = s.substr(0, maxlen) + " ..."; }
|
||||
|
||||
return "[" + s + "]";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string tensorPtrToString(T* ptr, size_t size, size_t str_len = 100)
|
||||
{
|
||||
std::vector<T> vals;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
vals.push_back(*ptr);
|
||||
ptr++;
|
||||
}
|
||||
return join_as_str(vals, ",", str_len);
|
||||
}
|
||||
|
||||
std::string tensorPtrToString(void* ptr,
|
||||
size_t size,
|
||||
c10::ScalarType datatype,
|
||||
size_t max_elem = 20,
|
||||
size_t max_str_len = 100);
|
||||
|
||||
std::string tensorToString(const at::Tensor& t, size_t max_elem = 20, size_t max_str_len = 100);
|
||||
|
||||
std::string tensorDimToString(const at::Tensor& t);
|
||||
|
||||
at::Tensor test_call(at::Tensor param);
|
||||
|
||||
extern c10::intrusive_ptr<c10d::ProcessGroup> process_group;
|
||||
extern c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem;
|
||||
extern ncclComm_t nccl_comm;
|
||||
extern bool use_symm_mem;
|
||||
extern bool clone_custom_op_output;
|
||||
extern bool profile;
|
||||
extern bool pre_div_reduce;
|
||||
|
||||
extern bool sync_before_reduce; // for debugging
|
||||
extern bool sync_after_reduce; // for debugging
|
||||
extern bool sync_before_allgather; // for debugging
|
||||
extern bool sync_after_allgather; // for debugging
|
||||
|
||||
std::vector<int64_t> sizes_to_int_vector(at::IntArrayRef sizes);
|
||||
void enable_profiling(bool enable);
|
||||
bool is_profiling();
|
||||
|
||||
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> getSymmMemWorkspace(int64_t size);
|
||||
void lazy_init_symm_memory();
|
||||
ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type);
|
||||
void cleanup();
|
||||
|
||||
class ReduceTask {
|
||||
public:
|
||||
ReduceTask(long ds_id, at::Tensor grad, at::Tensor send_buf)
|
||||
: ds_id_(ds_id), grad_(std::move(grad)), send_buf_(std::move(send_buf))
|
||||
{
|
||||
}
|
||||
|
||||
long getDSId() const { return ds_id_; }
|
||||
at::Tensor getSendBuf() const { return send_buf_; }
|
||||
|
||||
private:
|
||||
long ds_id_;
|
||||
at::Tensor grad_;
|
||||
at::Tensor send_buf_;
|
||||
};
|
||||
|
||||
class ReduceBucket {
|
||||
public:
|
||||
ReduceBucket(int64_t size, at::ScalarType scalar_type) : size_(size), scalar_type_(scalar_type)
|
||||
{
|
||||
buffer_ = torch::empty({size}, at::TensorOptions().dtype(scalar_type).device(at::kCUDA));
|
||||
offset_ = 0;
|
||||
}
|
||||
|
||||
int64_t getSize() const { return size_; }
|
||||
int64_t getOffset() const { return offset_; }
|
||||
at::Tensor getBuffer() const { return buffer_; }
|
||||
at::ScalarType getScalarType() const { return scalar_type_; }
|
||||
|
||||
void reserve(int64_t size)
|
||||
{
|
||||
if (size > size_) {
|
||||
buffer_ =
|
||||
torch::empty({size}, at::TensorOptions().dtype(scalar_type_).device(at::kCUDA));
|
||||
size_ = size;
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor allocate(int64_t numel)
|
||||
{
|
||||
if (offset_ + numel > size_) {
|
||||
throw std::runtime_error("Buffer size exceeds the reduce bucket size");
|
||||
}
|
||||
|
||||
at::Tensor result = buffer_.index({torch::indexing::Slice(offset_, offset_ + numel)});
|
||||
offset_ += numel;
|
||||
return result;
|
||||
}
|
||||
|
||||
bool shouldFlush(int64_t numel) { return offset_ > 0 && offset_ + numel > size_; }
|
||||
|
||||
void reset() { offset_ = 0; }
|
||||
|
||||
private:
|
||||
int64_t size_;
|
||||
int64_t offset_;
|
||||
at::Tensor buffer_;
|
||||
at::ScalarType scalar_type_;
|
||||
};
|
||||
|
||||
class DoubleBufferedReduceBucket {
|
||||
public:
|
||||
DoubleBufferedReduceBucket(int64_t initial_bucket_size, bool enable_double_buffer)
|
||||
: initial_bucket_size_(initial_bucket_size), enable_double_buffer_(enable_double_buffer)
|
||||
{
|
||||
}
|
||||
|
||||
void swap(at::ScalarType scalar_type,
|
||||
at::cuda::CUDAStream rs_stream,
|
||||
at::cuda::CUDAStream copy_stream)
|
||||
{
|
||||
assert(hasKey(current_buffer_, scalar_type));
|
||||
assert(hasKey(current_buffer_events_, scalar_type));
|
||||
|
||||
current_buffer_.at(scalar_type)->reset();
|
||||
current_buffer_events_.at(scalar_type)->record(rs_stream);
|
||||
|
||||
if (enable_double_buffer_) {
|
||||
assert(hasKey(shadow_buffer_, scalar_type));
|
||||
assert(hasKey(shadow_buffer_events_, scalar_type));
|
||||
|
||||
auto tmp = current_buffer_.at(scalar_type);
|
||||
current_buffer_[scalar_type] = shadow_buffer_.at(scalar_type);
|
||||
shadow_buffer_[scalar_type] = tmp;
|
||||
|
||||
auto tmp_event = current_buffer_events_.at(scalar_type);
|
||||
current_buffer_events_[scalar_type] = shadow_buffer_events_.at(scalar_type);
|
||||
shadow_buffer_events_[scalar_type] = tmp_event;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ReduceBucket> getBuffer(at::ScalarType scalar_type)
|
||||
{
|
||||
if (!hasKey(current_buffer_, scalar_type)) {
|
||||
current_buffer_[scalar_type] =
|
||||
std::make_shared<ReduceBucket>(initial_bucket_size_, scalar_type);
|
||||
current_buffer_events_[scalar_type] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
|
||||
if (enable_double_buffer_) {
|
||||
shadow_buffer_[scalar_type] =
|
||||
std::make_shared<ReduceBucket>(initial_bucket_size_, scalar_type);
|
||||
shadow_buffer_events_[scalar_type] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
}
|
||||
}
|
||||
|
||||
return current_buffer_.at(scalar_type);
|
||||
}
|
||||
|
||||
std::shared_ptr<at::cuda::CUDAEvent> getEvent(at::ScalarType scalar_type)
|
||||
{
|
||||
assert(hasKey(current_buffer_events_, scalar_type));
|
||||
return current_buffer_events_.at(scalar_type);
|
||||
}
|
||||
|
||||
void clear()
|
||||
{
|
||||
current_buffer_.clear();
|
||||
shadow_buffer_.clear();
|
||||
current_buffer_events_.clear();
|
||||
shadow_buffer_events_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t initial_bucket_size_;
|
||||
bool enable_double_buffer_;
|
||||
std::unordered_map<at::ScalarType, std::shared_ptr<ReduceBucket>> current_buffer_;
|
||||
std::unordered_map<at::ScalarType, std::shared_ptr<ReduceBucket>> shadow_buffer_;
|
||||
std::unordered_map<at::ScalarType, std::shared_ptr<at::cuda::CUDAEvent>> current_buffer_events_;
|
||||
std::unordered_map<at::ScalarType, std::shared_ptr<at::cuda::CUDAEvent>> shadow_buffer_events_;
|
||||
};
|
||||
|
||||
class DSParam {
|
||||
public:
|
||||
DSParam(long id,
|
||||
std::vector<int64_t> ds_shape,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
bool partitioned,
|
||||
int64_t offset, // for Z1
|
||||
bool persistent // for Z3
|
||||
)
|
||||
: id_(id),
|
||||
shape_(std::move(ds_shape)),
|
||||
ds_tensor_(ds_tensor),
|
||||
grad_buffer_(grad_buffer),
|
||||
partitioned_(partitioned),
|
||||
offset_(offset),
|
||||
persistent_(persistent),
|
||||
offload_stream_(at::cuda::getStreamFromPool()),
|
||||
reload_stream_(at::cuda::getStreamFromPool())
|
||||
{
|
||||
}
|
||||
|
||||
long getId() const { return id_; }
|
||||
std::vector<int64_t> getShape() const { return shape_; }
|
||||
at::Tensor getDSTensor() const
|
||||
{
|
||||
// If the reload event exists and is complete, return the reloaded tensor (if defined)
|
||||
if (reload_done_event_) {
|
||||
if (!reload_done_event_->query()) {
|
||||
reload_done_event_->block(at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
if (ds_reload_tensor_.defined()) { return ds_reload_tensor_; }
|
||||
}
|
||||
// Otherwise, if an offload event exists, wait for it to complete
|
||||
if (offload_done_event_) {
|
||||
if (!offload_done_event_->query()) {
|
||||
offload_done_event_->block(at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
}
|
||||
return ds_tensor_;
|
||||
}
|
||||
at::Tensor getGradBuffer() const { return grad_buffer_; }
|
||||
bool isPartitioned() const { return partitioned_; }
|
||||
int64_t getOffset() const { return offset_; }
|
||||
void setPersistent(bool persistent) { persistent_ = persistent; }
|
||||
bool isPersistent() const { return persistent_; }
|
||||
|
||||
void offload()
|
||||
{
|
||||
// If a reloaded tensor exists, offload its data back to ds_tensor_
|
||||
if (ds_reload_tensor_.defined()) {
|
||||
auto comp_stream = at::cuda::getCurrentCUDAStream();
|
||||
comp_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
// Record completion and wait on the offload stream
|
||||
comp_done_event_->record(comp_stream);
|
||||
comp_done_event_->block(offload_stream_);
|
||||
offload_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(offload_stream_);
|
||||
ds_tensor_.copy_(ds_reload_tensor_, /*non_blocking=*/true);
|
||||
ds_reload_tensor_.reset(); // Clear the reloaded tensor
|
||||
offload_done_event_->record(offload_stream_);
|
||||
}
|
||||
// Reset the reload event to indicate that no valid reload is present.
|
||||
if (reload_done_event_) { reload_done_event_.reset(); }
|
||||
}
|
||||
}
|
||||
|
||||
void reload()
|
||||
{
|
||||
// Reload only if the current ds_tensor_ is on CPU
|
||||
if (ds_tensor_.device().is_cpu()) {
|
||||
auto comp_stream = at::cuda::getCurrentCUDAStream();
|
||||
comp_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
// Record and wait on the reload stream
|
||||
comp_done_event_->record(comp_stream);
|
||||
comp_done_event_->block(reload_stream_);
|
||||
reload_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(reload_stream_);
|
||||
ds_reload_tensor_ =
|
||||
at::empty_like(ds_tensor_, ds_tensor_.options().device(torch::kCUDA));
|
||||
ds_reload_tensor_.copy_(ds_tensor_, /*non_blocking=*/true);
|
||||
reload_done_event_->record(reload_stream_);
|
||||
}
|
||||
// Reset offload_done_event if it exists to clear any stale offload state.
|
||||
if (offload_done_event_) { offload_done_event_.reset(); }
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
long id_;
|
||||
std::vector<int64_t> shape_;
|
||||
at::Tensor ds_tensor_;
|
||||
at::Tensor ds_reload_tensor_;
|
||||
at::Tensor grad_buffer_;
|
||||
bool partitioned_;
|
||||
int64_t offset_; // for Z1
|
||||
bool persistent_; // for Z3
|
||||
mutable bool is_reloaded = false;
|
||||
|
||||
at::cuda::CUDAStream offload_stream_;
|
||||
at::cuda::CUDAStream reload_stream_;
|
||||
std::shared_ptr<at::cuda::CUDAEvent> comp_done_event_;
|
||||
std::shared_ptr<at::cuda::CUDAEvent> offload_done_event_;
|
||||
std::shared_ptr<at::cuda::CUDAEvent> reload_done_event_;
|
||||
};
|
||||
|
||||
class DSParamRegistry {
|
||||
public:
|
||||
DSParamRegistry() {}
|
||||
~DSParamRegistry() {}
|
||||
|
||||
void registerParam(long ds_id,
|
||||
const std::vector<int64_t>& ds_shape,
|
||||
at::Tensor ds_tensor,
|
||||
at::Tensor grad_buffer,
|
||||
bool partitioned,
|
||||
int64_t offset, // for Z1
|
||||
bool persistent // for Z3
|
||||
)
|
||||
{
|
||||
grad_buffer.zero_();
|
||||
params_.emplace(
|
||||
ds_id,
|
||||
DSParam(ds_id, ds_shape, ds_tensor, grad_buffer, partitioned, offset, persistent));
|
||||
valid_[ds_id] = false;
|
||||
}
|
||||
|
||||
void registerGatheredParam(long ds_id, at::Tensor ds_tensor)
|
||||
{
|
||||
gathered_params_.emplace(ds_id, ds_tensor);
|
||||
}
|
||||
|
||||
void unregisterGatheredParam(long ds_id)
|
||||
{
|
||||
assert(hasKey(gathered_params_, ds_id));
|
||||
gathered_params_.erase(ds_id);
|
||||
valid_[ds_id] = false;
|
||||
}
|
||||
|
||||
const std::unordered_map<long, DSParam>& getParams() const { return params_; }
|
||||
|
||||
const DSParam& getParam(long ds_id) const { return params_.at(ds_id); }
|
||||
const size_t getNumParams() const { return params_.size(); }
|
||||
const at::Tensor& getGatheredParam(long ds_id) const
|
||||
{
|
||||
assert(hasKey(gathered_params_, ds_id));
|
||||
return gathered_params_.at(ds_id);
|
||||
}
|
||||
bool hasGatheredParam(long ds_id) const { return hasKey(gathered_params_, ds_id); }
|
||||
void setPersistent(long ds_id, bool persistent) { params_.at(ds_id).setPersistent(persistent); }
|
||||
void offload(long ds_id) { params_.at(ds_id).offload(); }
|
||||
void reload(long ds_id) { params_.at(ds_id).reload(); }
|
||||
|
||||
void setValid(long ds_id, bool valid) { valid_[ds_id] = valid; }
|
||||
bool isValid(long ds_id) const
|
||||
{
|
||||
assert(hasKey(valid_, ds_id));
|
||||
return valid_.at(ds_id);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<long, DSParam> params_;
|
||||
std::unordered_map<long, at::Tensor> gathered_params_;
|
||||
std::unordered_map<long, bool> valid_;
|
||||
};
|
||||
|
||||
class CustomOpExecutor {
|
||||
public:
|
||||
CustomOpExecutor(c10::intrusive_ptr<c10d::ProcessGroup> process_group,
|
||||
std::shared_ptr<DSParamRegistry> param_registry,
|
||||
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets,
|
||||
std::vector<long> ds_ids,
|
||||
ncclComm_t nccl_comm,
|
||||
at::cuda::CUDAStream rs_stream,
|
||||
at::cuda::CUDAStream copy_stream,
|
||||
bool pre_div_reduce)
|
||||
: process_group_(process_group),
|
||||
param_registry_(std::move(param_registry)),
|
||||
reduce_buckets_(std::move(reduce_buckets)),
|
||||
ds_ids_(std::move(ds_ids)),
|
||||
nccl_comm_(nccl_comm),
|
||||
rs_stream_(rs_stream),
|
||||
copy_stream_(copy_stream),
|
||||
pre_div_reduce_(pre_div_reduce)
|
||||
{
|
||||
for (long ds_id : ds_ids_) {
|
||||
has_acc_grad_[ds_id] = false;
|
||||
|
||||
rs_comp_done_events_[ds_id] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
rs_copy_done_events_[ds_id] =
|
||||
std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
|
||||
}
|
||||
reduce_counter_ = ds_ids_.size();
|
||||
}
|
||||
~CustomOpExecutor() {}
|
||||
|
||||
virtual void startForward() {}
|
||||
|
||||
virtual void endForward() {}
|
||||
|
||||
virtual void startBackward(bool update) { param_updated_ = update; }
|
||||
|
||||
virtual void endBackward() {}
|
||||
|
||||
at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id)
|
||||
{
|
||||
int world_size = process_group_->getSize();
|
||||
const DSParam& param = param_registry_->getParam(ds_id);
|
||||
const auto scalar_type = grad_tensor.scalar_type();
|
||||
std::shared_ptr<ReduceBucket> reduce_bucket = reduce_buckets_->getBuffer(scalar_type);
|
||||
|
||||
auto comp_stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (reduce_bucket->shouldFlush(grad_tensor.numel())) {
|
||||
int rank = process_group_->getRank();
|
||||
|
||||
flushReduceBucket(scalar_type);
|
||||
|
||||
// reduce_bucket is swapped in flushReduceBucket if double buffering is enabled
|
||||
reduce_bucket = reduce_buckets_->getBuffer(scalar_type);
|
||||
}
|
||||
|
||||
if (grad_tensor.numel() > reduce_bucket->getSize()) {
|
||||
// extend buckets
|
||||
at::cuda::stream_synchronize(rs_stream_);
|
||||
reduce_bucket->reserve(grad_tensor.numel());
|
||||
}
|
||||
|
||||
at::Tensor reduce_in_buffer = reduce_bucket->allocate(grad_tensor.numel());
|
||||
|
||||
// This ensures the order of reduce_scatter -> copy
|
||||
// Without this block, copy may start while reduce_scatter is still running
|
||||
reduce_buckets_->getEvent(scalar_type)->block(comp_stream);
|
||||
auto copy_src = grad_tensor.contiguous().view({-1}).detach();
|
||||
// keep references to copy src
|
||||
reduce_tasks_[scalar_type].emplace_back(ds_id, copy_src, reduce_in_buffer);
|
||||
|
||||
// computation must be done before copy
|
||||
rs_comp_done_events_[ds_id]->record(comp_stream);
|
||||
rs_comp_done_events_[ds_id]->block(copy_stream_);
|
||||
{
|
||||
at::cuda::CUDAStreamGuard guard(copy_stream_);
|
||||
reduce_in_buffer.copy_(copy_src, true);
|
||||
rs_copy_done_events_[ds_id]->record(copy_stream_);
|
||||
}
|
||||
|
||||
reduce_counter_--;
|
||||
|
||||
if (reduce_counter_ == 0) {
|
||||
flushAllReduceBuckets();
|
||||
|
||||
reduce_counter_ = ds_ids_.size();
|
||||
|
||||
// This synchronization ensures all of reduce calls are done before optimizer's step.
|
||||
at::cuda::stream_synchronize(rs_stream_);
|
||||
|
||||
endBackward();
|
||||
}
|
||||
|
||||
return at::Tensor();
|
||||
}
|
||||
|
||||
bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); }
|
||||
|
||||
protected:
|
||||
c10::intrusive_ptr<c10d::ProcessGroup> process_group_;
|
||||
std::shared_ptr<DSParamRegistry> param_registry_;
|
||||
std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets_;
|
||||
std::vector<long> ds_ids_;
|
||||
ncclComm_t nccl_comm_;
|
||||
at::cuda::CUDAStream rs_stream_;
|
||||
at::cuda::CUDAStream copy_stream_;
|
||||
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> rs_comp_done_events_;
|
||||
std::unordered_map<long, std::shared_ptr<at::cuda::CUDAEvent>> rs_copy_done_events_;
|
||||
|
||||
size_t reduce_counter_ = 0;
|
||||
bool param_updated_ = false;
|
||||
std::unordered_map<at::ScalarType, std::vector<ReduceTask>> reduce_tasks_;
|
||||
std::unordered_map<long, bool> has_acc_grad_;
|
||||
bool pre_div_reduce_;
|
||||
|
||||
virtual void flushReduceBucket(at::ScalarType scalar_type) = 0;
|
||||
|
||||
void flushAllReduceBuckets()
|
||||
{
|
||||
for (const auto& it : reduce_tasks_) { flushReduceBucket(it.first); }
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
std::shared_ptr<T> getExecutor(long graph_id,
|
||||
const std::unordered_map<long, std::shared_ptr<U>>& executors)
|
||||
{
|
||||
assert(hasKey(executors, graph_id));
|
||||
if (auto executor = std::dynamic_pointer_cast<T>(executors.at(graph_id))) { return executor; }
|
||||
throw std::runtime_error("Invalid executor type");
|
||||
}
|
||||
|
||||
extern std::shared_ptr<DSParamRegistry> param_registry;
|
||||
extern std::unordered_map<long, std::shared_ptr<CustomOpExecutor>> executors;
|
||||
extern std::shared_ptr<DoubleBufferedReduceBucket> reduce_buckets;
|
||||
|
||||
at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id);
|
||||
at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id);
|
||||
void free_tensors(std::vector<at::Tensor> tensors);
|
||||
void free_tensors_meta(std::vector<at::Tensor> tensors);
|
||||
|
||||
void init(c10::intrusive_ptr<c10d::ProcessGroup> pg,
|
||||
int64_t initial_reduce_bucket_size,
|
||||
bool enable_double_buffer,
|
||||
bool _use_symm_mem,
|
||||
bool _clone_custom_op_output,
|
||||
bool _sync_before_reduce,
|
||||
bool _sync_after_reduce,
|
||||
bool _sync_before_allgather,
|
||||
bool _sync_after_allgather);
|
||||
void reset();
|
||||
void cleanup();
|
||||
|
||||
void start_forward();
|
||||
void end_forward();
|
||||
void start_backward(bool update);
|
||||
|
||||
} // namespace dc
|
@ -621,6 +621,17 @@ def initialize_mesh_device(mesh_shape, mesh_dim_names):
|
||||
return mesh_device
|
||||
|
||||
|
||||
def enable_symm_mem_for_group(group_name: str):
|
||||
global cdb
|
||||
assert cdb is not None and cdb.is_initialized(
|
||||
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
|
||||
|
||||
if hasattr(cdb, 'enable_symm_mem_for_group'):
|
||||
cdb.enable_symm_mem_for_group(group_name)
|
||||
else:
|
||||
raise RuntimeError(f"Backend {cdb.name} does not support symmetric memory initialization")
|
||||
|
||||
|
||||
# Main DeepSpeed Comms. public API.
|
||||
def init_distributed(dist_backend=None,
|
||||
auto_mpi_discovery=True,
|
||||
|
@ -409,6 +409,13 @@ class TorchBackend(Backend):
|
||||
mesh_shape,
|
||||
mesh_dim_names=mesh_dim_names)
|
||||
|
||||
def enable_symm_mem_for_group(self, group_name):
|
||||
if not required_torch_version(min_version=2.5):
|
||||
raise RuntimeError(f"Torch version must be 2.5 or higher to use symmetric memory. "
|
||||
f"Current version: {torch.__version__}")
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
return enable_symm_mem_for_group(group_name)
|
||||
|
||||
|
||||
# This will become a light-weight wrapper around torch.distributed functions
|
||||
# TODO: create some example to show how this wrapper can help profile communication
|
||||
|
4
deepspeed/compile/__init__.py
Normal file
4
deepspeed/compile/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
279
deepspeed/compile/backend.py
Normal file
279
deepspeed/compile/backend.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Dict, List, Callable
|
||||
import time
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule
|
||||
|
||||
try:
|
||||
import torch.utils._pytree as pytree
|
||||
import torch._dynamo
|
||||
import torch._inductor.scheduler
|
||||
from functorch.compile import make_boxed_func
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from .fx import add_free_activations
|
||||
from .graph_param import DSGraphParamManager
|
||||
from .profilers import ProfilingResult
|
||||
from .profilers.graph_profile import MemoryProfilingInterpreter
|
||||
from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs
|
||||
from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0
|
||||
from .partitioner import get_wrapped_partitioner
|
||||
from .inductor import register_custom_ops, patch_create_aot_dispatcher_function
|
||||
|
||||
remaining_schedule = None
|
||||
next_pass_step = -1
|
||||
next_passes = None
|
||||
current_passes = None
|
||||
|
||||
param_manager: Dict[int, DSGraphParamManager] = {}
|
||||
graph_order = []
|
||||
profiling_results: Dict[int, ProfilingResult] = {}
|
||||
opt_pass_times = []
|
||||
|
||||
opt_passes = {}
|
||||
|
||||
fwd_real_inputs = []
|
||||
remaining_bwd_compile_count = 0
|
||||
|
||||
|
||||
def register_compile_pass(name: str, opt_pass_fn):
|
||||
opt_passes[name] = opt_pass_fn
|
||||
|
||||
|
||||
def init_schedule(schedule):
|
||||
|
||||
assert isinstance(schedule, list), f"schedule should be a list, but got {type(schedule)}"
|
||||
|
||||
for step, passes in schedule:
|
||||
assert isinstance(step, int), f"Each step in schedule should be an integer, but got {type(step)}"
|
||||
assert isinstance(passes, list), f"Passes at a certain step should be a list, but got {type(passes)}"
|
||||
|
||||
global remaining_schedule
|
||||
remaining_schedule = schedule
|
||||
|
||||
|
||||
def launch_compile_passes(global_steps: int):
|
||||
global next_pass_step, next_passes
|
||||
|
||||
if len(remaining_schedule) > 0 and global_steps == remaining_schedule[0][0]:
|
||||
_, next_passes = remaining_schedule.pop(0)
|
||||
log_rank0(f"Launching compile passes: global_steps={global_steps} passes={next_passes}", True)
|
||||
|
||||
torch._dynamo.reset()
|
||||
get_deepcompile_handle().reset()
|
||||
patch_compiled_func()
|
||||
graph_order.clear()
|
||||
profiling_results.clear()
|
||||
param_manager.clear()
|
||||
|
||||
|
||||
def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results):
|
||||
node_time = []
|
||||
tensor_sizes = []
|
||||
|
||||
for n in graph.nodes:
|
||||
node_time.append((n.name, n.meta["device_time"] if "device_time" in n.meta else 0.0,
|
||||
n.meta["wall_time"] if "wall_time" in n.meta else 0.0))
|
||||
tensor_sizes.append((n.name, n.meta["tensor_size"] if "tensor_size" in n.meta else 0))
|
||||
|
||||
if bwd:
|
||||
profiling_results[graph_id].bwd_graph = graph
|
||||
profiling_results[graph_id].bwd_time = node_time
|
||||
profiling_results[graph_id].bwd_tensor_sizes = tensor_sizes
|
||||
profiling_results[graph_id].bwd_mem = mem
|
||||
else:
|
||||
profiling_results[graph_id].fwd_graph = graph
|
||||
profiling_results[graph_id].fwd_time = node_time
|
||||
profiling_results[graph_id].fwd_tensor_sizes = tensor_sizes
|
||||
profiling_results[graph_id].fwd_mem = mem
|
||||
|
||||
|
||||
def run_opt_passes(opt_passes: List[Callable],
|
||||
gm: GraphModule,
|
||||
graph_id: int,
|
||||
graph_order: List[int],
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
mem_budget: float,
|
||||
param_manager,
|
||||
bwd: bool,
|
||||
debug_log=False) -> None:
|
||||
|
||||
with unset_fake_temporarily():
|
||||
get_accelerator().synchronize()
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
for i, opt_pass_fn in enumerate(opt_passes):
|
||||
log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log)
|
||||
|
||||
gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager,
|
||||
bwd)
|
||||
if gm_new is not None:
|
||||
gm = gm_new
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log)
|
||||
mem_prof.run(*create_inputs_fn())
|
||||
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
|
||||
|
||||
set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results)
|
||||
|
||||
with unset_fake_temporarily():
|
||||
get_accelerator().synchronize()
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
|
||||
def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False):
|
||||
|
||||
register_custom_ops()
|
||||
|
||||
def backend_fn(gm: GraphModule, real_inputs):
|
||||
graph_id = id(gm.graph)
|
||||
needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs)
|
||||
|
||||
global graph_order
|
||||
graph_order.append((graph_id, needs_backward))
|
||||
|
||||
z3_partition = any(hasattr(v, "ds_id") for v in real_inputs)
|
||||
if z3_partition:
|
||||
param_indices = [(i, input_val.ds_id, input_val.ds_shape) for i, input_val in enumerate(real_inputs)
|
||||
if isinstance(input_val, torch.nn.Parameter)]
|
||||
else:
|
||||
assert all(hasattr(v, "param_id") for v in real_inputs
|
||||
if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id"
|
||||
param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs)
|
||||
if isinstance(input_val, torch.nn.Parameter)]
|
||||
|
||||
global fwd_real_inputs
|
||||
fwd_real_inputs.append(real_inputs)
|
||||
|
||||
global profiling_results
|
||||
if graph_id not in profiling_results:
|
||||
profiling_results[graph_id] = ProfilingResult()
|
||||
profiling_results[graph_id].param_indices = param_indices
|
||||
profiling_results[graph_id].needs_backward = needs_backward
|
||||
|
||||
def make_fw_graph(gm, sample_inputs):
|
||||
time_start = time.time()
|
||||
graph_index = len(graph_order) - 1
|
||||
real_inputs = fwd_real_inputs.pop(0)
|
||||
|
||||
param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices)
|
||||
|
||||
real_inputs_with_rng = real_inputs + sample_inputs[len(real_inputs):]
|
||||
run_opt_passes(
|
||||
opt_passes=next_passes,
|
||||
gm=gm,
|
||||
graph_id=graph_id,
|
||||
graph_order=graph_order,
|
||||
profiling_results=profiling_results,
|
||||
create_inputs_fn=lambda: real_inputs_with_rng,
|
||||
mem_budget=.0, # unused
|
||||
param_manager=param_manager,
|
||||
bwd=False,
|
||||
debug_log=debug_log)
|
||||
|
||||
if needs_backward:
|
||||
global remaining_bwd_compile_count
|
||||
remaining_bwd_compile_count += 1
|
||||
|
||||
opt_pass_times.append(("fwd", graph_index, graph_id, time.time() - time_start))
|
||||
|
||||
log_rank0(
|
||||
f"Fwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
|
||||
enable=debug_log)
|
||||
|
||||
return gm.graph
|
||||
|
||||
def make_bw_graph(gm, sample_inputs):
|
||||
time_start = time.time()
|
||||
|
||||
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
||||
log_rank0(
|
||||
f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
|
||||
enable=debug_log)
|
||||
|
||||
bwd_inputs_stack = get_backward_inputs()
|
||||
|
||||
if len(bwd_inputs_stack) == 0:
|
||||
# dynamo calls bw compiler ahead of time when symints are saved for backward. See the details for aot_dispatch_autograd in jit_compile_runtime_wrappers.
|
||||
# As we currently use actually bwd input values in bw compiler, we return None to skip the compilation there.
|
||||
# This would need be handled properly in the future.
|
||||
return None
|
||||
|
||||
bwd_real_inputs = bwd_inputs_stack.pop()
|
||||
run_opt_passes(
|
||||
opt_passes=next_passes,
|
||||
gm=gm,
|
||||
graph_id=graph_id,
|
||||
graph_order=graph_order,
|
||||
profiling_results=profiling_results,
|
||||
create_inputs_fn=lambda: tuple(bwd_real_inputs),
|
||||
mem_budget=.0, # unused
|
||||
param_manager=param_manager,
|
||||
bwd=True,
|
||||
debug_log=debug_log)
|
||||
|
||||
# assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager"
|
||||
|
||||
if free_activation:
|
||||
param_nodes_bw, _ = param_manager[graph_id].get_bwd_mapping(gm.graph)
|
||||
param_names = [n.name for n in param_nodes_bw]
|
||||
non_param_input_names = [n.name for n in get_input_nodes(gm.graph) if n.name not in param_names]
|
||||
add_free_activations(graph_id, gm.graph,
|
||||
get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names))
|
||||
|
||||
global remaining_bwd_compile_count
|
||||
remaining_bwd_compile_count -= 1
|
||||
if remaining_bwd_compile_count == 0:
|
||||
unpatch_compiled_func()
|
||||
|
||||
log_rank0(
|
||||
f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}",
|
||||
enable=debug_log)
|
||||
|
||||
gm.recompile()
|
||||
opt_pass_times.append(("bwd", graph_index, graph_id, time.time() - time_start))
|
||||
|
||||
return gm.graph
|
||||
|
||||
if backend == "eager":
|
||||
|
||||
def make_compiler_fn(make_graph_fn):
|
||||
|
||||
def compiler_fn(gm, sample_inputs):
|
||||
return None if make_graph_fn(gm, sample_inputs) is None else make_boxed_func(gm.forward)
|
||||
|
||||
return compiler_fn
|
||||
|
||||
aot_mod = aot_module_simplified(gm,
|
||||
real_inputs,
|
||||
fw_compiler=make_compiler_fn(make_fw_graph),
|
||||
bw_compiler=make_compiler_fn(make_bw_graph),
|
||||
partition_fn=get_wrapped_partitioner(param_indices))
|
||||
return torch._dynamo.optimize(**compile_kwargs)(aot_mod)
|
||||
elif backend == "inductor":
|
||||
patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs,
|
||||
param_indices, param_manager)
|
||||
from .partitioner import get_wrapped_choose_saved_values_set
|
||||
torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices)
|
||||
|
||||
return torch._inductor.compile(gm, real_inputs)
|
||||
|
||||
raise ValueError(f"Unsupported backend {backend}")
|
||||
|
||||
return backend_fn
|
46
deepspeed/compile/config.py
Normal file
46
deepspeed/compile/config.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
|
||||
|
||||
class CompileConfig(DeepSpeedConfigModel):
|
||||
""" Configure compile settings """
|
||||
|
||||
deepcompile: bool = False
|
||||
""" Turn on/off the DeepCompile mode """
|
||||
|
||||
free_activation: bool = False
|
||||
""" Turn on/off the free activation mode """
|
||||
|
||||
offload_activation: bool = False
|
||||
""" Turn on/off the activation offloading """
|
||||
|
||||
offload_opt_states: bool = False
|
||||
""" Turn on/off the optimizer states offloading """
|
||||
|
||||
double_buffer: bool = True
|
||||
""" Turn on/off the double buffering """
|
||||
|
||||
symmetric_memory: bool = False
|
||||
""" Turn on/off the symmetric memory """
|
||||
|
||||
debug_log: bool = False
|
||||
""" Turn on/off the graph dumping """
|
||||
|
||||
offload_parameters: bool = False
|
||||
""" Turn on/off the parameter offloading """
|
||||
|
||||
sync_before_reduce: bool = False
|
||||
""" Turn on/off the sync before reduce """
|
||||
|
||||
sync_after_reduce: bool = False
|
||||
""" Turn on/off the sync after reduce """
|
||||
|
||||
sync_before_allgather: bool = False
|
||||
""" Turn on/off the sync before allgather """
|
||||
|
||||
sync_after_allgather: bool = False
|
||||
""" Turn on/off the sync after allgather """
|
139
deepspeed/compile/fx.py
Normal file
139
deepspeed/compile/fx.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import Callable, Any, List
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, Graph
|
||||
|
||||
from .util import get_last_uses
|
||||
|
||||
|
||||
def get_output_node(graph: Graph):
|
||||
for v in graph.nodes:
|
||||
if v.target == "output":
|
||||
return v
|
||||
raise ValueError("No output node found")
|
||||
|
||||
|
||||
def move_primals_to_head(graph: Graph):
|
||||
|
||||
# Move primals to the head of the graph
|
||||
primals = [n for n in graph.nodes if n.op == "placeholder"]
|
||||
non_primals = [n for n in graph.nodes if n.op != "placeholder"]
|
||||
all_nodes = primals + non_primals
|
||||
|
||||
new_graph = Graph()
|
||||
env = {}
|
||||
for node in all_nodes:
|
||||
new_node = new_graph.node_copy(node, lambda n: env[n.name])
|
||||
env[node.name] = new_node
|
||||
new_graph.lint()
|
||||
|
||||
return new_graph
|
||||
|
||||
|
||||
def add_args_process(graph: Graph,
|
||||
node: Node,
|
||||
fn: Callable[..., Any],
|
||||
extra_args: List[int] = [],
|
||||
name=None,
|
||||
meta={}) -> List[Node]:
|
||||
# Apply fn to all args of node
|
||||
new_nodes = []
|
||||
with graph.inserting_before(node):
|
||||
target_args = [arg for arg in node.args if isinstance(arg, Node)]
|
||||
|
||||
for arg in target_args:
|
||||
new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name)
|
||||
for k, v in meta.items():
|
||||
new_node.meta[k] = v
|
||||
node.replace_input_with(arg, new_node)
|
||||
new_nodes.append(new_node)
|
||||
|
||||
return new_nodes
|
||||
|
||||
|
||||
def add_postprocess(graph: Graph,
|
||||
node: Node,
|
||||
fn: Callable[..., Any],
|
||||
extra_args: List[int] = [],
|
||||
name=None,
|
||||
meta={}) -> Node:
|
||||
# https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py
|
||||
with graph.inserting_after(node):
|
||||
args = (node, )
|
||||
for a in extra_args: # To add ds_id
|
||||
args += (a, )
|
||||
|
||||
node_users = node.users.keys()
|
||||
new_node = graph.create_node('call_function', fn, args, {}, name=name)
|
||||
users = {}
|
||||
for u in node_users:
|
||||
if u != new_node:
|
||||
users[u] = (node, new_node)
|
||||
for u, (old_in, new_in) in users.items():
|
||||
u.replace_input_with(old_in, new_in)
|
||||
|
||||
for k, v in meta.items():
|
||||
new_node.meta[k] = v
|
||||
|
||||
return new_node
|
||||
|
||||
|
||||
def _make_node_meta(node: Node, ds_id: int, comm: bool):
|
||||
meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm}
|
||||
if "tensor_meta" in node.meta:
|
||||
meta["tensor_meta"] = node.meta["tensor_meta"]
|
||||
return meta
|
||||
|
||||
|
||||
def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]):
|
||||
node_to_last_use, _ = get_last_uses(graph)
|
||||
activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names])
|
||||
|
||||
offload_id_to_node = {}
|
||||
node_to_wait_reload = {}
|
||||
for node in graph.nodes:
|
||||
if node.target == torch.ops.dc.reload_tensor.default:
|
||||
offload_act = node.args[0]
|
||||
# node_to_offload_id[offload_act] = node.args[2]
|
||||
offload_id_to_node[node.args[2]] = offload_act
|
||||
elif node.target == torch.ops.dc.wait_reload.default:
|
||||
offload_id = node.args[2]
|
||||
node_to_wait_reload[offload_id_to_node[offload_id]] = node
|
||||
|
||||
activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set)
|
||||
|
||||
last_user_to_uses = defaultdict(list)
|
||||
for node, last_user in node_to_last_use.items():
|
||||
last_user_to_uses[last_user].append(node)
|
||||
|
||||
def _should_free(node: Node) -> bool:
|
||||
if not hasattr(node, "meta"):
|
||||
return False
|
||||
if not "tensor_meta" in node.meta:
|
||||
return False
|
||||
return True
|
||||
|
||||
def free_tensors(tensors: List[torch.Tensor]):
|
||||
for a in tensors:
|
||||
if a.numel() > 10_000_000:
|
||||
a.data = torch.empty([0], device=a.device, dtype=a.dtype)
|
||||
|
||||
for last_user, used_nodes in last_user_to_uses.items():
|
||||
activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)]
|
||||
|
||||
if len(activation_args) == 0:
|
||||
continue
|
||||
|
||||
node_name = f"free_activations_{[n.name for n in used_nodes]}"
|
||||
with graph.inserting_after(last_user):
|
||||
args = (activation_args, )
|
||||
graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name)
|
||||
|
||||
# Python version for debugging
|
||||
# graph.create_node('call_function', free_tensors, args, {}, name=node_name)
|
84
deepspeed/compile/graph_param.py
Normal file
84
deepspeed/compile/graph_param.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from .fx import get_output_node
|
||||
from .util import get_param_nodes
|
||||
|
||||
|
||||
@dataclass
|
||||
class DSGraphParam:
|
||||
name: str
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
node: Node
|
||||
allgather_node: Node
|
||||
release_node: Node
|
||||
param: torch.Tensor
|
||||
numel: int = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.numel = reduce(lambda x, y: x * y, self.shape)
|
||||
|
||||
|
||||
class DSGraphParamManager:
|
||||
|
||||
def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]):
|
||||
self._fw_graph = fw_graph
|
||||
self._bw_graph = None
|
||||
self._params: Dict[str, DSGraphParam] = {}
|
||||
self._param_name_to_grad: Dict[str, Node] = {}
|
||||
self._ds_ids: Dict[str, int] = {}
|
||||
|
||||
param_nodes = get_param_nodes(fw_graph, index_to_ds_ids)
|
||||
self._param_names = [pn.name for pn in param_nodes]
|
||||
self._param_indices = [i for i, _, _ in index_to_ds_ids]
|
||||
|
||||
param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids]
|
||||
ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids]
|
||||
ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids]
|
||||
|
||||
for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes):
|
||||
self._params[pn.name] = DSGraphParam(name=pn.name,
|
||||
shape=ds_shape,
|
||||
dtype=pi.dtype,
|
||||
device=pi.device,
|
||||
node=pn,
|
||||
allgather_node=None,
|
||||
release_node=None,
|
||||
param=pi)
|
||||
self._ds_ids[pn.name] = ds_id
|
||||
|
||||
def get_bwd_mapping(self, bw_graph: Graph):
|
||||
self._bw_graph = bw_graph
|
||||
|
||||
output_node = get_output_node(bw_graph)
|
||||
param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names]
|
||||
grad_outputs = [output_node.args[0][i] for i in self._param_indices]
|
||||
param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)}
|
||||
return param_nodes_bw, param_name_to_grad
|
||||
|
||||
@property
|
||||
def param_names(self) -> List[str]:
|
||||
return self._param_names
|
||||
|
||||
@property
|
||||
def params(self) -> Dict[str, DSGraphParam]:
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def ds_ids(self) -> Dict[str, int]:
|
||||
return self._ds_ids
|
||||
|
||||
def get_grad_name(self, param_name) -> str:
|
||||
assert self._param_name_to_grad is not None, "Backward graph is not added yet"
|
||||
return self._param_name_to_grad[param_name]
|
214
deepspeed/compile/inductor.py
Normal file
214
deepspeed/compile/inductor.py
Normal file
@ -0,0 +1,214 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._functorch.aot_autograd import create_aot_dispatcher_function
|
||||
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
|
||||
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._inductor.scheduler import Scheduler
|
||||
|
||||
original_create_aot_dispatcher_function = create_aot_dispatcher_function
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from .util import get_input_nodes
|
||||
from .graph_param import DSGraphParamManager
|
||||
|
||||
|
||||
def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool):
|
||||
|
||||
def wrapped_compiler(gm, fake_inputs):
|
||||
mod_graph = dc_compiler(gm, fake_inputs)
|
||||
|
||||
# For symint case
|
||||
if mod_graph is None:
|
||||
return None
|
||||
|
||||
if z3_partition:
|
||||
# Inductor validates input size estimated by the first trace, where ds tensor is materialized.
|
||||
# We need to patch the input tensors to avoid the validation error.
|
||||
patched_inputs = []
|
||||
if bwd:
|
||||
param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph)
|
||||
param_names = [n.name for n in param_nodes_bw]
|
||||
else:
|
||||
param_names = graph_param_manager[graph_id].param_names
|
||||
input_nodes = get_input_nodes(gm.graph)
|
||||
|
||||
for in_node, in_v in zip(input_nodes, fake_inputs):
|
||||
ds_param = in_node.name in param_names
|
||||
if ds_param:
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
from torch._dynamo.utils import to_fake_tensor
|
||||
assert is_fake(in_v), f"Input {in_v} should be fake tensor"
|
||||
patched_inputs.append(
|
||||
to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode))
|
||||
else:
|
||||
patched_inputs.append(in_v)
|
||||
|
||||
patched_inputs = tuple(patched_inputs)
|
||||
else:
|
||||
patched_inputs = fake_inputs
|
||||
|
||||
return original_compiler(gm, patched_inputs)
|
||||
|
||||
return wrapped_compiler
|
||||
|
||||
|
||||
def wrap_partition_fn(partition_fn, real_inputs, param_indices):
|
||||
|
||||
def wrapped_partition_fn(*args, **kwargs):
|
||||
|
||||
fw_module, bw_module = partition_fn(*args, **kwargs)
|
||||
|
||||
# get parameter names
|
||||
pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices)
|
||||
|
||||
def fix_placeholder_meta(graph):
|
||||
for n in graph.nodes:
|
||||
if n.op == "placeholder" and n.name in pm.param_names:
|
||||
n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device)
|
||||
|
||||
fix_placeholder_meta(fw_module.graph)
|
||||
fix_placeholder_meta(bw_module.graph)
|
||||
|
||||
return fw_module, bw_module
|
||||
|
||||
return wrapped_partition_fn
|
||||
|
||||
|
||||
def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs,
|
||||
param_indices, param_manager):
|
||||
|
||||
from torch._dynamo.backends.common import AotAutograd
|
||||
import functools
|
||||
|
||||
def patch_aotautograd():
|
||||
# Unpatch if it was already patched
|
||||
if hasattr(AotAutograd, "__original_init"):
|
||||
AotAutograd.__init__ = AotAutograd.__original_init
|
||||
|
||||
original_init = AotAutograd.__init__
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def patched_init(self, **kwargs):
|
||||
kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"],
|
||||
make_fw_graph,
|
||||
z3_partition,
|
||||
graph_id,
|
||||
param_manager,
|
||||
bwd=False)
|
||||
kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"],
|
||||
make_bw_graph,
|
||||
z3_partition,
|
||||
graph_id,
|
||||
param_manager,
|
||||
bwd=True)
|
||||
kwargs["inference_compiler"] = kwargs["fw_compiler"]
|
||||
|
||||
if z3_partition:
|
||||
kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices)
|
||||
|
||||
original_init(self, **kwargs)
|
||||
|
||||
AotAutograd.__original_init = original_init
|
||||
AotAutograd.__init__ = patched_init
|
||||
|
||||
patch_aotautograd()
|
||||
|
||||
|
||||
def register_custom_ops():
|
||||
|
||||
def fallback_handler_no_reuse(kernel,
|
||||
never_reuse_input,
|
||||
never_reuse_output,
|
||||
force_free_input,
|
||||
add_to_fallback_set=True):
|
||||
if add_to_fallback_set:
|
||||
fallbacks.add(kernel)
|
||||
|
||||
def handler(*args, **kwargs):
|
||||
|
||||
def wrap_tensors(x):
|
||||
out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x
|
||||
if out is not None and never_reuse_output:
|
||||
V.graph.never_reuse_buffers.add(out.get_name())
|
||||
return out
|
||||
|
||||
class CustomDCKernel(FallbackKernel):
|
||||
|
||||
def __init__(self, op, *args, **kwargs):
|
||||
super().__init__(op, *args, **kwargs)
|
||||
|
||||
def add_to_never_reuse(x):
|
||||
if isinstance(x, IRNode):
|
||||
assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}"
|
||||
V.graph.never_reuse_buffers.add(x.get_name())
|
||||
|
||||
if never_reuse_input:
|
||||
pytree.tree_map(add_to_never_reuse, args)
|
||||
|
||||
def get_var_name_for_arg(self, arg: str):
|
||||
if arg.isidentifier():
|
||||
return arg
|
||||
|
||||
import re
|
||||
match = re.match(r"reinterpret_tensor\((\w+),", arg)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def codegen(self, wrapper):
|
||||
if not force_free_input:
|
||||
return super().codegen(wrapper)
|
||||
|
||||
kernel = self.op_overload
|
||||
self.codegen_comment(wrapper)
|
||||
args = [*self.codegen_args(), *self.codegen_kwargs()]
|
||||
|
||||
V.graph.wrapper_code.generate_fallback_kernel(self, args)
|
||||
if isinstance(self.layout, Layout):
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
var_name = self.get_var_name_for_arg(args[0])
|
||||
if var_name:
|
||||
wrapper.writeline(f"{var_name} = None")
|
||||
|
||||
self.codegen_unbacked_symbol_defs(wrapper)
|
||||
|
||||
kernel_cls = CustomDCKernel if force_free_input else FallbackKernel
|
||||
return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs))
|
||||
|
||||
return handler
|
||||
|
||||
def register_fallback_no_reuse(op_overload,
|
||||
never_reuse_input=False,
|
||||
never_reuse_output=False,
|
||||
force_free_input=False):
|
||||
add_needs_realized_inputs(op_overload)
|
||||
return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse(
|
||||
op_overload,
|
||||
never_reuse_input=never_reuse_input,
|
||||
never_reuse_output=never_reuse_output,
|
||||
force_free_input=force_free_input))
|
||||
|
||||
# Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops.
|
||||
# -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops.
|
||||
register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True)
|
||||
register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True)
|
||||
register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False)
|
||||
register_fallback_no_reuse(torch.ops.dc.reduce_grad.default,
|
||||
never_reuse_input=True,
|
||||
never_reuse_output=True,
|
||||
force_free_input=True)
|
||||
register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True)
|
||||
|
||||
if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched:
|
||||
Scheduler.is_dc_patched = True
|
||||
Scheduler.dead_node_elimination = lambda _: None
|
82
deepspeed/compile/init_z1.py
Normal file
82
deepspeed/compile/init_z1.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from .passes import zero1_compile, zero3_compile
|
||||
from .backend import make_backend, launch_compile_passes, init_schedule
|
||||
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
|
||||
|
||||
WARMUP = 5
|
||||
|
||||
|
||||
def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None):
|
||||
|
||||
optimizer = engine.optimizer
|
||||
optimizer.contiguous_gradients = False # Avoid creating unnecessary buffer
|
||||
for hook in optimizer._grad_acc_hooks:
|
||||
hook.remove()
|
||||
optimizer._grad_acc_hooks.clear()
|
||||
|
||||
dc = get_deepcompile_handle()
|
||||
dc.init(engine.data_parallel_group,
|
||||
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
|
||||
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, False,
|
||||
False)
|
||||
|
||||
grad_buffer = {}
|
||||
|
||||
for i, group in enumerate(optimizer.bit16_groups):
|
||||
|
||||
grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i],
|
||||
optimizer.first_offset[i],
|
||||
optimizer.partition_size[i],
|
||||
dtype=optimizer.gradient_accumulation_dtype,
|
||||
device=get_accelerator().current_device_name(),
|
||||
return_tensor_list=True)
|
||||
grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary
|
||||
|
||||
index_in_partition = 0
|
||||
first_in_partition = True
|
||||
for p in group:
|
||||
param_id = optimizer.get_param_id(p)
|
||||
p.param_id = param_id
|
||||
in_partition = optimizer.is_param_in_current_partition[param_id]
|
||||
|
||||
if in_partition:
|
||||
buf = grad_buffer[i][index_in_partition]
|
||||
offset = optimizer.first_offset[i] if first_in_partition else 0
|
||||
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}")
|
||||
dc.register_z1_param(p.param_id, p.shape, p, buf, int(offset))
|
||||
index_in_partition += 1
|
||||
first_in_partition = False
|
||||
else:
|
||||
# print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None")
|
||||
dc.register_z1_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0)
|
||||
|
||||
def set_grad_buffer():
|
||||
optimizer.averaged_gradients = copy.copy(grad_buffer)
|
||||
|
||||
add_pre_backward_hook(set_grad_buffer)
|
||||
|
||||
if schedule is None:
|
||||
schedule = []
|
||||
schedule.append((0, [zero1_compile.add_z1_reduce]))
|
||||
else:
|
||||
for opt in schedule:
|
||||
# avoid typical misconfiguration
|
||||
if zero3_compile.add_z3_gather_release in opt[1]:
|
||||
raise ValueError("A pass for ZeRO3 is not specified though ZeRO1 is enabled")
|
||||
|
||||
init_schedule(schedule)
|
||||
|
||||
engine.launch_compile_passes = launch_compile_passes
|
||||
return make_backend(backend,
|
||||
compile_kwargs=compile_kwargs,
|
||||
free_activation=False,
|
||||
debug_log=compile_config.debug_log)
|
94
deepspeed/compile/init_z3.py
Normal file
94
deepspeed/compile/init_z3.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
from .passes import zero3_compile, prefetch, selective_gather, offload_parameters
|
||||
from .backend import make_backend, launch_compile_passes, init_schedule
|
||||
from .patch_fake_tensor import patch_fake_tensor
|
||||
from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor
|
||||
|
||||
WARMUP = 5
|
||||
|
||||
|
||||
def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None):
|
||||
|
||||
optimizer = engine.optimizer
|
||||
if optimizer is not None and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'):
|
||||
optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
dc = get_deepcompile_handle()
|
||||
dc.init(engine.data_parallel_group,
|
||||
engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory,
|
||||
is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce,
|
||||
compile_config.sync_before_allgather, compile_config.sync_after_allgather)
|
||||
|
||||
# Unset hooks
|
||||
for m in engine.module.modules():
|
||||
m._parameters = m._original_parameters
|
||||
optimizer.parameter_offload._remove_module_hooks()
|
||||
|
||||
for hook in optimizer._grad_acc_hooks:
|
||||
hook.remove()
|
||||
optimizer._grad_acc_hooks.clear()
|
||||
|
||||
# Unpatch linear
|
||||
if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
|
||||
torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk
|
||||
|
||||
if compile_config.symmetric_memory:
|
||||
group_name = engine.data_parallel_group.group_name
|
||||
dist.enable_symm_mem_for_group(group_name)
|
||||
|
||||
for p in engine.module.parameters():
|
||||
grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id]
|
||||
|
||||
# Disable persistent param
|
||||
p.ds_persist = False
|
||||
dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist)
|
||||
|
||||
def set_grad_buffer():
|
||||
for i, sub_group in enumerate(optimizer.fp16_groups):
|
||||
optimizer.averaged_gradients[i] = [
|
||||
optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id]
|
||||
if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group
|
||||
]
|
||||
|
||||
add_pre_backward_hook(set_grad_buffer)
|
||||
|
||||
if schedule is None:
|
||||
schedule = []
|
||||
if (compile_config.offload_parameters):
|
||||
schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd]))
|
||||
else:
|
||||
schedule.append((0, [zero3_compile.add_z3_gather_release]))
|
||||
schedule.append(
|
||||
(WARMUP,
|
||||
[zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather]))
|
||||
|
||||
init_schedule(schedule)
|
||||
|
||||
# offloading opt states need additional setup
|
||||
from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states
|
||||
for _, passes in schedule:
|
||||
if move_opt_states in passes or move_opt_states_sync in passes:
|
||||
init_offload_opt_states(optimizer, dc)
|
||||
|
||||
engine.launch_compile_passes = launch_compile_passes
|
||||
|
||||
patch_fake_tensor()
|
||||
free_activation = compile_config.free_activation and not is_backend_inductor(backend)
|
||||
|
||||
torch._inductor.config.size_asserts = False
|
||||
|
||||
return make_backend(backend,
|
||||
compile_kwargs=compile_kwargs,
|
||||
free_activation=free_activation,
|
||||
debug_log=compile_config.debug_log)
|
431
deepspeed/compile/list_schedule.py
Normal file
431
deepspeed/compile/list_schedule.py
Normal file
@ -0,0 +1,431 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import map_arg
|
||||
|
||||
try:
|
||||
from torch.utils._pytree import tree_iter
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from .util import get_last_uses, is_release_node
|
||||
from .fx import get_output_node
|
||||
|
||||
|
||||
def make_graph_from_schedule(scheduled: List[Node]):
|
||||
new_graph = Graph()
|
||||
env = {}
|
||||
for node in scheduled:
|
||||
new_node = new_graph.node_copy(node, lambda n: env[n.name])
|
||||
env[node.name] = new_node
|
||||
|
||||
return new_graph
|
||||
|
||||
|
||||
def get_original_args_num(node: Node):
|
||||
if node.name.startswith("allgather_ds_param") \
|
||||
or node.name.startswith("release_ds_param") \
|
||||
or node.name.startswith("wait_allgather_ds_param") \
|
||||
or node.name.startswith("reduce_ds_param"):
|
||||
return 1
|
||||
|
||||
return len(node.args)
|
||||
|
||||
|
||||
def flat_nodes_in_args(args: List[Node]):
|
||||
return [a for a in tree_iter(args) if isinstance(a, Node)]
|
||||
|
||||
|
||||
def filter_args(node: Node):
|
||||
args = node.args[:get_original_args_num(node)]
|
||||
return flat_nodes_in_args(args)
|
||||
|
||||
|
||||
def init_schedule(graph: Graph):
|
||||
mem_table = create_mem_table(graph)
|
||||
remaining_users = defaultdict(set)
|
||||
user_to_producer = {}
|
||||
|
||||
scheduled = []
|
||||
unscheduled = []
|
||||
edges = defaultdict(list)
|
||||
for node in graph.nodes:
|
||||
filtered_args = filter_args(node)
|
||||
# print(f"Node: {node} args: {node.args}")
|
||||
if len(filtered_args) == 0:
|
||||
scheduled.append(node)
|
||||
|
||||
remaining_users[node] = set(node.users.keys())
|
||||
for user in node.users.keys():
|
||||
user_to_producer[user] = node
|
||||
else:
|
||||
unscheduled.append(node)
|
||||
for a in filtered_args:
|
||||
for elem_a in tree_iter(a):
|
||||
if isinstance(elem_a, Node):
|
||||
if node not in edges[elem_a]:
|
||||
edges[elem_a].append(node)
|
||||
|
||||
return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer
|
||||
|
||||
|
||||
def get_runnable_nodes(scheduled: List[Node], unscheduled: List[Node]):
|
||||
scheduled = set(scheduled)
|
||||
return [node for node in unscheduled if all(arg in scheduled for arg in filter_args(node))]
|
||||
|
||||
|
||||
def choose_next_node(scheduled: List[Node], unscheduled: List[Node], mem_table: Dict[str, int]):
|
||||
runnable_nodes = get_runnable_nodes(scheduled, unscheduled)
|
||||
|
||||
# sort by memory usage
|
||||
runnable_nodes = sorted(runnable_nodes, key=lambda n: mem_table[n.name])
|
||||
return runnable_nodes[0]
|
||||
|
||||
|
||||
def create_mem_table(graph: Graph) -> Dict[str, int]:
|
||||
mem_table = {}
|
||||
for node in graph.nodes:
|
||||
if node.name.startswith("allgather_ds_param"):
|
||||
mem_table[node.name] = node.meta["tensor_size"]
|
||||
elif node.name.startswith("release_ds_param") or node.name.startswith("reduce_ds_param"):
|
||||
mem_table[node.name] = -node.meta["tensor_size"]
|
||||
else:
|
||||
mem_table[node.name] = 0
|
||||
|
||||
return mem_table
|
||||
|
||||
|
||||
def list_schedule(graph: Graph) -> Graph:
|
||||
|
||||
scheduled, unscheduled, mem_table = init_schedule(graph)
|
||||
|
||||
while len(unscheduled) > 0:
|
||||
next_node = choose_next_node(scheduled, unscheduled, mem_table)
|
||||
scheduled.append(next_node)
|
||||
unscheduled.remove(next_node)
|
||||
|
||||
return make_graph_from_schedule(scheduled)
|
||||
|
||||
|
||||
###############################
|
||||
|
||||
|
||||
def get_new_runnable_nodes_with(scheduled: List[Node], edges: Dict[Node, List[Node]], new_scheduled: Node):
|
||||
scheduled = set(scheduled)
|
||||
new_runnables = []
|
||||
for node in edges[new_scheduled]:
|
||||
if all(arg in scheduled for arg in filter_args(node) if arg != new_scheduled):
|
||||
new_runnables.append(node)
|
||||
|
||||
return new_runnables
|
||||
|
||||
|
||||
def _do_schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]],
|
||||
non_ag_runnable: List[Node]):
|
||||
|
||||
while len(non_ag_runnable) > 0:
|
||||
next_node = non_ag_runnable.pop()
|
||||
|
||||
new_runnables = get_new_runnable_nodes_with(scheduled, edges, next_node)
|
||||
non_ag_runnable += [n for n in new_runnables if not n.name.startswith("allgather_ds_param")]
|
||||
|
||||
scheduled.append(next_node)
|
||||
unscheduled.remove(next_node)
|
||||
|
||||
return scheduled, unscheduled
|
||||
|
||||
|
||||
def schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]]):
|
||||
runnable = get_runnable_nodes(scheduled, unscheduled)
|
||||
non_ag_runnable = [n for n in runnable if not n.name.startswith("allgather_ds_param")]
|
||||
|
||||
tmp_scheduled = copy(scheduled)
|
||||
tmp_unscheduled = copy(unscheduled)
|
||||
|
||||
return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable)
|
||||
|
||||
|
||||
def try_schedule_with_new_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]],
|
||||
new_scheduled: Node):
|
||||
new_runnables = get_new_runnable_nodes_with(scheduled, edges, new_scheduled)
|
||||
non_ag_runnable = [n for n in new_runnables if not n.name.startswith("allgather_ds_param")]
|
||||
|
||||
tmp_scheduled = copy(scheduled)
|
||||
tmp_unscheduled = copy(unscheduled)
|
||||
|
||||
tmp_scheduled.append(new_scheduled)
|
||||
tmp_unscheduled.remove(new_scheduled)
|
||||
|
||||
return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable)
|
||||
|
||||
|
||||
def simple_prefetch(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph:
|
||||
|
||||
scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule(graph)
|
||||
tmp_scheduled, tmp_unscheduled = schedule_without_allgather(scheduled, unscheduled, edges)
|
||||
|
||||
while len(tmp_unscheduled) > 0:
|
||||
|
||||
runnable = get_runnable_nodes(tmp_scheduled, tmp_unscheduled)
|
||||
ag_with_unblock_time = []
|
||||
|
||||
for ag_node in runnable:
|
||||
ag_scheduled, ag_unscheduled = try_schedule_with_new_allgather(tmp_scheduled, tmp_unscheduled, edges,
|
||||
ag_node)
|
||||
unblock_time = sum(n.meta["device_time"] for n in ag_scheduled[len(tmp_scheduled) + 1:])
|
||||
ag_with_unblock_time.append((ag_node, unblock_time, ag_scheduled, ag_unscheduled))
|
||||
|
||||
ag_with_unblock_time = sorted(ag_with_unblock_time, key=lambda x: x[1], reverse=True)
|
||||
best_ag_node = ag_with_unblock_time[0][0]
|
||||
best_ag_scheduled = ag_with_unblock_time[0][2]
|
||||
|
||||
no_ag_runnables = tmp_scheduled[len(scheduled):]
|
||||
after_ag_runnables = best_ag_scheduled[len(tmp_scheduled) + 1:]
|
||||
|
||||
scheduled.append(best_ag_node)
|
||||
unscheduled.remove(best_ag_node)
|
||||
for n in no_ag_runnables:
|
||||
scheduled.append(n)
|
||||
unscheduled.remove(n)
|
||||
|
||||
tmp_scheduled = copy(scheduled)
|
||||
tmp_unscheduled = copy(unscheduled)
|
||||
for n in after_ag_runnables:
|
||||
tmp_scheduled.append(n)
|
||||
tmp_unscheduled.remove(n)
|
||||
|
||||
return make_graph_from_schedule(tmp_scheduled)
|
||||
|
||||
|
||||
###############################
|
||||
|
||||
|
||||
def init_schedule_with_placeholders(graph: Graph):
|
||||
mem_table = create_mem_table(graph)
|
||||
remaining_users = defaultdict(set)
|
||||
user_to_producer = {}
|
||||
|
||||
scheduled = []
|
||||
unscheduled = []
|
||||
edges = defaultdict(list)
|
||||
for node in graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
scheduled.append(node)
|
||||
|
||||
remaining_users[node] = set(node.users.keys())
|
||||
for user in node.users.keys():
|
||||
user_to_producer[user] = node
|
||||
else:
|
||||
unscheduled.append(node)
|
||||
|
||||
return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer
|
||||
|
||||
|
||||
def get_node_requirements(target_node: Node, scheduled: List[Node]):
|
||||
scheduled = set(scheduled)
|
||||
visited = set()
|
||||
ordered_nodes = []
|
||||
|
||||
def dfs(node: Node):
|
||||
if node in scheduled:
|
||||
return
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
args = []
|
||||
|
||||
def register_arg(n: Node):
|
||||
args.append(n)
|
||||
|
||||
map_arg(node.args, register_arg)
|
||||
|
||||
for arg in args:
|
||||
dfs(arg)
|
||||
ordered_nodes.append(node)
|
||||
|
||||
dfs(target_node)
|
||||
|
||||
return ordered_nodes
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllgatherTask:
|
||||
node: Node
|
||||
allgather_cost: float
|
||||
free_cost: float
|
||||
allgathered_mem: int
|
||||
allgather_acc_mem: int
|
||||
free_acc_mem: int
|
||||
last_use: Node
|
||||
n_scheduled_ags: int
|
||||
schedule_until_ag: List[Node]
|
||||
schedule_until_free: List[Node]
|
||||
|
||||
|
||||
def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph:
|
||||
node_to_last_use, user_to_last_uses = get_last_uses(graph)
|
||||
|
||||
# check tensor size
|
||||
for node in graph.nodes:
|
||||
if "tensor_size" not in node.meta:
|
||||
# Our profiler may not visit all nodes because of the control flow.
|
||||
node.meta["tensor_size"] = 0
|
||||
|
||||
scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule_with_placeholders(
|
||||
graph)
|
||||
|
||||
unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.dc.allgather_param.default]
|
||||
|
||||
release_nodes = defaultdict(list)
|
||||
for n in unscheduled:
|
||||
if is_release_node(n):
|
||||
release_nodes[n.args[2]].append(n)
|
||||
|
||||
ag_nodes_in_path = {}
|
||||
for ag_node in unscheduled_ags:
|
||||
last_use = node_to_last_use[ag_node]
|
||||
required_nodes = get_node_requirements(last_use, scheduled)
|
||||
ag_nodes_in_path[ag_node] = set(n for n in required_nodes if n.target == torch.ops.dc.allgather_param.default)
|
||||
|
||||
reduce_nodes = [n for n in unscheduled if n.target == torch.ops.dc.reduce_grad.default]
|
||||
ag_nodes_in_path_to_reduce_nodes = {}
|
||||
for reduce_node in reduce_nodes:
|
||||
ag_nodes_in_path_to_reduce_nodes[reduce_node] = set(n for n in get_node_requirements(reduce_node, scheduled)
|
||||
if n.target == torch.ops.dc.allgather_param.default)
|
||||
|
||||
output_nodes = [
|
||||
n for n in get_output_node(graph).args[0]
|
||||
if isinstance(n, Node) and n.target != torch.ops.dc.reduce_grad.default
|
||||
]
|
||||
ag_nodes_in_path_to_output_nodes = {}
|
||||
for output_node in output_nodes:
|
||||
ag_nodes_in_path_to_output_nodes[output_node] = set(n for n in get_node_requirements(output_node, scheduled)
|
||||
if n.target == torch.ops.dc.allgather_param.default)
|
||||
|
||||
while len(unscheduled_ags) > 0:
|
||||
|
||||
ag_nodes_count = {ag_node: len(nodes) for ag_node, nodes in ag_nodes_in_path.items()}
|
||||
count_list = sorted(set(ag_nodes_count.values()))
|
||||
|
||||
runnable_ags = []
|
||||
for ag_count in count_list:
|
||||
|
||||
target_unscheduled_ags = [ag for ag in unscheduled_ags if ag_nodes_count[ag] == ag_count]
|
||||
|
||||
for node in target_unscheduled_ags:
|
||||
ds_id = node.args[2]
|
||||
|
||||
schedule_until_ag = get_node_requirements(node, scheduled)
|
||||
if schedule_until_ag is None:
|
||||
continue
|
||||
|
||||
last_use = node_to_last_use[node]
|
||||
|
||||
diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag)
|
||||
|
||||
allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag)
|
||||
free_cost = sum(n.meta["device_time"] for n in diff_required_nodes)
|
||||
allgathered_mem = node.meta["tensor_size"]
|
||||
allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag
|
||||
if n.target == torch.ops.dc.allgather_param.default)
|
||||
free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes
|
||||
if n.target == torch.ops.dc.allgather_param.default)
|
||||
|
||||
schedule_until_free = schedule_until_ag + diff_required_nodes
|
||||
for release_node in release_nodes[ds_id]:
|
||||
if release_node not in schedule_until_free:
|
||||
schedule_until_free.append(release_node)
|
||||
|
||||
n_scheduled_ags = len(
|
||||
[n for n in schedule_until_free if n.target == torch.ops.dc.allgather_param.default])
|
||||
|
||||
task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem,
|
||||
last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free)
|
||||
|
||||
# print(f" ag_count {ag_count} allgather runnable {i}: {node} last_use: {node_to_last_use[node]} t: {t2-t1:.2f}")
|
||||
runnable_ags.append(task)
|
||||
|
||||
if len(runnable_ags) > 0:
|
||||
break
|
||||
|
||||
assert len(runnable_ags) > 0, "No runnable allgather nodes"
|
||||
|
||||
# Criteria of the choice:
|
||||
# We want to choose allgather that does not require additional allgather until releasing the param.
|
||||
# When we can find such a node, free_acc_mem will be zero. In that case, we choose the one with the smallest cost until free to minimize the period of occupying memory for the gathered param.
|
||||
# If there is no such node, we choose the one with the smallest free_cost to minimize the period of occupying memory for the gathered param.
|
||||
ags_with_no_additional_ag = [ag for ag in runnable_ags if ag.free_acc_mem == 0]
|
||||
if len(ags_with_no_additional_ag) > 0:
|
||||
sorted_ags = sorted(runnable_ags, key=lambda x: x.free_cost)
|
||||
next_ag = sorted_ags[0]
|
||||
nodes_to_schedule = next_ag.schedule_until_free
|
||||
else:
|
||||
# sorted_ags = sorted(runnable_ags, key=lambda x: x.allgathered_mem)
|
||||
sorted_ags = sorted(runnable_ags, key=lambda x: x.free_acc_mem)
|
||||
next_ag = sorted_ags[0]
|
||||
nodes_to_schedule = next_ag.schedule_until_ag
|
||||
|
||||
# print(f" next_ag {next_ag}")
|
||||
for n in nodes_to_schedule:
|
||||
scheduled.append(n)
|
||||
unscheduled.remove(n)
|
||||
|
||||
unscheduled_ags.remove(next_ag.node)
|
||||
|
||||
ag_nodes_in_path.pop(next_ag.node)
|
||||
for ag_node, nodes in ag_nodes_in_path.items():
|
||||
if next_ag.node in nodes:
|
||||
nodes.remove(next_ag.node)
|
||||
|
||||
# Schedule reduce nodes when possible to free memory earlier
|
||||
reduces_to_schedule = []
|
||||
for reduce_node in reduce_nodes:
|
||||
if next_ag.node in ag_nodes_in_path_to_reduce_nodes[reduce_node]:
|
||||
ag_nodes_in_path_to_reduce_nodes[reduce_node].remove(next_ag.node)
|
||||
if len(ag_nodes_in_path_to_reduce_nodes[reduce_node]) == 0:
|
||||
reduces_to_schedule.append(reduce_node)
|
||||
|
||||
for n in reduces_to_schedule:
|
||||
need_to_schedule = get_node_requirements(n, scheduled)
|
||||
for nn in need_to_schedule:
|
||||
scheduled.append(nn)
|
||||
unscheduled.remove(nn)
|
||||
|
||||
# Do the same for output nodes
|
||||
outputs_to_schedule = []
|
||||
for output_node in output_nodes:
|
||||
if next_ag.node in ag_nodes_in_path_to_output_nodes[output_node]:
|
||||
ag_nodes_in_path_to_output_nodes[output_node].remove(next_ag.node)
|
||||
if len(ag_nodes_in_path_to_output_nodes[output_node]) == 0:
|
||||
outputs_to_schedule.append(output_node)
|
||||
|
||||
for n in outputs_to_schedule:
|
||||
need_to_schedule = get_node_requirements(n, scheduled)
|
||||
for nn in need_to_schedule:
|
||||
scheduled.append(nn)
|
||||
unscheduled.remove(nn)
|
||||
|
||||
# print(f"After ag scheduled: scheduled: {scheduled}")
|
||||
|
||||
scheduled_set = set(scheduled)
|
||||
for node in graph.nodes:
|
||||
if node in scheduled_set:
|
||||
continue
|
||||
scheduled.append(node)
|
||||
unscheduled.remove(node)
|
||||
|
||||
assert len(unscheduled) == 0, f"There are unscheduled nodes: {unscheduled}"
|
||||
|
||||
ret_graph = make_graph_from_schedule(scheduled)
|
||||
ret_graph.lint()
|
||||
return ret_graph
|
158
deepspeed/compile/partitioner.py
Normal file
158
deepspeed/compile/partitioner.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
# This file was copied from PyTorch and modified for DeepSpeed.
|
||||
|
||||
from typing import Tuple, List
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Graph, Node
|
||||
|
||||
try:
|
||||
from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from .util import get_no_copy_ops
|
||||
|
||||
_recompute_ops = {torch.ops.aten.t.default}
|
||||
|
||||
|
||||
def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]:
|
||||
"""
|
||||
Given a graph and a node that represents a parameter that was allgathered,
|
||||
find all nodes that use the parameter and require recomputation.
|
||||
"""
|
||||
no_copy_ops = get_no_copy_ops()
|
||||
recompute_nodes = set()
|
||||
for node in graph.nodes:
|
||||
if node.target in no_copy_ops:
|
||||
if ds_param_node in node.args:
|
||||
recompute_nodes.add(node)
|
||||
if any(a in recompute_nodes for a in node.args):
|
||||
recompute_nodes.add(node)
|
||||
|
||||
return recompute_nodes
|
||||
|
||||
|
||||
def _get_values_from_ds_params(joint_graph, param_indices):
|
||||
primal_inputs = list(filter(_is_primal, joint_graph.nodes))
|
||||
ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices]
|
||||
|
||||
no_copy_ops = get_no_copy_ops()
|
||||
|
||||
ds_param_inputs = set(ds_param_inputs)
|
||||
ds_param_users = {}
|
||||
|
||||
for node in joint_graph.nodes:
|
||||
if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args):
|
||||
for a in node.args:
|
||||
if a in ds_param_inputs:
|
||||
ds_param_users[node] = a
|
||||
elif a in ds_param_users:
|
||||
ds_param_users[node] = ds_param_users[a]
|
||||
|
||||
return ds_param_users
|
||||
|
||||
|
||||
def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]):
|
||||
|
||||
def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]:
|
||||
saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget)
|
||||
ds_param_users = _get_values_from_ds_params(joint_graph, param_indices)
|
||||
|
||||
new_saved_values = []
|
||||
for v in saved_values:
|
||||
if v in ds_param_users:
|
||||
ds_val = ds_param_users[v]
|
||||
if ds_val not in new_saved_values:
|
||||
new_saved_values.append(ds_val)
|
||||
else:
|
||||
new_saved_values.append(v)
|
||||
|
||||
return new_saved_values
|
||||
|
||||
return ds_choose_saved_values_set
|
||||
|
||||
|
||||
def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]):
|
||||
|
||||
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
|
||||
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
|
||||
"""
|
||||
This is basically the same as the default_partition function, but
|
||||
it doesn't save the gathered params and values computed from them.
|
||||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
|
||||
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward")
|
||||
forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"}
|
||||
saved_values = []
|
||||
saved_sym_nodes = []
|
||||
|
||||
fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes))
|
||||
ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices]
|
||||
ds_param_input_names = {node.name for node in ds_param_inputs}
|
||||
|
||||
ds_param_recompute_nodes = set()
|
||||
|
||||
for node in joint_module.graph.nodes:
|
||||
if node.name not in forward_node_names:
|
||||
continue
|
||||
|
||||
if is_sym_node(node):
|
||||
# Symints must be kept separate from tensors so that PythonFunction only calls
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
saved_sym_nodes.append(node)
|
||||
elif "tensor_meta" not in node.meta and node.op == "call_function":
|
||||
# Since we can't save tuple of tensor values, we need to flatten out what we're saving
|
||||
users = node.users
|
||||
assert all(user.target == operator.getitem for user in users)
|
||||
saved_values.extend(users)
|
||||
else:
|
||||
backward_usages = [n for n in node.users if n.name not in forward_node_names]
|
||||
|
||||
if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages):
|
||||
# If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
|
||||
# and not the actual tensor data,
|
||||
# then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
|
||||
#
|
||||
# Note that saving the tensor could also cause compilation problems:
|
||||
# If the user mutated an input in the forward and uses its sizes/strides in the backward,
|
||||
# then we would be obligated to clone the input before saving it to appease autograd.
|
||||
# (This is how we originally found this bug).
|
||||
saved_sym_nodes.extend(backward_usages)
|
||||
|
||||
if node.name in ds_param_input_names:
|
||||
saved_values.append(node)
|
||||
recompute_nodes = _find_recompute_nodes(joint_module.graph, node)
|
||||
recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names]
|
||||
for recompute_node in recompute_nodes:
|
||||
ds_param_recompute_nodes.add(recompute_node)
|
||||
|
||||
if len(recompute_nodes) > 0:
|
||||
saved_values.append(node)
|
||||
else:
|
||||
if node not in ds_param_recompute_nodes:
|
||||
saved_values.append(node)
|
||||
saved_values = list(dict.fromkeys(saved_values).keys())
|
||||
saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
|
||||
|
||||
f_gm, b_gm = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
)
|
||||
|
||||
return f_gm, b_gm
|
||||
|
||||
return partition_recompute_ds_params
|
48
deepspeed/compile/passes/__init__.py
Normal file
48
deepspeed/compile/passes/__init__.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from ..profilers.graph_profile import MemoryProfilingInterpreter
|
||||
|
||||
import deepspeed.comm as dist
|
||||
|
||||
|
||||
def run_opt_passes(nz3,
|
||||
graph_index,
|
||||
graph_id,
|
||||
gm,
|
||||
create_inputs_fn,
|
||||
opt_passes,
|
||||
graph_order,
|
||||
profiling_results,
|
||||
param_manager,
|
||||
bwd,
|
||||
debug_log=False):
|
||||
profile = profiling_results[graph_id]
|
||||
rank = dist.get_rank()
|
||||
|
||||
for i, opt_pass in enumerate(opt_passes):
|
||||
|
||||
opt_pass_fn, mem_budget = opt_pass
|
||||
|
||||
graph = opt_pass_fn(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd)
|
||||
graph.lint()
|
||||
gm.graph = graph
|
||||
gm.recompile()
|
||||
|
||||
if debug_log:
|
||||
print(f"Prefetching enabled for {'bwd' if bwd else 'fwd'} graph_id={graph_id} {graph}")
|
||||
|
||||
mem_prof = MemoryProfilingInterpreter(nz3, gm)
|
||||
mem_prof.run(*create_inputs_fn())
|
||||
if debug_log and rank == 0:
|
||||
mem_prof.dump(f"mem_prof_r{rank}_{'bwd' if bwd else 'fwd'}_{graph_index}_{graph_id}_pass_{i}.csv")
|
||||
|
||||
mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record]
|
||||
if bwd:
|
||||
profile.bwd_mem = mem
|
||||
else:
|
||||
profile.fwd_mem = mem
|
||||
|
||||
return gm
|
116
deepspeed/compile/passes/offload_activation.py
Normal file
116
deepspeed/compile/passes/offload_activation.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List, Dict, Set, Tuple
|
||||
import random
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from ..fx import get_output_node, move_primals_to_head
|
||||
from ..graph_param import DSGraphParamManager
|
||||
|
||||
value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict)
|
||||
used_ids: Set[int] = set()
|
||||
|
||||
|
||||
def get_random_id() -> int:
|
||||
|
||||
def _gen():
|
||||
# generate random int
|
||||
return random.randint(10000, 2**31)
|
||||
|
||||
global used_ids
|
||||
v = _gen()
|
||||
while v in used_ids:
|
||||
v = _gen()
|
||||
used_ids.add(v)
|
||||
return v
|
||||
|
||||
|
||||
def _should_offload(node: Node) -> bool:
|
||||
if not hasattr(node, "meta"):
|
||||
return False
|
||||
if not "tensor_meta" in node.meta:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]],
|
||||
graph_order: List[int], mem_budget: float, param_manager: DSGraphParamManager) -> Graph:
|
||||
param_names = set(param_manager.param_names)
|
||||
|
||||
import copy
|
||||
cl_graph = copy.deepcopy(graph)
|
||||
cl_graph.erase_node(get_output_node(cl_graph))
|
||||
|
||||
global value_to_id
|
||||
for name, node in nodes_to_offload_with_names:
|
||||
if node.name in param_names:
|
||||
continue
|
||||
|
||||
if not _should_offload(node):
|
||||
continue
|
||||
|
||||
val_id = get_random_id()
|
||||
with graph.inserting_after(node):
|
||||
offload_node = graph.create_node('call_function',
|
||||
torch.ops.dc.offload_tensor.default, (node, graph_id, val_id), {},
|
||||
name=f"offload_{node.name}_{val_id}")
|
||||
with graph.inserting_after(offload_node):
|
||||
wait_node = graph.create_node('call_function',
|
||||
torch.ops.dc.wait_offload.default, (offload_node, graph_id, val_id), {},
|
||||
name=f"wait_copy_{node.name}_{val_id}")
|
||||
|
||||
output_node = get_output_node(graph)
|
||||
output_node.replace_input_with(node, wait_node)
|
||||
|
||||
value_to_id[graph_id][name] = val_id
|
||||
|
||||
graph = move_primals_to_head(graph)
|
||||
|
||||
graph.lint()
|
||||
return graph
|
||||
|
||||
|
||||
def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float,
|
||||
param_manager: DSGraphParamManager) -> Graph:
|
||||
|
||||
graph_value_to_id = value_to_id[graph_id]
|
||||
name_to_node = {n.name: n for n in graph.nodes}
|
||||
act_nodes = [name_to_node[n] for n in graph_value_to_id.keys()]
|
||||
|
||||
node_to_first_user = {}
|
||||
for act in act_nodes:
|
||||
for node in graph.nodes:
|
||||
if act in node.args:
|
||||
node_to_first_user[act] = node
|
||||
break
|
||||
|
||||
for node in act_nodes:
|
||||
val_id = graph_value_to_id[node.name]
|
||||
|
||||
with graph.inserting_before(node_to_first_user[node]):
|
||||
reload_node = graph.create_node('call_function',
|
||||
torch.ops.dc.reload_tensor.default, (node, graph_id, val_id), {},
|
||||
name=f"reload_{node.name}_{val_id}")
|
||||
with graph.inserting_after(reload_node):
|
||||
wait_node = graph.create_node('call_function',
|
||||
torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {},
|
||||
name=f"wait_copy_{node.name}_{val_id}")
|
||||
|
||||
# replace all uses of node with wait_node
|
||||
users = {}
|
||||
for u in node.users.keys():
|
||||
if u != reload_node:
|
||||
users[u] = (node, wait_node)
|
||||
for u, (old_in, new_in) in users.items():
|
||||
u.replace_input_with(old_in, new_in)
|
||||
|
||||
graph = move_primals_to_head(graph)
|
||||
graph.lint()
|
||||
return graph
|
546
deepspeed/compile/passes/offload_adam_states.py
Normal file
546
deepspeed/compile/passes/offload_adam_states.py
Normal file
@ -0,0 +1,546 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.zero.offload_states import _make_offload_state_key
|
||||
|
||||
try:
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
except ImportError:
|
||||
# Unsupported torch version
|
||||
pass
|
||||
|
||||
from ..profilers import ProfilingResult
|
||||
from ..graph_param import DSGraphParamManager
|
||||
from ..fx import move_primals_to_head
|
||||
|
||||
import deepspeed.comm as dist
|
||||
|
||||
NAME = "offload_adam_states"
|
||||
|
||||
|
||||
def print_r0(msg):
|
||||
if dist.get_rank() == 0:
|
||||
print(msg)
|
||||
|
||||
|
||||
MARGIN = 0.2
|
||||
|
||||
copy_stream = None
|
||||
offload_event = None
|
||||
reload_event = None
|
||||
|
||||
offload_key_events = {}
|
||||
reload_key_events = {}
|
||||
|
||||
max_memory = 0
|
||||
|
||||
|
||||
def lazy_init():
|
||||
global copy_stream
|
||||
global offload_event
|
||||
global reload_event
|
||||
|
||||
if copy_stream is None:
|
||||
|
||||
copy_stream = get_accelerator().Stream()
|
||||
offload_event = get_accelerator().Event()
|
||||
reload_event = get_accelerator().Event()
|
||||
|
||||
|
||||
optimizer = None
|
||||
device = None
|
||||
nz3 = None
|
||||
|
||||
|
||||
def move_key(state, key, key_event=None):
|
||||
offload_buf_key = _make_offload_state_key(key)
|
||||
if offload_buf_key not in state:
|
||||
state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu"))
|
||||
|
||||
if key not in state:
|
||||
return
|
||||
|
||||
with get_accelerator().stream(copy_stream):
|
||||
state[offload_buf_key].copy_(state[key], non_blocking=True)
|
||||
|
||||
if key_event is None:
|
||||
offload_event.record(stream=copy_stream)
|
||||
else:
|
||||
key_event.record(stream=copy_stream)
|
||||
|
||||
|
||||
def move_back_key(state, key, key_event=None):
|
||||
|
||||
with get_accelerator().stream(copy_stream):
|
||||
state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device)
|
||||
state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True)
|
||||
|
||||
if key_event is None:
|
||||
reload_event.record(stream=copy_stream)
|
||||
else:
|
||||
key_event.record(stream=copy_stream)
|
||||
|
||||
|
||||
def move_hp_param(src_tensor, dest_buf, key_event=None):
|
||||
with get_accelerator().stream(copy_stream):
|
||||
dest_buf.copy_(src_tensor, non_blocking=True)
|
||||
src_tensor.data = dest_buf
|
||||
|
||||
if key_event is None:
|
||||
reload_event.record(stream=copy_stream)
|
||||
else:
|
||||
key_event.record(stream=copy_stream)
|
||||
|
||||
|
||||
def move_back_hp_param(src_tensor, dest_buf, key_event=None):
|
||||
with get_accelerator().stream(copy_stream):
|
||||
dest_buf.data = torch.empty_like(src_tensor, device=device)
|
||||
dest_buf.copy_(src_tensor, non_blocking=True)
|
||||
|
||||
if key_event is None:
|
||||
reload_event.record(stream=copy_stream)
|
||||
else:
|
||||
key_event.record(stream=copy_stream)
|
||||
|
||||
|
||||
def offload_adam_states_sync():
|
||||
|
||||
with unset_fake_temporarily():
|
||||
|
||||
if not hasattr(optimizer, "hp_params_pin_buffers"):
|
||||
optimizer.hp_params_pin_buffers = [
|
||||
get_accelerator().pin_memory(torch.empty_like(t, device="cpu"))
|
||||
for t in optimizer.fp32_partitioned_groups_flat
|
||||
]
|
||||
|
||||
for i, (k, state) in enumerate(optimizer.state.items()):
|
||||
if "exp_avg" in state:
|
||||
move_key(state, "exp_avg")
|
||||
if "exp_avg_sq" in state:
|
||||
move_key(state, "exp_avg_sq")
|
||||
|
||||
for _, state in optimizer.state.items():
|
||||
if "exp_avg" in state:
|
||||
del state["exp_avg"]
|
||||
if "exp_avg_sq" in state:
|
||||
del state["exp_avg_sq"]
|
||||
|
||||
for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers):
|
||||
move_hp_param(src_tensor, dest_buf)
|
||||
|
||||
get_accelerator().synchronize()
|
||||
|
||||
|
||||
def reload_adam_states_sync():
|
||||
|
||||
with unset_fake_temporarily():
|
||||
# print_r0("Reloading Adam states")
|
||||
|
||||
for _, state in optimizer.state.items():
|
||||
if _make_offload_state_key("exp_avg") in state:
|
||||
move_back_key(state, "exp_avg")
|
||||
if _make_offload_state_key("exp_avg_sq") in state:
|
||||
move_back_key(state, "exp_avg_sq")
|
||||
|
||||
for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat):
|
||||
move_back_hp_param(src, dest)
|
||||
|
||||
get_accelerator().synchronize()
|
||||
|
||||
|
||||
def sync_offload_states(event=None):
|
||||
if nz3.is_profiling():
|
||||
offload_adam_states_sync()
|
||||
else:
|
||||
if event is None:
|
||||
offload_event.wait(copy_stream)
|
||||
else:
|
||||
event.wait(copy_stream)
|
||||
|
||||
|
||||
def sync_reload_states(event=None):
|
||||
if nz3.is_profiling():
|
||||
reload_adam_states_sync()
|
||||
else:
|
||||
if event is None:
|
||||
reload_event.wait(copy_stream)
|
||||
else:
|
||||
event.wait(copy_stream)
|
||||
|
||||
|
||||
def make_offload_task(task):
|
||||
|
||||
def run_offload_task():
|
||||
# if not nz3.is_profiling():
|
||||
# print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}")
|
||||
|
||||
if offload_key_events.get(task[1]) is None:
|
||||
offload_key_events[task[1]] = get_accelerator().Event()
|
||||
|
||||
if task[2] == "hp_param":
|
||||
move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]])
|
||||
else:
|
||||
assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer"
|
||||
state = optimizer.state[task[1]]
|
||||
# if offload_key_events.get(task[1]) is None:
|
||||
# offload_key_events[task[1]] = get_accelerator().Event()
|
||||
move_key(state, task[2], offload_key_events[task[1]])
|
||||
|
||||
return run_offload_task
|
||||
|
||||
|
||||
def make_offload_sync(task):
|
||||
|
||||
def run_offload_sync():
|
||||
# if not nz3.is_profiling():
|
||||
event = offload_key_events[task[1]]
|
||||
event.synchronize()
|
||||
|
||||
if task[2] != "hp_param":
|
||||
state = optimizer.state[task[1]]
|
||||
key = task[2]
|
||||
if key in state:
|
||||
del state[key]
|
||||
# print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}")
|
||||
|
||||
return run_offload_sync
|
||||
|
||||
|
||||
def make_reload_task(task):
|
||||
|
||||
def run_reload_task():
|
||||
if not nz3.is_profiling():
|
||||
if reload_key_events.get(task[1]) is None:
|
||||
reload_key_events[task[1]] = get_accelerator().Event()
|
||||
|
||||
if task[2] == "hp_param":
|
||||
move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]])
|
||||
else:
|
||||
state = optimizer.state[task[1]]
|
||||
# print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}")
|
||||
move_back_key(state, task[2], reload_key_events[task[1]])
|
||||
|
||||
return run_reload_task
|
||||
|
||||
|
||||
def update_max_memory(name):
|
||||
|
||||
global max_memory
|
||||
mem = get_accelerator().max_memory_allocated()
|
||||
max_memory = max(max_memory, mem)
|
||||
|
||||
|
||||
def empty_cache():
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
|
||||
offload_tasks = []
|
||||
offload_tasks_remaining = []
|
||||
offload_tasks_scheduled = []
|
||||
reload_task_remaining = []
|
||||
total_reload_mem = 0
|
||||
|
||||
|
||||
def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
|
||||
|
||||
to_remove = []
|
||||
for node in graph.nodes:
|
||||
if node.op == 'call_function' and \
|
||||
node.target in [offload_adam_states_sync, sync_offload_states, reload_adam_states_sync, sync_reload_states, update_max_memory]:
|
||||
to_remove.append(node)
|
||||
|
||||
for node in to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
accelerator = get_accelerator()
|
||||
total_mem = accelerator.total_memory() * (1 - MARGIN)
|
||||
print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}")
|
||||
|
||||
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
|
||||
mem_dict = {name: peak for name, alloc_mem, delta, peak in mem}
|
||||
|
||||
current_peak_mem = 0
|
||||
peak_mem = {}
|
||||
|
||||
ordered_node = reversed(graph.nodes) if bwd else graph.nodes
|
||||
for node in ordered_node:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if mem_dict[node.name] > current_peak_mem:
|
||||
current_peak_mem = mem_dict[node.name]
|
||||
peak_mem[node.name] = current_peak_mem
|
||||
|
||||
# fwd_max_mem = max(m[3] for m in prof.fwd_mem)
|
||||
# bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
|
||||
# peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
|
||||
|
||||
global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled
|
||||
|
||||
if not bwd:
|
||||
is_first_graph = graph_id == graph_order[0][0]
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}"
|
||||
# )
|
||||
|
||||
# At the beginning of the first graph, we schedule offload tasks to launch all offloading
|
||||
if is_first_graph:
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}"
|
||||
# )
|
||||
|
||||
with unset_fake_temporarily():
|
||||
offload_adam_states_sync()
|
||||
reload_adam_states_sync()
|
||||
sync_reload_states()
|
||||
|
||||
reload_size = 0
|
||||
|
||||
for i, ((k, state), hp_param, hp_param_cpu) in enumerate(
|
||||
zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat,
|
||||
optimizer.hp_params_pin_buffers)):
|
||||
# print_r0(
|
||||
# f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}")
|
||||
|
||||
if _make_offload_state_key("exp_avg") in state:
|
||||
key = _make_offload_state_key("exp_avg")
|
||||
size = state[key].numel() * state[key].element_size()
|
||||
|
||||
# if total_mem < max_memory + reload_size + size:
|
||||
offload_tasks.append(
|
||||
(i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype))
|
||||
# print_r0(
|
||||
# f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
|
||||
# )
|
||||
|
||||
if _make_offload_state_key("exp_avg_sq") in state:
|
||||
key = _make_offload_state_key("exp_avg_sq")
|
||||
size = state[key].numel() * state[key].element_size()
|
||||
|
||||
# if total_mem < max_memory + reload_size + size:
|
||||
offload_tasks.append(
|
||||
(i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype))
|
||||
# print_r0(
|
||||
# f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}"
|
||||
# )
|
||||
|
||||
hp_param_size = hp_param.numel() * hp_param.element_size()
|
||||
# if total_mem < max_memory + reload_size + hp_param_size:
|
||||
offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param",
|
||||
hp_param.numel() * hp_param.element_size(), hp_param.dtype))
|
||||
# print_r0(
|
||||
# f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}"
|
||||
# )
|
||||
|
||||
# print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}")
|
||||
|
||||
for node in graph.nodes:
|
||||
# print_r0(f"checking sync node insert node: {node.name}")
|
||||
|
||||
if node.name not in peak_mem \
|
||||
or node.op == 'placeholder' \
|
||||
or "offload_opt_" in node.name:
|
||||
continue
|
||||
|
||||
to_offload = []
|
||||
optim_size = sum([task[3] for task in offload_tasks])
|
||||
|
||||
# print_r0(
|
||||
# f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
|
||||
# )
|
||||
while total_mem - peak_mem[node.name] - optim_size < 0:
|
||||
if len(offload_tasks) == 0:
|
||||
break
|
||||
|
||||
task = offload_tasks.pop(0)
|
||||
to_offload.append(task)
|
||||
optim_size = sum([task[3] for task in offload_tasks])
|
||||
# print_r0(
|
||||
# f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}"
|
||||
# )
|
||||
|
||||
for task in to_offload:
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function',
|
||||
make_offload_sync(task), (), {},
|
||||
name=f"offload_opt_sync_{task[0]}_{task[2]}")
|
||||
print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}")
|
||||
offload_tasks_scheduled.append(task)
|
||||
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op != 'placeholder':
|
||||
print_r0(f"Inserting all offload tasks before {node.name}")
|
||||
for task in offload_tasks_scheduled:
|
||||
name = f"offload_opt_{task[0]}_{task[2]}"
|
||||
with graph.inserting_before(node):
|
||||
offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name)
|
||||
break
|
||||
|
||||
# print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}")
|
||||
print_r0(f"offload_opt_states_inc finish graph {graph_id}")
|
||||
else:
|
||||
|
||||
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
|
||||
is_first_graph = graph_id == graph_order_with_backward[-1]
|
||||
is_last_graph = graph_id == graph_order_with_backward[0]
|
||||
|
||||
# print_r0(
|
||||
# f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
|
||||
# )
|
||||
|
||||
if is_first_graph:
|
||||
inserted_sync = False
|
||||
for node in graph.nodes:
|
||||
if node.op != 'placeholder' and not inserted_sync:
|
||||
# print(f"Inserting offload_sync before {node.name}")
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', empty_cache, (), {}, name="empty_cache")
|
||||
|
||||
inserted_sync = True
|
||||
reload_tasks_remaining = copy.copy(offload_tasks_scheduled)
|
||||
|
||||
global total_reload_mem
|
||||
for node in graph.nodes:
|
||||
if node.name not in peak_mem \
|
||||
or node.op == 'placeholder' \
|
||||
or node.op == 'output' \
|
||||
or "offload_opt_sync_" in node.name:
|
||||
continue
|
||||
|
||||
if len(reload_tasks_remaining) > 0:
|
||||
task = reload_tasks_remaining[0]
|
||||
next_reload_mem = task[3]
|
||||
|
||||
insert_pos = node
|
||||
while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem:
|
||||
expected_mem = peak_mem[node.name] + total_reload_mem
|
||||
print_r0(
|
||||
f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}"
|
||||
)
|
||||
|
||||
with graph.inserting_after(insert_pos):
|
||||
insert_pos = graph.create_node('call_function',
|
||||
make_reload_task(task), (), {},
|
||||
name=f"reload_opt_{task[0]}_{task[2]}")
|
||||
|
||||
total_reload_mem += next_reload_mem
|
||||
reload_tasks_remaining.pop(0)
|
||||
if len(reload_tasks_remaining) == 0:
|
||||
break
|
||||
|
||||
task = reload_tasks_remaining[0]
|
||||
next_reload_mem = task[3]
|
||||
|
||||
# prev_node = node
|
||||
|
||||
if is_last_graph:
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op == 'output':
|
||||
for task in reload_tasks_remaining:
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function',
|
||||
make_reload_task(task), (), {},
|
||||
name=f"reload_opt_{task[0]}_{task[2]}")
|
||||
|
||||
sync_fn = lambda: copy_stream.synchronize()
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream")
|
||||
|
||||
print_r0(
|
||||
f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def add_record_max_mem_nodes(graph: Graph):
|
||||
|
||||
nodes = list(graph.nodes)
|
||||
for node in nodes:
|
||||
if node.op == "output" or node.op == "placeholder":
|
||||
continue
|
||||
|
||||
with graph.inserting_after(node):
|
||||
name = f"update_max_memory_{node.name}"
|
||||
graph.create_node('call_function', update_max_memory, (name, ), {}, name=name)
|
||||
|
||||
|
||||
def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph:
|
||||
|
||||
if bwd:
|
||||
graph_order_with_backward = [g[0] for g in graph_order if g[1]]
|
||||
is_last_graph = graph_id == graph_order_with_backward[0]
|
||||
|
||||
inserted_reload = False
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op == 'output' and not inserted_reload and is_last_graph:
|
||||
# print(f"Inserting reload_opt before {node.name}")
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt")
|
||||
inserted_reload = True
|
||||
|
||||
# add_record_max_mem_nodes(graph)
|
||||
|
||||
else:
|
||||
is_first_graph = graph_id == graph_order[0][0]
|
||||
|
||||
graph = move_primals_to_head(graph)
|
||||
|
||||
inserted_offload = False
|
||||
for node in graph.nodes:
|
||||
# print(f"Node: {node.name} mem: {mem_dict[node.name]}")
|
||||
if node.op != 'placeholder' and not inserted_offload and is_first_graph:
|
||||
print(f"Inserting offload_opt before {node.name}")
|
||||
with graph.inserting_before(node):
|
||||
graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt")
|
||||
inserted_offload = True
|
||||
|
||||
add_record_max_mem_nodes(graph)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
|
||||
bwd)
|
||||
return gm
|
||||
|
||||
|
||||
def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager,
|
||||
bwd)
|
||||
return gm
|
||||
|
||||
|
||||
def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results,
|
||||
create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager,
|
||||
bwd: bool) -> GraphModule:
|
||||
if not bwd and graph_id == graph_order[0][0]:
|
||||
with unset_fake_temporarily():
|
||||
offload_adam_states_sync()
|
||||
# returns None, and profiling will be skipped
|
||||
|
||||
|
||||
def init_offload_opt_states(adam_optimizer, _nz3):
|
||||
lazy_init()
|
||||
|
||||
global optimizer
|
||||
optimizer = adam_optimizer
|
||||
global device
|
||||
device = torch.device(get_accelerator().current_device())
|
||||
global nz3
|
||||
nz3 = _nz3
|
54
deepspeed/compile/passes/offload_parameters.py
Normal file
54
deepspeed/compile/passes/offload_parameters.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, GraphModule
|
||||
from deepspeed.compile.util import get_last_uses
|
||||
from ..graph_param import DSGraphParamManager
|
||||
|
||||
|
||||
def add_offload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
|
||||
new_node = None
|
||||
with gm.graph.inserting_after(node):
|
||||
args = (node, )
|
||||
for a in [graph_id, ds_id]: # To add ds_id
|
||||
args += (a, )
|
||||
new_node = gm.graph.create_node('call_function',
|
||||
torch.ops.dc.offload_parameter.default,
|
||||
args, {},
|
||||
name="offload_parameter")
|
||||
|
||||
return new_node
|
||||
|
||||
|
||||
def add_reload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int):
|
||||
new_node = None
|
||||
with gm.graph.inserting_after(node):
|
||||
args = (node, )
|
||||
for a in [graph_id, ds_id]: # To add ds_id
|
||||
args += (a, )
|
||||
new_node = gm.graph.create_node('call_function',
|
||||
torch.ops.dc.reload_parameter.default,
|
||||
args, {},
|
||||
name=f"reload_parameter")
|
||||
return new_node
|
||||
|
||||
|
||||
def get_ds_id(node: Node):
|
||||
assert node.target == torch.ops.dc.allgather_param.default
|
||||
return node.args[2]
|
||||
|
||||
|
||||
def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
node_to_last_use, user_to_last_uses = get_last_uses(gm.graph)
|
||||
for node in gm.graph.nodes:
|
||||
if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default):
|
||||
add_reload_parameter(graph_id, gm, node.args[0], get_ds_id(node))
|
||||
add_offload_parameter(graph_id, gm, node_to_last_use[node], get_ds_id(node))
|
||||
gm.graph.lint()
|
||||
return gm
|
174
deepspeed/compile/passes/prefetch.py
Normal file
174
deepspeed/compile/passes/prefetch.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node, GraphModule
|
||||
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
import deepspeed.comm as dist
|
||||
|
||||
from ..profilers.comm_profile import create_predictor
|
||||
from ..graph_param import DSGraphParamManager
|
||||
|
||||
NAME = "prefetch"
|
||||
|
||||
FUSE_FACTOR = 0.8
|
||||
MARGIN = 0.1
|
||||
MAX_FUSE_SIZE = 1e9
|
||||
MAX_BUFFERED_SIZE = 4e9
|
||||
|
||||
run_prefetch_pass = False
|
||||
|
||||
|
||||
def print_rank_0(message):
|
||||
if dist.get_rank() == 0:
|
||||
print(message)
|
||||
|
||||
|
||||
def get_ds_id(node: Node):
|
||||
assert node.target == torch.ops.dc.allgather_param.default
|
||||
return node.args[2]
|
||||
|
||||
|
||||
def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
|
||||
max_mem = get_accelerator().total_memory() * (1 - MARGIN)
|
||||
vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device()))
|
||||
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
|
||||
max_mem = vals_to_bcast[0].item()
|
||||
|
||||
mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem
|
||||
op_time = profiling_results[graph_id].bwd_time if bwd else profiling_results[graph_id].fwd_time
|
||||
tensor_sizes = profiling_results[graph_id].bwd_tensor_sizes if bwd else profiling_results[graph_id].fwd_tensor_sizes
|
||||
|
||||
mem_dict = {name: (alloc_mem, peak) for name, alloc_mem, delta, peak in mem}
|
||||
time_dict = {name: (device_time, wall_time) for name, device_time, wall_time in op_time}
|
||||
tensor_size_dict = {name: size for name, size in tensor_sizes}
|
||||
|
||||
graph = gm.graph
|
||||
total_param_size = sum(
|
||||
[tensor_size_dict[n.name] for n in graph.nodes if n.target == torch.ops.dc.allgather_param.default])
|
||||
|
||||
print_rank_0(
|
||||
f"schedule_prefetch graph_id={graph_id} max_mem={max_mem} available_memory={get_accelerator().available_memory()} memory_allocated={get_accelerator().memory_allocated()} max_allocated={get_accelerator().max_memory_allocated()} total_param_size={total_param_size} margin={MARGIN}"
|
||||
)
|
||||
|
||||
# Fill missing values
|
||||
prev_mem = 0
|
||||
prev_peak = 0
|
||||
for node in graph.nodes:
|
||||
if node.name in mem_dict:
|
||||
prev_mem = mem_dict[node.name][0]
|
||||
prev_peak = mem_dict[node.name][1]
|
||||
else:
|
||||
print_rank_0(f"node {node.name} not in mem_dict")
|
||||
mem_dict[node.name] = (prev_mem, prev_peak)
|
||||
|
||||
comm_predictor = create_predictor()
|
||||
|
||||
order_rev = list(reversed(graph.nodes))
|
||||
new_order_rev = []
|
||||
prefetch_ags = []
|
||||
prefetch_ag_groups = []
|
||||
ag_tensor_size_sum = 0
|
||||
for i, node in enumerate(order_rev):
|
||||
# print_rank_0(
|
||||
# f"Checking node reverse order {node.name} {node.target} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
|
||||
# )
|
||||
|
||||
if node.op != "placeholder":
|
||||
assert i < len(order_rev) - 1
|
||||
assert node.name in mem_dict
|
||||
next_node = order_rev[i + 1]
|
||||
next_alloc_mem, next_peak = mem_dict[next_node.name]
|
||||
|
||||
# Free up memory
|
||||
while next_peak + ag_tensor_size_sum > max_mem or ag_tensor_size_sum > MAX_BUFFERED_SIZE:
|
||||
if len(prefetch_ag_groups) > 0:
|
||||
# launch prefetch
|
||||
fused_ag_nodes = prefetch_ag_groups.pop(0)
|
||||
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in fused_ag_nodes])
|
||||
ag_tensor_size_sum -= total_ag_tensor_size
|
||||
new_order_rev.append(fused_ag_nodes)
|
||||
assert len(fused_ag_nodes) > 0
|
||||
# print_rank_0(
|
||||
# f"Free up memory fused_ag_nodes={fused_ag_nodes} next_alloc_mem={next_alloc_mem} total_ag_tensor_size={total_ag_tensor_size} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
|
||||
# )
|
||||
elif len(prefetch_ags) > 0:
|
||||
prefetch_ag_groups.append(prefetch_ags)
|
||||
prefetch_ags = []
|
||||
# print_rank_0(
|
||||
# f"Free up memory prefetch_ags={prefetch_ag_groups} next_alloc_mem={next_alloc_mem} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}"
|
||||
# )
|
||||
else:
|
||||
break
|
||||
|
||||
if node.target == torch.ops.dc.allgather_param.default:
|
||||
|
||||
current_ag_size = sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
|
||||
pred_time_current = comm_predictor(current_ag_size)
|
||||
pred_time_next = comm_predictor(tensor_size_dict[node.name])
|
||||
pred_time_fused = comm_predictor(current_ag_size + tensor_size_dict[node.name])
|
||||
|
||||
do_fuse = max(pred_time_current, pred_time_next) * 1.2 > pred_time_fused and (
|
||||
current_ag_size + tensor_size_dict[node.name]) < MAX_FUSE_SIZE
|
||||
# print_rank_0(
|
||||
# f"found allgather_param do_fuse={do_fuse} current_ag_size={current_ag_size} tensor_size_dict[node.name]={tensor_size_dict[node.name]} pred_time_current={pred_time_current} pred_time_next={pred_time_next} pred_time_fused={pred_time_fused}"
|
||||
# )
|
||||
|
||||
if len(prefetch_ags) > 0 and not do_fuse:
|
||||
# stop fusing here
|
||||
prefetch_ag_groups.append(prefetch_ags)
|
||||
prefetch_ags = []
|
||||
# print_rank_0(
|
||||
# f"stop fusing prefetch_ags={prefetch_ag_groups} ag_tensor_size_sum={ag_tensor_size_sum}")
|
||||
# else:
|
||||
# print_rank_0(
|
||||
# f"continue fusing ag_tensor_size_sum={ag_tensor_size_sum} ag_size={tensor_size_dict[node.name]} prefetch_ags={prefetch_ags} prefetch_ag_groups={prefetch_ag_groups}"
|
||||
# )
|
||||
prefetch_ags.append(node)
|
||||
ag_tensor_size_sum += tensor_size_dict[node.name]
|
||||
|
||||
new_order_rev.append(node)
|
||||
|
||||
if (node.op != "placeholder"
|
||||
and node.target != torch.ops.dc.reload_parameter) and order_rev[i + 1].op == "placeholder":
|
||||
for ag_group in prefetch_ag_groups:
|
||||
assert len(ag_group) > 0
|
||||
new_order_rev.append(ag_group)
|
||||
total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in ag_group])
|
||||
ag_tensor_size_sum -= total_ag_tensor_size
|
||||
if len(prefetch_ags) > 0:
|
||||
new_order_rev.append(prefetch_ags)
|
||||
ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags])
|
||||
assert ag_tensor_size_sum == 0
|
||||
|
||||
# print_rank_0(
|
||||
# f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}"
|
||||
# )
|
||||
|
||||
assert ag_tensor_size_sum >= 0
|
||||
|
||||
new_graph = Graph()
|
||||
env = {}
|
||||
for node in reversed(new_order_rev):
|
||||
if isinstance(node, Node):
|
||||
#print(f"reconstruct {node.name} {node.target}")
|
||||
new_node = new_graph.node_copy(node, lambda n: env[n.name])
|
||||
env[node.name] = new_node
|
||||
else:
|
||||
param_nodes = [ag_node.args[0] for ag_node in node]
|
||||
param_nodes_copy = [env[param_node.name] for param_node in param_nodes]
|
||||
|
||||
ds_ids = [get_ds_id(ag_node) for ag_node in node]
|
||||
new_graph.call_function(torch.ops.dc.prefetch_params_fused.default,
|
||||
args=(graph_id, param_nodes_copy, ds_ids))
|
||||
new_graph.lint()
|
||||
gm.graph = new_graph
|
||||
|
||||
return gm
|
146
deepspeed/compile/passes/selective_gather.py
Normal file
146
deepspeed/compile/passes/selective_gather.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from ..util import get_deepcompile_handle
|
||||
from ..graph_param import DSGraphParamManager
|
||||
|
||||
NAME = "selective_gather"
|
||||
|
||||
max_alloc_mem = 0
|
||||
last_optimize_step = 0
|
||||
|
||||
|
||||
def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule:
|
||||
|
||||
if not bwd:
|
||||
return gm
|
||||
|
||||
last_backward_graph_id = None
|
||||
for g_id, needs_bwd in graph_order:
|
||||
if needs_bwd:
|
||||
last_backward_graph_id = g_id
|
||||
break
|
||||
|
||||
# Run only on the last backward graph
|
||||
if last_backward_graph_id is None or graph_id != last_backward_graph_id:
|
||||
return gm
|
||||
|
||||
peak_mem = 0
|
||||
for graph_id, prof in profiling_results.items():
|
||||
# Use peak memory
|
||||
fwd_max_mem = max(m[3] for m in prof.fwd_mem)
|
||||
bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0
|
||||
peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem)
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}"
|
||||
)
|
||||
|
||||
persistent_ds_ids = set()
|
||||
for graph_id, pm in param_manager.items():
|
||||
for name, ds_param in pm.params.items():
|
||||
if ds_param.param.ds_persist:
|
||||
persistent_ds_ids.add(pm.ds_ids[name])
|
||||
|
||||
ds_id_to_size = {}
|
||||
ds_id_to_time = defaultdict(float)
|
||||
ds_id_to_prof_dtime = defaultdict(float)
|
||||
ds_id_to_prof_wtime = defaultdict(float)
|
||||
|
||||
for graph_id, pm in param_manager.items():
|
||||
params = pm.params
|
||||
for param_name, param in params.items():
|
||||
ds_id = pm.ds_ids[param_name]
|
||||
ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize
|
||||
|
||||
profile = profiling_results[graph_id]
|
||||
for n in profile.fwd_graph.nodes:
|
||||
if n.target == torch.ops.dc.allgather_param.default:
|
||||
assert "tensor_size" in n.meta
|
||||
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
|
||||
assert "device_time" in n.meta
|
||||
ds_id_to_time[n.args[2]] += n.meta["device_time"]
|
||||
|
||||
ds_id_to_prof_dtime[n.args[2]] = n.meta["device_time"]
|
||||
ds_id_to_prof_wtime[n.args[2]] = n.meta["wall_time"]
|
||||
|
||||
if profile.bwd_graph is not None:
|
||||
for n in profile.bwd_graph.nodes:
|
||||
if n.target == torch.ops.dc.allgather_param.default:
|
||||
assert "tensor_size" in n.meta
|
||||
ds_id_to_size[n.args[2]] = n.meta["tensor_size"]
|
||||
assert "device_time" in n.meta
|
||||
ds_id_to_time[n.args[2]] += n.meta["device_time"]
|
||||
|
||||
ds_ids = [ds_id for ds_id in ds_id_to_size if ds_id not in persistent_ds_ids]
|
||||
ds_ids.sort(key=lambda ds_id: ds_id_to_time[ds_id] / ds_id_to_size[ds_id], reverse=True)
|
||||
|
||||
# print(f"ds_id_to_size={ds_id_to_size}")
|
||||
# print(f"ds_id_to_time={ds_id_to_time}")
|
||||
|
||||
# if dist.get_rank() == 0:
|
||||
# for ds_id in ds_ids:
|
||||
# dtime_in_sec = ds_id_to_prof_dtime[ds_id]
|
||||
# wtime_in_sec = ds_id_to_prof_wtime[ds_id]
|
||||
# size_in_mb = ds_id_to_size[ds_id] / 1024 / 1024
|
||||
# print(
|
||||
# f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s"
|
||||
# )
|
||||
|
||||
sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids}
|
||||
|
||||
accelerator = get_accelerator()
|
||||
total_mem = accelerator.total_memory()
|
||||
vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device()))
|
||||
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN)
|
||||
total_mem = vals_to_bcast[0].item()
|
||||
|
||||
MEM_MARGIN = 0.1
|
||||
available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}"
|
||||
)
|
||||
|
||||
ds_id_to_param = {}
|
||||
for g_id, g_pm in param_manager.items():
|
||||
for name, ds_param in g_pm.params.items():
|
||||
ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param
|
||||
|
||||
persistent_mem = 0
|
||||
nz3 = get_deepcompile_handle()
|
||||
for ds_id, size in sorted_ds_ids.items():
|
||||
if persistent_mem + size > available_mem:
|
||||
break
|
||||
persistent_mem += size
|
||||
|
||||
param_obj = ds_id_to_param[ds_id]
|
||||
|
||||
nz3.set_persistent(ds_id)
|
||||
if dist.get_rank() == 0:
|
||||
print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}")
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
# def make_selective_gather(z3_optimizer, nz3):
|
||||
|
||||
# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results,
|
||||
# mem_budget: float, param_manager, bwd: bool) -> Graph:
|
||||
# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd,
|
||||
# z3_optimizer, nz3)
|
||||
|
||||
# return selective_gather_wrapper
|
55
deepspeed/compile/passes/zero1_compile.py
Normal file
55
deepspeed/compile/passes/zero1_compile.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ..util import get_deepcompile_handle
|
||||
from ..fx import add_postprocess, move_primals_to_head, _make_node_meta
|
||||
|
||||
NAME = "zero1_compile"
|
||||
|
||||
|
||||
def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager) -> GraphModule:
|
||||
|
||||
dc = get_deepcompile_handle()
|
||||
param_indices = profiling_results[graph_id].param_indices
|
||||
dc.register_graph_z1(graph_id, [v[1] for v in param_indices]) # Need this before profiling
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule:
|
||||
|
||||
graph = gm.graph
|
||||
pm = param_manager[graph_id]
|
||||
_, param_name_to_grad = pm.get_bwd_mapping(graph)
|
||||
|
||||
for param_name in pm.param_names:
|
||||
|
||||
grad_node = param_name_to_grad[param_name]
|
||||
|
||||
assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids"
|
||||
ds_id = pm.ds_ids[param_name]
|
||||
|
||||
new_node = add_postprocess(graph,
|
||||
grad_node,
|
||||
torch.ops.dc.reduce_grad.default,
|
||||
extra_args=[graph_id, ds_id],
|
||||
name=f"reduce_param_{param_name}",
|
||||
meta=_make_node_meta(grad_node, param_name, True))
|
||||
new_node.meta["val"] = None
|
||||
|
||||
gm.graph = move_primals_to_head(graph)
|
||||
return gm
|
||||
|
||||
|
||||
def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
if bwd:
|
||||
return add_z1_reduce_bw(gm, graph_id, param_manager)
|
||||
return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager)
|
186
deepspeed/compile/passes/zero3_compile.py
Normal file
186
deepspeed/compile/passes/zero3_compile.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import gc
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, Node, GraphModule
|
||||
|
||||
from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses
|
||||
from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head
|
||||
from ..profilers.graph_profile import ProfilingInterpreter
|
||||
from ..list_schedule import fast_free_schedule
|
||||
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
NAME = "zero3_compile"
|
||||
|
||||
|
||||
def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int):
|
||||
new_ag_node = add_postprocess(graph,
|
||||
node,
|
||||
torch.ops.dc.allgather_param.default,
|
||||
extra_args=[graph_id, ds_id],
|
||||
name=f"allgather_ds_param_{node.target}_{ds_id}",
|
||||
meta=_make_node_meta(node, ds_id, True))
|
||||
new_ag_node.meta["val"] = node.meta["val"]
|
||||
|
||||
# Set the previous node back to output
|
||||
# We don't want to change the output node to allgather
|
||||
output_node = get_output_node(graph)
|
||||
output_node.replace_input_with(new_ag_node, node)
|
||||
|
||||
# Add wait as well
|
||||
new_wait_node = add_postprocess(graph,
|
||||
new_ag_node,
|
||||
torch.ops.dc.wait_allgather.default,
|
||||
extra_args=[graph_id, ds_id],
|
||||
name=f"wait_allgather_ds_param__{node.target}_{ds_id}",
|
||||
meta=_make_node_meta(node, ds_id, False))
|
||||
new_wait_node.meta["val"] = node.meta["val"]
|
||||
|
||||
return new_ag_node
|
||||
|
||||
|
||||
def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int):
|
||||
new_node = add_postprocess(graph,
|
||||
node,
|
||||
torch.ops.dc.release_param.default,
|
||||
extra_args=[graph_id, ds_id, n_users],
|
||||
name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}",
|
||||
meta=_make_node_meta(node, ds_id, False))
|
||||
new_node.meta["val"] = None
|
||||
|
||||
|
||||
def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int):
|
||||
new_node = add_postprocess(graph,
|
||||
grad_node,
|
||||
torch.ops.dc.reduce_grad.default,
|
||||
extra_args=[graph_id, ds_id],
|
||||
name=f"reduce_ds_param_{param_name}",
|
||||
meta=_make_node_meta(grad_node, ds_id, True))
|
||||
new_node.meta["val"] = None
|
||||
|
||||
|
||||
def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph:
|
||||
|
||||
node_to_uses = get_real_uses(graph)
|
||||
for pn in param_nodes:
|
||||
add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name])
|
||||
ds_id = param_manager.ds_ids[pn.name]
|
||||
users = node_to_uses[pn]
|
||||
for user in users:
|
||||
add_release(graph_id, graph, user, pn, ds_id, len(users))
|
||||
|
||||
return move_primals_to_head(graph)
|
||||
|
||||
|
||||
def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node],
|
||||
param_name_to_grad: Dict[str, Node]) -> Graph:
|
||||
|
||||
add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw)
|
||||
|
||||
for param_name in param_manager.param_names:
|
||||
add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name])
|
||||
|
||||
return move_primals_to_head(graph)
|
||||
|
||||
|
||||
def add_z3_gather_release_fw(gm: GraphModule,
|
||||
graph_id: int,
|
||||
graph_order: List[int],
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
param_manager,
|
||||
debug_log=False) -> GraphModule:
|
||||
|
||||
nz3 = get_deepcompile_handle()
|
||||
|
||||
real_inputs = create_inputs_fn()
|
||||
param_indices = profiling_results[graph_id].param_indices
|
||||
|
||||
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
|
||||
get_param_nodes(gm.graph, param_indices))
|
||||
|
||||
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
|
||||
|
||||
profiler = ProfilingInterpreter(gm, debug_log=debug_log)
|
||||
profiler.run(*real_inputs)
|
||||
del profiler
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
rank = dist.get_rank()
|
||||
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
||||
if rank == 0 and debug_log:
|
||||
print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
||||
|
||||
for n in gm.graph.nodes:
|
||||
is_ds_param = n.name in param_manager[graph_id].ds_ids
|
||||
if "val" in n.meta and is_ds_param:
|
||||
# Used for Inductor's validation
|
||||
n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device)
|
||||
|
||||
gm.graph = fast_free_schedule(
|
||||
gm.graph,
|
||||
get_accelerator().available_memory(),
|
||||
0, # unused
|
||||
debug_log=debug_log)
|
||||
|
||||
if rank == 0 and debug_log:
|
||||
print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def add_z3_gather_release_bw(gm: GraphModule,
|
||||
graph_id: int,
|
||||
graph_order: List[int],
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
param_manager,
|
||||
debug_log=False) -> GraphModule:
|
||||
|
||||
param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph)
|
||||
gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad)
|
||||
|
||||
input_nodes = get_input_nodes(gm.graph)
|
||||
real_inputs = create_inputs_fn()
|
||||
assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}"
|
||||
|
||||
real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs)
|
||||
|
||||
del real_outputs
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
rank = dist.get_rank()
|
||||
graph_index = get_index_by_graph_id(graph_order, graph_id)
|
||||
if rank == 0 and debug_log:
|
||||
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
|
||||
|
||||
# gm.graph = fast_free_schedule(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log)
|
||||
return gm
|
||||
|
||||
|
||||
def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn,
|
||||
mem_budget: float, param_manager, bwd: bool) -> GraphModule:
|
||||
if bwd:
|
||||
return add_z3_gather_release_bw(gm,
|
||||
graph_id,
|
||||
graph_order,
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
param_manager,
|
||||
debug_log=False)
|
||||
return add_z3_gather_release_fw(gm,
|
||||
graph_id,
|
||||
graph_order,
|
||||
profiling_results,
|
||||
create_inputs_fn,
|
||||
param_manager,
|
||||
debug_log=False)
|
93
deepspeed/compile/patch_compiled_func.py
Normal file
93
deepspeed/compile/patch_compiled_func.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
from deepspeed.utils.torch import required_torch_version
|
||||
|
||||
backward_inputs = []
|
||||
|
||||
enabled_patched_func = False
|
||||
original_grad_fn = None
|
||||
base_meta = type(torch.autograd.Function)
|
||||
|
||||
if required_torch_version(min_version=2.7):
|
||||
|
||||
class FunctionMeta(base_meta):
|
||||
|
||||
def __new__(cls, name, bases, dct):
|
||||
if name == "CompiledFunction":
|
||||
original_backward_impl = dct.get("_backward_impl")
|
||||
|
||||
def wrapped_backward_impl(ctx, all_args):
|
||||
assert original_backward_impl is not None
|
||||
|
||||
if enabled_patched_func:
|
||||
backward_inputs.append(all_args)
|
||||
wrapped_backward_impl.owner_class.compiled_bw = None
|
||||
|
||||
return original_backward_impl(ctx, all_args)
|
||||
|
||||
wrapped_backward_impl.owner_class = None
|
||||
dct["_backward_impl"] = staticmethod(wrapped_backward_impl)
|
||||
new_class = super().__new__(cls, name, bases, dct)
|
||||
wrapped_backward_impl.owner_class = new_class
|
||||
|
||||
return new_class
|
||||
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
elif required_torch_version(min_version=2.6):
|
||||
|
||||
class FunctionMeta(base_meta):
|
||||
|
||||
def __new__(cls, name, bases, dct):
|
||||
if name == "CompiledFunction":
|
||||
original_backward_prologue = dct.get("_backward_prologue")
|
||||
|
||||
def wrapped_backward_prologue(ctx, *grad_outputs):
|
||||
assert original_backward_prologue is not None
|
||||
|
||||
all_args = original_backward_prologue(ctx, *grad_outputs)
|
||||
if enabled_patched_func:
|
||||
backward_inputs.append(all_args)
|
||||
wrapped_backward_prologue.owner_class.compiled_bw = None
|
||||
|
||||
return all_args
|
||||
|
||||
wrapped_backward_prologue.owner_class = None
|
||||
dct["_backward_prologue"] = staticmethod(wrapped_backward_prologue)
|
||||
new_class = super().__new__(cls, name, bases, dct)
|
||||
wrapped_backward_prologue.owner_class = new_class
|
||||
|
||||
return new_class
|
||||
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
|
||||
def patch_compiled_func():
|
||||
|
||||
global enabled_patched_func
|
||||
enabled_patched_func = True
|
||||
|
||||
class PatchedFunction(torch.autograd.Function, metaclass=FunctionMeta):
|
||||
pass
|
||||
|
||||
global original_grad_fn
|
||||
original_grad_fn = torch.autograd.Function
|
||||
torch.autograd.Function = PatchedFunction
|
||||
|
||||
return backward_inputs
|
||||
|
||||
|
||||
def unpatch_compiled_func():
|
||||
global enabled_patched_func
|
||||
enabled_patched_func = False
|
||||
|
||||
global original_grad_fn
|
||||
torch.autograd.Function = original_grad_fn
|
||||
|
||||
|
||||
def get_backward_inputs():
|
||||
return backward_inputs
|
53
deepspeed/compile/patch_fake_tensor.py
Normal file
53
deepspeed/compile/patch_fake_tensor.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch._dynamo.variables.builder import wrap_to_fake_tensor_and_record
|
||||
except ImportError:
|
||||
# Unsupported torch version
|
||||
pass
|
||||
|
||||
|
||||
def wrap_if_ds_param(t):
|
||||
if hasattr(t, 'ds_id'):
|
||||
data = torch.rand(t.ds_shape,
|
||||
dtype=t.dtype,
|
||||
layout=t.layout,
|
||||
device=t.device,
|
||||
pin_memory=t.is_pinned(),
|
||||
requires_grad=t.requires_grad)
|
||||
if isinstance(t, torch.nn.Parameter):
|
||||
t = torch.nn.Parameter(data, requires_grad=t.requires_grad)
|
||||
else:
|
||||
t = data
|
||||
return t
|
||||
|
||||
|
||||
def patch_fake_tensor():
|
||||
# dynamo tracer uses wrap_to_fake_tensor_and_record
|
||||
# Wrapping FakeTensorMode.from_tensor is not sufficient as dynamo generates SymbolicContext before calling from_tensor
|
||||
original_wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record
|
||||
|
||||
def wrap_to_fake_tensor_and_record_wrapper(t, *args, **kwargs):
|
||||
dummy_tensor = wrap_if_ds_param(t)
|
||||
ret = original_wrap_to_fake_tensor_and_record(dummy_tensor, *args, **kwargs)
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.tensor_to_context[t] = tracing_context.tensor_to_context.pop(dummy_tensor)
|
||||
return ret
|
||||
|
||||
torch._dynamo.variables.builder.wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record_wrapper
|
||||
|
||||
# aot_module_simplified uses fake_mode.from_tensor to process inputs
|
||||
original_from_tensor = FakeTensorMode.from_tensor
|
||||
|
||||
def from_tensor_wrapper(self, t, *args, **kwargs):
|
||||
with unset_fake_temporarily():
|
||||
return original_from_tensor(self, wrap_if_ds_param(t), *args, **kwargs)
|
||||
|
||||
FakeTensorMode.from_tensor = from_tensor_wrapper
|
23
deepspeed/compile/profilers/__init__.py
Normal file
23
deepspeed/compile/profilers/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch.fx import Graph
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilingResult:
|
||||
fwd_graph: Graph = None
|
||||
bwd_graph: Graph = None
|
||||
needs_backward: bool = False
|
||||
fwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) # name, current_alloc, delta, peak
|
||||
bwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list)
|
||||
fwd_time: List[Tuple[str, int, int]] = field(default_factory=list) # name, device_time, wall_time
|
||||
bwd_time: List[Tuple[str, int, int]] = field(default_factory=list)
|
||||
fwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list) # name, size
|
||||
bwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list)
|
||||
param_indices: List[Tuple[int, int, Tuple[int, ...]]] = field(default_factory=list) # index, ds_id, ds_shape
|
171
deepspeed/compile/profilers/comm_profile.py
Normal file
171
deepspeed/compile/profilers/comm_profile.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
except ImportError:
|
||||
# Unsupported torch version
|
||||
pass
|
||||
|
||||
import deepspeed
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
|
||||
|
||||
def sync_all():
|
||||
get_accelerator().synchronize()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def get_bw(comm_op, size, duration):
|
||||
n = dist.get_world_size()
|
||||
tput = 0
|
||||
busbw = 0
|
||||
|
||||
if duration == 0:
|
||||
raise ValueError("Error. Duration is 0.")
|
||||
|
||||
if comm_op == "all_to_all":
|
||||
tput = (size / duration)
|
||||
busbw = (size / duration) * ((n - 1) / n)
|
||||
elif comm_op == "all_gather":
|
||||
size *= n
|
||||
tput = (size / duration)
|
||||
busbw = (size / duration) * ((n - 1) / n)
|
||||
elif comm_op == "all_reduce":
|
||||
tput = (size * 2 / duration)
|
||||
busbw = (size / duration) * (2 * (n - 1) / n)
|
||||
elif comm_op == "pt2pt" or comm_op == "broadcast":
|
||||
tput = (size / duration)
|
||||
busbw = tput
|
||||
else:
|
||||
raise ValueError("wrong comm_op specified")
|
||||
|
||||
return tput, busbw
|
||||
|
||||
|
||||
# Run all_gather and print metrics
|
||||
def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op):
|
||||
sync_all()
|
||||
# Warmups, establish connections, etc.
|
||||
for i in range(warmup):
|
||||
dist.all_gather_into_tensor(output, input, async_op=async_op)
|
||||
sync_all()
|
||||
|
||||
# time the actual comm op trials times and average it
|
||||
start_event.record()
|
||||
for i in range(trials):
|
||||
dist.all_gather_into_tensor(output, input, async_op=async_op)
|
||||
end_event.record()
|
||||
sync_all()
|
||||
duration = start_event.elapsed_time(end_event) / 1000
|
||||
|
||||
# maintain and clean performance data
|
||||
avg_duration = duration / trials
|
||||
size = input.element_size() * input.nelement() * dist.get_world_size()
|
||||
# tput, busbw = get_bw('all_gather', size, avg_duration)
|
||||
|
||||
avg_duration_ten = torch.tensor([avg_duration], device=device)
|
||||
if dist.get_world_size() > 1:
|
||||
dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG)
|
||||
|
||||
return size, avg_duration_ten.item()
|
||||
|
||||
|
||||
def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False):
|
||||
|
||||
# Prepare benchmark header
|
||||
global_rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
start_event = get_accelerator().Event(enable_timing=True)
|
||||
end_event = get_accelerator().Event(enable_timing=True)
|
||||
|
||||
# Create list of message sizes
|
||||
M_LIST = []
|
||||
for x in (2**p for p in range(1, maxsize)):
|
||||
m = x // world_size
|
||||
if m > 0:
|
||||
M_LIST.append(m)
|
||||
|
||||
results = [(0, 0)]
|
||||
sync_all()
|
||||
# loop over various tensor sizes
|
||||
for M in M_LIST:
|
||||
global_rank = dist.get_rank()
|
||||
try:
|
||||
mat = torch.ones(M, dtype=dtype, device=device)
|
||||
sync_all()
|
||||
input = ((mat.mul_(float(global_rank))).view(-1))
|
||||
# Delete original mat to avoid OOM
|
||||
del mat
|
||||
get_accelerator().empty_cache()
|
||||
output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
if dist.get_rank() == 0:
|
||||
print('WARNING: Ran out of GPU memory. Exiting comm op.')
|
||||
sync_all()
|
||||
break
|
||||
else:
|
||||
raise e
|
||||
sync_all()
|
||||
results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
profile_results = None
|
||||
|
||||
|
||||
def create_predictor():
|
||||
global profile_results
|
||||
if profile_results is None:
|
||||
with unset_fake_temporarily():
|
||||
device = get_accelerator().current_device()
|
||||
profile_results = run_all_gather(device, torch.bfloat16, 31)
|
||||
if dist.get_rank() == 0:
|
||||
for size, avg_duration in profile_results:
|
||||
print(f"size: {size}, avg_duration: {avg_duration}")
|
||||
|
||||
# Extract size and avg_duration from results
|
||||
sizes = [result[0] for result in profile_results]
|
||||
durations = [result[1] for result in profile_results]
|
||||
|
||||
try:
|
||||
from scipy.interpolate import interp1d
|
||||
except ImportError:
|
||||
raise RuntimeError("Please install scipy to use communication profiler in DeepCompile")
|
||||
|
||||
predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate")
|
||||
|
||||
def f(size):
|
||||
if size == 0:
|
||||
return 0
|
||||
return predictor(size)
|
||||
|
||||
# Create an interpolation function
|
||||
return f
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
get_accelerator().set_device(local_rank)
|
||||
print(f"local_rank={local_rank}")
|
||||
|
||||
deepspeed.init_distributed(dist_backend='nccl')
|
||||
|
||||
# Create predictor function
|
||||
predictor = create_predictor()
|
||||
|
||||
# Predict time for a specific data size
|
||||
example_size = 1e9
|
||||
predicted_time = predictor(example_size)
|
||||
print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds")
|
||||
|
||||
dist.destroy_process_group()
|
295
deepspeed/compile/profilers/graph_profile.py
Normal file
295
deepspeed/compile/profilers/graph_profile.py
Normal file
@ -0,0 +1,295 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import time
|
||||
from typing import Any, Tuple, Dict
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Interpreter
|
||||
from torch.fx.node import map_aggregate
|
||||
|
||||
try:
|
||||
from torch.utils._pytree import tree_all, tree_leaves
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake
|
||||
except ImportError:
|
||||
# Unsupported torch version
|
||||
pass
|
||||
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from ..util import is_comm_op, is_release_node, get_deepcompile_handle
|
||||
|
||||
|
||||
def _all_real_if_tensor(args):
|
||||
return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args)
|
||||
|
||||
|
||||
def _to(v, device):
|
||||
if torch.is_tensor(v):
|
||||
with unset_fake_temporarily():
|
||||
return v.to(device)
|
||||
return v
|
||||
|
||||
|
||||
def _args_to_key(v):
|
||||
|
||||
def _tensor_to_key(v) -> str:
|
||||
if torch.is_tensor(v):
|
||||
if v.numel() == 1:
|
||||
return f"{v.dtype}{v.device}{v.item()}"
|
||||
else:
|
||||
return f"{v.dtype}{v.device}{v.shape}"
|
||||
return str(v)
|
||||
|
||||
return map_aggregate(v, _tensor_to_key)
|
||||
|
||||
|
||||
def _node_size(out):
|
||||
return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)])
|
||||
|
||||
|
||||
def _get_mem_usage_out_of_torch():
|
||||
|
||||
adjust = 0
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
|
||||
current_dev_id = get_accelerator().current_device()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
torch_alloc = get_accelerator().memory_allocated()
|
||||
adjust = info.used - torch_alloc
|
||||
except:
|
||||
# pynvml not available
|
||||
pass
|
||||
|
||||
return adjust
|
||||
|
||||
|
||||
# https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html
|
||||
class ProfilingInterpreter(Interpreter):
|
||||
|
||||
def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False):
|
||||
super().__init__(gm)
|
||||
|
||||
self.nz3 = get_deepcompile_handle()
|
||||
|
||||
assert iteration > 0
|
||||
assert warmup >= 0
|
||||
self.iteration = iteration
|
||||
self.warmup = warmup
|
||||
self.device = torch.device(get_accelerator().current_device())
|
||||
self.cache: Dict[Tuple, Any] = {}
|
||||
self.distributed = dist.is_initialized()
|
||||
self.allgather_mem: Dict[int, int] = {}
|
||||
self.debug_log = debug_log
|
||||
self.mem_usage_out_of_torch = 0
|
||||
|
||||
def run(self, *args) -> Any:
|
||||
"""Run the graph with profiling enabled.
|
||||
|
||||
args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters.
|
||||
returns: The output of the graph. Tensor in the output is real tensors.
|
||||
"""
|
||||
try:
|
||||
assert _all_real_if_tensor(args), "Inputs must be real tensors"
|
||||
self.nz3.enable_profiling(True)
|
||||
|
||||
with unset_fake_temporarily():
|
||||
with get_accelerator().random().fork_rng(devices=[self.device]):
|
||||
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
|
||||
return_val = super().run(*args)
|
||||
except Exception as e:
|
||||
msg = e.msg if "msg" in dir(e) else str(e)
|
||||
print(f"Profiling error {msg}")
|
||||
finally:
|
||||
self.nz3.clear_all_gathered_params()
|
||||
self.nz3.enable_profiling(False)
|
||||
return return_val
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
|
||||
if n.op in {"placeholder", "output"}:
|
||||
n.meta["device_time"] = 0.0
|
||||
n.meta["wall_time"] = 0.0
|
||||
n.meta["memory"] = 0
|
||||
n.meta["max_memory"] = 0
|
||||
n.meta["tensor_size"] = _node_size(n)
|
||||
return super().run_node(n)
|
||||
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
assert isinstance(args, tuple)
|
||||
assert isinstance(kwargs, dict)
|
||||
|
||||
def rebuild_param_if_necessary(v):
|
||||
if hasattr(v, "ds_id"):
|
||||
v.all_gather(param_list=[v])
|
||||
return v
|
||||
|
||||
args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x))
|
||||
|
||||
args = map_aggregate(args, lambda x: _to(x, self.device))
|
||||
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device))
|
||||
|
||||
cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs))
|
||||
cache_hit = cache_key in self.cache
|
||||
|
||||
cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int)
|
||||
if self.distributed:
|
||||
dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM)
|
||||
cache_hit = cache_hit_flag.item() == 0
|
||||
|
||||
if cache_hit:
|
||||
device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key]
|
||||
n.meta["device_time"] = device_time
|
||||
n.meta["wall_time"] = wall_time
|
||||
n.meta["alloc_memory"] = alloc_mem
|
||||
n.meta["max_memory"] = max_mem
|
||||
n.meta["tensor_size"] = tensor_size
|
||||
|
||||
is_release_op = is_release_node(n)
|
||||
run_only_once = cache_hit or is_release_op
|
||||
iteration = 1 if run_only_once else self.iteration
|
||||
accelerator = get_accelerator()
|
||||
start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
|
||||
end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)]
|
||||
|
||||
get_accelerator().reset_peak_memory_stats()
|
||||
alloc_mem_start = get_accelerator().memory_allocated()
|
||||
max_mem_start = get_accelerator().max_memory_allocated()
|
||||
|
||||
if not run_only_once:
|
||||
for i in range(self.warmup):
|
||||
out = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
if is_comm_op(n):
|
||||
assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used."
|
||||
dist.barrier()
|
||||
|
||||
start = time.time()
|
||||
for i in range(iteration):
|
||||
start_events[i].record()
|
||||
out = getattr(self, n.op)(n.target, args, kwargs)
|
||||
end_events[i].record()
|
||||
accelerator.synchronize()
|
||||
walltime_sum = time.time() - start
|
||||
|
||||
if is_comm_op(n):
|
||||
dist.barrier()
|
||||
|
||||
alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch
|
||||
max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch
|
||||
tensor_size = _node_size(out)
|
||||
|
||||
def partition_param_if_necessary(v):
|
||||
if hasattr(v, "ds_id") and not v.ds_persist:
|
||||
v.partition(param_list=[v], has_been_updated=False)
|
||||
return v
|
||||
|
||||
args = map_aggregate(args, lambda x: partition_param_if_necessary(x))
|
||||
|
||||
if not cache_hit:
|
||||
device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)])
|
||||
wall_time = walltime_sum / iteration * 1000
|
||||
|
||||
with unset_fake_temporarily():
|
||||
vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size],
|
||||
device=self.device)
|
||||
if self.distributed:
|
||||
dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG)
|
||||
n.meta["device_time"] = vals_to_bcast[0].item()
|
||||
n.meta["wall_time"] = vals_to_bcast[1].item()
|
||||
n.meta["alloc_mem"] = int(vals_to_bcast[2].item())
|
||||
n.meta["max_mem"] = int(vals_to_bcast[3].item())
|
||||
n.meta["tensor_size"] = int(vals_to_bcast[4].item())
|
||||
self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"],
|
||||
n.meta["max_mem"], n.meta["tensor_size"])
|
||||
|
||||
if is_release_op:
|
||||
n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0)
|
||||
|
||||
if dist.get_rank() == 0 and self.debug_log:
|
||||
print(
|
||||
f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}"
|
||||
)
|
||||
|
||||
if n.target == torch.ops.dc.allgather_param.default:
|
||||
out = args[0]
|
||||
assert hasattr(out, "ds_id")
|
||||
if not out.ds_persist:
|
||||
self.nz3.invalidate_gathered_param(args[2])
|
||||
self.allgather_mem[out.ds_id] = n.meta["alloc_mem"]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MemoryProfilingInterpreter(Interpreter):
|
||||
|
||||
def __init__(self, gm: GraphModule, debug_log=False):
|
||||
super().__init__(gm)
|
||||
self.nz3 = get_deepcompile_handle()
|
||||
self.device = torch.device(get_accelerator().current_device())
|
||||
self.mem_record = []
|
||||
self.last_alloc = get_accelerator().memory_allocated()
|
||||
|
||||
self.node_counter = 0
|
||||
self.node_num = len(gm.graph.nodes)
|
||||
self.debug_log = debug_log
|
||||
|
||||
def run(self, *args) -> Any:
|
||||
try:
|
||||
assert _all_real_if_tensor(args), "Inputs must be real tensors"
|
||||
self.nz3.enable_profiling(True)
|
||||
self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch()
|
||||
|
||||
with unset_fake_temporarily():
|
||||
with get_accelerator().random().fork_rng(devices=[self.device]):
|
||||
return_val = super().run(*args)
|
||||
except Exception as e:
|
||||
print(f"MemoryProfiling error {e}")
|
||||
finally:
|
||||
self.nz3.enable_profiling(False)
|
||||
|
||||
return return_val
|
||||
|
||||
def run_node(self, n: torch.fx.Node) -> Any:
|
||||
get_accelerator().reset_peak_memory_stats()
|
||||
|
||||
if n.op in {"placeholder", "output"}:
|
||||
ret = super().run_node(n)
|
||||
else:
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
args = map_aggregate(args, lambda x: _to(x, self.device))
|
||||
kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device))
|
||||
ret = getattr(self, n.op)(n.target, args, kwargs)
|
||||
|
||||
del args, kwargs
|
||||
|
||||
current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch
|
||||
max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch
|
||||
vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device)
|
||||
dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX)
|
||||
current_alloc = vals_to_bcast[0].item()
|
||||
max_alloc = vals_to_bcast[1].item()
|
||||
|
||||
self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc))
|
||||
|
||||
self.node_counter += 1
|
||||
if self.debug_log and dist.get_rank() == 0:
|
||||
print(
|
||||
f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB"
|
||||
)
|
||||
|
||||
self.last_alloc = current_alloc
|
||||
|
||||
return ret
|
||||
|
||||
def dump(self, path):
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"])
|
||||
df.to_csv(path, index=False)
|
429
deepspeed/compile/util.py
Normal file
429
deepspeed/compile/util.py
Normal file
@ -0,0 +1,429 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import functools
|
||||
import operator
|
||||
from typing import List, Tuple, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.fx import Node, Graph
|
||||
from torch.fx.node import map_aggregate, Argument, map_arg
|
||||
|
||||
try:
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
except ImportError:
|
||||
# Unsupported torch version
|
||||
pass
|
||||
|
||||
import deepspeed.comm as dist
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.utils.torch import required_torch_version
|
||||
from deepspeed.ops.op_builder.dc import DeepCompileBuilder
|
||||
|
||||
|
||||
def is_deepcompile_supported() -> bool:
|
||||
return required_torch_version(min_version=2.6, max_version=2.7) and get_accelerator().device_name() == "cuda"
|
||||
|
||||
|
||||
dc_handle = None
|
||||
|
||||
if is_deepcompile_supported():
|
||||
sym_size_ops = {
|
||||
operator.ge,
|
||||
operator.le,
|
||||
operator.eq,
|
||||
operator.ne,
|
||||
operator.gt,
|
||||
operator.lt,
|
||||
torch.ops.aten.sym_size.int,
|
||||
operator.getitem,
|
||||
}
|
||||
|
||||
|
||||
def get_deepcompile_handle():
|
||||
global dc_handle
|
||||
if dc_handle is None:
|
||||
dc_handle = DeepCompileBuilder().load()
|
||||
return dc_handle
|
||||
|
||||
|
||||
def is_backend_inductor(backend):
|
||||
return backend == "inductor"
|
||||
|
||||
|
||||
backward_started = False
|
||||
pre_backward_hooks = []
|
||||
|
||||
|
||||
def add_pre_backward_hook(hook):
|
||||
pre_backward_hooks.append(hook)
|
||||
|
||||
|
||||
def deepcompile_backward_prologue(is_gradient_accumulation_boundary):
|
||||
|
||||
for hook in pre_backward_hooks:
|
||||
hook()
|
||||
|
||||
dc = get_deepcompile_handle()
|
||||
dc.start_backward(is_gradient_accumulation_boundary)
|
||||
|
||||
|
||||
def log_rank0(msg: str, enable: bool = False):
|
||||
if dist.get_rank() == 0 and enable:
|
||||
print(msg)
|
||||
|
||||
|
||||
def get_no_copy_ops():
|
||||
# Need to compile custom ops
|
||||
get_deepcompile_handle()
|
||||
return {
|
||||
torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default,
|
||||
torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default
|
||||
}
|
||||
|
||||
|
||||
def get_input_nodes(graph: Graph) -> List[Node]:
|
||||
return [n for n in graph.nodes if n.op == "placeholder"]
|
||||
|
||||
|
||||
def get_param_nodes(graph: Graph, index_to_ds_ids: List[Tuple[int, int]]) -> List[Node]:
|
||||
all_input_nodes = get_input_nodes(graph)
|
||||
return [all_input_nodes[i] for i, _, _ in index_to_ds_ids]
|
||||
|
||||
|
||||
def is_comm_op(node: Node) -> bool:
|
||||
return "comm" in node.meta and node.meta["comm"]
|
||||
|
||||
|
||||
def exclude_from_act_offload(node: Node) -> bool:
|
||||
return node.target in sym_size_ops
|
||||
|
||||
|
||||
def dtype_to_elem_size(dtype: torch.dtype) -> int:
|
||||
if dtype == torch.float32:
|
||||
elem_size = 4
|
||||
elif dtype == torch.float64:
|
||||
elem_size = 8
|
||||
elif dtype == torch.float16:
|
||||
elem_size = 2
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
return elem_size
|
||||
|
||||
|
||||
def tensor_meta_size(tensor_meta) -> int:
|
||||
numel = 1 if len(tensor_meta.shape) == 0 else functools.reduce(operator.mul, tensor_meta.shape)
|
||||
|
||||
dtype = tensor_meta.dtype
|
||||
if dtype == torch.float32:
|
||||
elem_size = 4
|
||||
elif dtype == torch.float64 or dtype == torch.int64:
|
||||
elem_size = 8
|
||||
elif dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
elem_size = 2
|
||||
elif dtype == torch.bool:
|
||||
elem_size = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
return numel * elem_size
|
||||
|
||||
|
||||
class NodeValueOffloadHelper:
|
||||
|
||||
def __init__(self, device):
|
||||
self.device = device
|
||||
self.env_values: Dict[str, Argument] = {}
|
||||
self.original_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
||||
def _to_cpu(self, v):
|
||||
if torch.is_tensor(v):
|
||||
with unset_fake_temporarily():
|
||||
device = v.device
|
||||
offloaded = v.to('cpu').detach()
|
||||
self.original_device[offloaded] = device
|
||||
return offloaded
|
||||
return v
|
||||
|
||||
def _from_cpu(self, v):
|
||||
if torch.is_tensor(v) and v in self.original_device:
|
||||
return v.to(self.original_device[v])
|
||||
return v
|
||||
|
||||
def save(self, name: str, v: Argument, offload) -> None:
|
||||
self.env_values[name] = map_aggregate(v, lambda x: self._to_cpu(x) if offload else x)
|
||||
|
||||
def load(self, name: str) -> Argument:
|
||||
return map_aggregate(self.env_values[name], lambda x: self._from_cpu(x))
|
||||
|
||||
def get_offloaded_value(self, name: str) -> Argument:
|
||||
return self.env_values[name]
|
||||
|
||||
def has_value(self, name: str) -> bool:
|
||||
return name in self.env_values
|
||||
|
||||
def clear(self) -> None:
|
||||
self.env_values.clear()
|
||||
self.original_device.clear()
|
||||
|
||||
|
||||
def materialize_fake(v, device=None):
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
def convert(t):
|
||||
if is_fake(t):
|
||||
with unset_fake_temporarily():
|
||||
if t.is_floating_point():
|
||||
return torch.randn(t.shape,
|
||||
dtype=t.dtype,
|
||||
device=t.device if device is None else device,
|
||||
layout=t.layout,
|
||||
requires_grad=t.requires_grad,
|
||||
pin_memory=t.is_pinned())
|
||||
else:
|
||||
return torch.zeros(t.shape,
|
||||
dtype=t.dtype,
|
||||
device=t.device if device is None else device,
|
||||
requires_grad=t.requires_grad)
|
||||
|
||||
return t
|
||||
|
||||
return map_aggregate(v, lambda x: convert(x))
|
||||
|
||||
|
||||
def get_last_uses(graph: Graph):
|
||||
position = {node: i for i, node in enumerate(graph.nodes)}
|
||||
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
no_copy_ops = get_no_copy_ops()
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
update = False
|
||||
known_last_use = None
|
||||
|
||||
if user.target in no_copy_ops and n in node_to_last_use:
|
||||
last_user = node_to_last_use[user]
|
||||
last_use_position = position[last_user]
|
||||
|
||||
known_last_use = node_to_last_use[n]
|
||||
known_last_use_position = position[known_last_use]
|
||||
update = last_use_position > known_last_use_position
|
||||
|
||||
if n not in node_to_last_use or update:
|
||||
if user.target in no_copy_ops:
|
||||
user = node_to_last_use[user]
|
||||
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
if known_last_use:
|
||||
user_to_last_uses[known_last_use].remove(n)
|
||||
|
||||
for node in reversed(graph.nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
return node_to_last_use, user_to_last_uses
|
||||
|
||||
|
||||
def get_real_uses(graph: Graph):
|
||||
node_to_uses: Dict[Node, List[Node]] = defaultdict(list)
|
||||
no_copy_ops = get_no_copy_ops()
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if user.target == "output":
|
||||
return
|
||||
|
||||
if user.target in no_copy_ops:
|
||||
users = node_to_uses[user]
|
||||
node_to_uses[n].extend(users)
|
||||
else:
|
||||
node_to_uses[n].append(user)
|
||||
|
||||
for node in reversed(graph.nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
return node_to_uses
|
||||
|
||||
|
||||
def count_inflight_values(graph: Graph, file_path: str):
|
||||
position = {node: i for i, node in enumerate(graph.nodes)}
|
||||
|
||||
node_to_last_use, user_to_last_uses = get_last_uses(graph)
|
||||
|
||||
max_inflight_size = 0
|
||||
inflight_values = set()
|
||||
|
||||
# Output csv.
|
||||
csv_filename = file_path
|
||||
csv_data = []
|
||||
header = [
|
||||
'Node', 'tensor_size', 'inflight_size', 'inflight_size_in_output', 'args', 'users', 'node_to_last_use',
|
||||
'lifetime', 'user_to_last_uses', 'inflight_values'
|
||||
]
|
||||
csv_data.append(header)
|
||||
|
||||
from .fx import get_output_node
|
||||
output_node = get_output_node(graph)
|
||||
values_in_output = set([n for n in output_node.args[0] if isinstance(n, Node)])
|
||||
|
||||
for node in graph.nodes:
|
||||
inflight_values.add(node)
|
||||
if node in user_to_last_uses:
|
||||
for to_delete in user_to_last_uses[node]:
|
||||
inflight_values.remove(to_delete)
|
||||
|
||||
assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size"
|
||||
inflight_size = sum(n.meta["tensor_size"] for n in inflight_values)
|
||||
inflight_size_in_output = sum(n.meta["tensor_size"] for n in inflight_values if n in values_in_output)
|
||||
|
||||
lifetime = position[node_to_last_use[node]] - position[node] if node in node_to_last_use else 0
|
||||
|
||||
row = [
|
||||
node.name, node.meta["tensor_size"], inflight_size, inflight_size_in_output,
|
||||
[a.name for a in node.args if isinstance(a, Node)],
|
||||
list(node.users.keys()), node_to_last_use[node] if node in node_to_last_use else 'NA', lifetime,
|
||||
user_to_last_uses[node] if node in user_to_last_uses else 'NA',
|
||||
list(inflight_values)
|
||||
]
|
||||
csv_data.append(row)
|
||||
|
||||
# print(
|
||||
# f"Node: {node.name} users: {list(node.users.keys())} node_to_last_use: {node_to_last_use[node] if node in node_to_last_use else 'NA'} user_to_last_uses: {user_to_last_uses[node] if node in user_to_last_uses else 'NA'} inflight_values: {inflight_values} inflight_size: {inflight_size}"
|
||||
# )
|
||||
max_inflight_size = max(max_inflight_size, inflight_size)
|
||||
|
||||
import csv
|
||||
with open(csv_filename, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerows(csv_data)
|
||||
|
||||
print(f"Max inflight size: {max_inflight_size}")
|
||||
print(f"Data successfully written to {csv_filename}")
|
||||
|
||||
|
||||
def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], fwd_output_names: List[str]):
|
||||
|
||||
input_nodes = get_input_nodes(graph)
|
||||
param_node_names = set([n.name for n in param_nodes_bw])
|
||||
|
||||
activation_node_names = []
|
||||
for in_node in input_nodes:
|
||||
if in_node.name in fwd_output_names:
|
||||
if in_node.name not in param_node_names:
|
||||
activation_node_names.append(in_node.name)
|
||||
|
||||
return activation_node_names
|
||||
|
||||
|
||||
class TensorOffloadHelper():
|
||||
|
||||
def __init__(self):
|
||||
self.devices = {}
|
||||
self.base_tensors = {}
|
||||
self.views = {}
|
||||
self.arg_list = []
|
||||
self.offloaded = {}
|
||||
self.non_tensor = {}
|
||||
|
||||
def offload(self, argument):
|
||||
|
||||
def is_base_tensor(tensor):
|
||||
return torch.is_tensor(a) and not a._is_view() and not hasattr(tensor, "ds_id")
|
||||
|
||||
base_tensor_ids = set()
|
||||
for a in argument:
|
||||
if is_base_tensor(a):
|
||||
base_tensor_ids.add(id(a))
|
||||
|
||||
for a in argument:
|
||||
a_id = id(a)
|
||||
|
||||
if is_base_tensor(a):
|
||||
# Base tensor
|
||||
self.devices[a_id] = a.device
|
||||
self.base_tensors[a_id] = a
|
||||
# elif torch.is_tensor(a) and not hasattr(a, "ds_id") and id(a._base) in base_tensor_ids:
|
||||
# # View
|
||||
# self.views[a_id] = {
|
||||
# "base_id": id(a._base),
|
||||
# "size": a.size(),
|
||||
# "stride": a.stride(),
|
||||
# "offset": a.storage_offset(),
|
||||
# }
|
||||
else:
|
||||
# other types or ds tensor
|
||||
self.non_tensor[a_id] = a
|
||||
|
||||
self.arg_list.append(a_id)
|
||||
|
||||
for a in argument:
|
||||
if is_base_tensor(a):
|
||||
a.data = a.data.to("cpu")
|
||||
|
||||
def reload(self, in_place):
|
||||
|
||||
loaded_base_tensors = {}
|
||||
for a_id in self.arg_list:
|
||||
if a_id in self.base_tensors:
|
||||
device = self.devices[a_id]
|
||||
|
||||
if in_place:
|
||||
self.base_tensors[a_id].data = self.base_tensors[a_id].to(device)
|
||||
loaded_base_tensors[a_id] = self.base_tensors[a_id]
|
||||
else:
|
||||
loaded_base_tensors[a_id] = self.base_tensors[a_id].to(device)
|
||||
|
||||
results = []
|
||||
for a_id in self.arg_list:
|
||||
if a_id in self.base_tensors:
|
||||
results.append(loaded_base_tensors[a_id])
|
||||
|
||||
# elif a_id in self.views:
|
||||
# view_info = self.views[a_id]
|
||||
# # print(f"load_args loading view {a_id} base_id={view_info['base_id']} size={view_info['size']} stride={view_info['stride']} offset={view_info['offset']}")
|
||||
# base_tensor = loaded_base_tensors[view_info["base_id"]]
|
||||
# view_tensor = base_tensor.as_strided(
|
||||
# view_info["size"], view_info["stride"], view_info["offset"]
|
||||
# )
|
||||
# results.append(view_tensor)
|
||||
|
||||
elif a_id in self.non_tensor:
|
||||
results.append(self.non_tensor[a_id])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def add_mem_profile_nodes(graph: Graph, prefix: str):
|
||||
|
||||
def show_memory(label: str):
|
||||
if dist.get_rank() == 0:
|
||||
print(
|
||||
f"{prefix} {label} alloc_mem={get_accelerator().memory_allocated()} max_mem={get_accelerator().max_memory_allocated()}"
|
||||
)
|
||||
|
||||
nodes = list(graph.nodes)
|
||||
for node in nodes:
|
||||
if node.op == "output":
|
||||
continue
|
||||
|
||||
with graph.inserting_after(node):
|
||||
msg = f"Mem {node.name}"
|
||||
name = f"show_memory_{node.name}"
|
||||
graph.create_node('call_function', show_memory, (msg, ), {}, name=name)
|
||||
|
||||
|
||||
def is_release_node(n: Node) -> bool:
|
||||
return n.target == torch.ops.dc.release_param.default
|
||||
|
||||
|
||||
def get_index_by_graph_id(graph_order, target_graph_id):
|
||||
for index, (graph_id, _) in enumerate(graph_order):
|
||||
if graph_id == target_graph_id:
|
||||
return index
|
||||
return -1
|
6
deepspeed/ops/compile/__init__.py
Executable file
6
deepspeed/ops/compile/__init__.py
Executable file
@ -0,0 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from ..op_builder import DeepCompileBuilder
|
@ -31,6 +31,7 @@ from .activation_checkpointing.config import DeepSpeedActivationCheckpointingCon
|
||||
from ..comm.config import DeepSpeedCommsConfig
|
||||
from ..monitor.config import get_monitor_config
|
||||
from ..inference.config import WeightQuantConfig
|
||||
from ..compile.config import CompileConfig
|
||||
|
||||
from deepspeed import comm as dist
|
||||
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
|
||||
@ -912,6 +913,8 @@ class DeepSpeedConfig(object):
|
||||
self.weight_quantization_config = WeightQuantConfig(
|
||||
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None
|
||||
|
||||
self.compile_config = CompileConfig(**param_dict.get('compile', {}))
|
||||
|
||||
self.timers_config = get_timers_config(param_dict)
|
||||
self.tensor_parallel_config = get_tensor_parallel_config(param_dict)
|
||||
|
||||
|
@ -107,6 +107,12 @@ from deepspeed.accelerator import get_accelerator
|
||||
|
||||
from deepspeed.runtime.config import DtypeEnum
|
||||
|
||||
from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue
|
||||
from deepspeed.compile.backend import register_compile_pass, opt_passes
|
||||
from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states
|
||||
from deepspeed.compile.init_z1 import init_z1
|
||||
from deepspeed.compile.init_z3 import init_z3
|
||||
|
||||
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
|
||||
|
||||
DeepSpeedOptimizerCallable = \
|
||||
@ -271,8 +277,9 @@ class DeepSpeedEngine(Module):
|
||||
# Configure distributed model
|
||||
self._configure_distributed_model(model)
|
||||
|
||||
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
|
||||
self.module_forward_post_hook = self._create_module_forward_post_hook()
|
||||
if not self.is_deepcompile_enabled():
|
||||
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
|
||||
self.module_forward_post_hook = self._create_module_forward_post_hook()
|
||||
|
||||
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
|
||||
self.param_names = {param: name for name, param in model.named_parameters()}
|
||||
@ -377,6 +384,12 @@ class DeepSpeedEngine(Module):
|
||||
self.unflatten = _unflatten_dense_tensors
|
||||
|
||||
self._is_compiled = False
|
||||
if is_deepcompile_supported():
|
||||
# Predefined compile passes
|
||||
self.register_compile_pass(zero3_compile.NAME, zero3_compile.add_z3_gather_release)
|
||||
self.register_compile_pass(prefetch.NAME, prefetch.schedule_prefetch)
|
||||
self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather)
|
||||
self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states)
|
||||
|
||||
def _optimized_linear_offload_setup(self):
|
||||
self.optimized_linear_base_weight_sharding = False
|
||||
@ -486,6 +499,8 @@ class DeepSpeedEngine(Module):
|
||||
def destroy(self):
|
||||
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
|
||||
self.optimizer.destroy()
|
||||
if self.is_deepcompile_enabled():
|
||||
get_deepcompile_handle().cleanup()
|
||||
debug_clear_module_and_param_names()
|
||||
|
||||
def _get_model_parameters(self):
|
||||
@ -2032,6 +2047,10 @@ class DeepSpeedEngine(Module):
|
||||
if self.autotuning_profile_model_info():
|
||||
ma = get_ma_status()
|
||||
|
||||
if self.is_deepcompile_enabled() and hasattr(self, "launch_compile_passes"):
|
||||
# We can't have this in forward prologue as the compiler compiles hooks including the forward prologue.
|
||||
self.launch_compile_passes(self.global_steps)
|
||||
|
||||
loss = self.module(*inputs, **kwargs)
|
||||
|
||||
if self.autotuning_profile_model_info():
|
||||
@ -2104,7 +2123,8 @@ class DeepSpeedEngine(Module):
|
||||
scale_wrt_gas = self.scale_wrt_gas
|
||||
|
||||
# scale loss w.r.t. gradient accumulation if reduction is not disabled
|
||||
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
|
||||
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_enabled(
|
||||
)
|
||||
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
|
||||
loss = self._scale_loss_by_gas(loss.float())
|
||||
|
||||
@ -2121,6 +2141,9 @@ class DeepSpeedEngine(Module):
|
||||
)]
|
||||
self.monitor.write_events(self.summary_events)
|
||||
|
||||
if self.is_deepcompile_enabled():
|
||||
deepcompile_backward_prologue(self.is_gradient_accumulation_boundary())
|
||||
|
||||
return loss
|
||||
|
||||
def _backward_epilogue(self):
|
||||
@ -2128,6 +2151,7 @@ class DeepSpeedEngine(Module):
|
||||
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
|
||||
# Traditional code path that allreduces the module parameter grads
|
||||
self.allreduce_gradients()
|
||||
|
||||
self._stop_timers(self.engine_timers.backward_reduce_timers)
|
||||
see_memory_usage("Engine after backward", force=self.memory_breakdown())
|
||||
|
||||
@ -3849,7 +3873,7 @@ class DeepSpeedEngine(Module):
|
||||
gc.collect()
|
||||
get_accelerator().empty_cache()
|
||||
|
||||
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
|
||||
def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None:
|
||||
"""Compile the module using the specified backend and kwargs.
|
||||
If a compiler_fn is set, it will be used instead of torch.compile().
|
||||
"""
|
||||
@ -3865,10 +3889,50 @@ class DeepSpeedEngine(Module):
|
||||
if 'backend' in compile_kwargs:
|
||||
logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.")
|
||||
|
||||
print(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}")
|
||||
|
||||
if self.is_deepcompile_enabled():
|
||||
assert self.zero_optimization_stage() == ZeroStageEnum.optimizer_states \
|
||||
or self.zero_optimization_stage() == ZeroStageEnum.weights \
|
||||
, "Currently DeepCompile supports stage 1 or 3 only."
|
||||
|
||||
if schedule is not None:
|
||||
|
||||
def passes_name_to_fn(passes):
|
||||
for p in passes:
|
||||
assert callable(p) or p in opt_passes, f"Unknown pass {p}"
|
||||
return [p if callable(p) else opt_passes[p] for p in passes]
|
||||
|
||||
schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule]
|
||||
|
||||
assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile."
|
||||
|
||||
compile_config = self._config.compile_config
|
||||
if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"]
|
||||
and "offload_param" in self.config["zero_optimization"])
|
||||
and self._config.zero_config.offload_param.device == "cpu"
|
||||
and self._config.zero_config.offload_optimizer.device == "cpu"):
|
||||
compile_config.offload_parameters = True
|
||||
if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states:
|
||||
backend = init_z1(self, backend, compile_config, compile_kwargs, schedule)
|
||||
elif self.zero_optimization_stage() == ZeroStageEnum.weights:
|
||||
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule)
|
||||
|
||||
# create new dict to avoid modifying original dict
|
||||
self.module.compile(**{**compile_kwargs, 'backend': backend})
|
||||
|
||||
self._is_compiled = True
|
||||
|
||||
def get_compile_time(self):
|
||||
from deepspeed.compile.backend import opt_pass_times
|
||||
return opt_pass_times
|
||||
|
||||
def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None:
|
||||
register_compile_pass(pass_name, pass_fn)
|
||||
|
||||
def is_deepcompile_enabled(self):
|
||||
return self._config.compile_config.deepcompile
|
||||
|
||||
@property
|
||||
def is_compiled(self) -> bool:
|
||||
return self._is_compiled
|
||||
|
@ -73,6 +73,8 @@ class ZeROOrderedDict(OrderedDict):
|
||||
|
||||
def _inject_parameters(module, cls):
|
||||
for module in module.modules():
|
||||
module._original_parameters = module._parameters
|
||||
|
||||
if cls == ZeROOrderedDict:
|
||||
new_param = cls(parent_module=module)
|
||||
else:
|
||||
@ -80,6 +82,7 @@ def _inject_parameters(module, cls):
|
||||
|
||||
for key, param in module._parameters.items():
|
||||
new_param[key] = param
|
||||
|
||||
module._parameters = new_param
|
||||
|
||||
|
||||
@ -232,6 +235,8 @@ class DeepSpeedZeRoOffload(object):
|
||||
for hook in self.backward_hooks:
|
||||
hook.remove()
|
||||
|
||||
self.fwd_pre_hook.remove()
|
||||
|
||||
print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}',
|
||||
force=False)
|
||||
|
||||
@ -244,7 +249,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
|
||||
self.get_param_coordinator().reset_step()
|
||||
|
||||
self.module.register_forward_pre_hook(_start_of_forward_hook)
|
||||
self.fwd_pre_hook = self.module.register_forward_pre_hook(_start_of_forward_hook)
|
||||
|
||||
#likely one of them should be enough but just to be safe
|
||||
self._register_deepspeed_module(self.module)
|
||||
@ -287,7 +292,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
count[0] = count[0] + 1
|
||||
self._register_deepspeed_module(child, count=count)
|
||||
|
||||
@instrument_w_nvtx
|
||||
@torch.compiler.disable
|
||||
def _pre_forward_module_hook(module, *args):
|
||||
self.pre_sub_module_forward_function(module)
|
||||
|
||||
@ -365,6 +370,7 @@ class DeepSpeedZeRoOffload(object):
|
||||
return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function,
|
||||
_run_after_backward_hook, inputs)
|
||||
|
||||
@torch.compiler.disable
|
||||
def _post_backward_module_hook(module, inputs):
|
||||
if not hasattr(module, "ds_grads_remaining"):
|
||||
module.ds_grads_remaining = 0
|
||||
|
@ -556,8 +556,9 @@ class InsertPostInitMethodToModuleSubClasses(object):
|
||||
print_rank_0(
|
||||
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
|
||||
force=False)
|
||||
self.linear_bk = torch.nn.functional.linear
|
||||
torch.nn.functional.linear = zero3_linear_wrap
|
||||
if not hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"):
|
||||
InsertPostInitMethodToModuleSubClasses.linear_bk = torch.nn.functional.linear
|
||||
torch.nn.functional.linear = zero3_linear_wrap
|
||||
|
||||
if self.quantized_initialization:
|
||||
print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False)
|
||||
|
@ -1511,7 +1511,8 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
|
||||
# free the gradient
|
||||
if not get_accelerator().is_synchronized_device():
|
||||
param.grad.record_stream(get_accelerator().current_stream())
|
||||
if param.grad is not None:
|
||||
param.grad.record_stream(get_accelerator().current_stream())
|
||||
param.grad = None
|
||||
|
||||
if self.offload_optimizer and self.swap_optimizer:
|
||||
|
32
op_builder/dc.py
Normal file
32
op_builder/dc.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
from .builder import TorchCPUOpBuilder
|
||||
|
||||
|
||||
class DeepCompileBuilder(TorchCPUOpBuilder):
|
||||
BUILD_VAR = "DS_BUILD_DEEP_COMPILE"
|
||||
NAME = "dc"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name=self.NAME)
|
||||
|
||||
def absolute_name(self):
|
||||
return f'deepspeed.ops.{self.NAME}_op'
|
||||
|
||||
def sources(self):
|
||||
return [
|
||||
'csrc/compile/deepcompile.cpp', 'csrc/compile/init.cpp', 'csrc/compile/z1.cpp', 'csrc/compile/z3.cpp',
|
||||
'csrc/compile/util.cpp'
|
||||
]
|
||||
|
||||
def libraries_args(self):
|
||||
args = super().libraries_args()
|
||||
return args
|
||||
|
||||
def include_paths(self):
|
||||
import os
|
||||
import torch
|
||||
return ['csrc/includes', os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")]
|
1
requirements/requirements-deepcompile.txt
Normal file
1
requirements/requirements-deepcompile.txt
Normal file
@ -0,0 +1 @@
|
||||
scipy
|
1
setup.py
1
setup.py
@ -91,6 +91,7 @@ extras_require = {
|
||||
'inf': fetch_requirements('requirements/requirements-inf.txt'),
|
||||
'sd': fetch_requirements('requirements/requirements-sd.txt'),
|
||||
'triton': fetch_requirements('requirements/requirements-triton.txt'),
|
||||
'deepcompile': fetch_requirements('requirements/requirements-deepcompile.txt'),
|
||||
}
|
||||
|
||||
# Only install pynvml on nvidia gpus.
|
||||
|
@ -66,3 +66,49 @@ class TestZeRO(DistributedTest):
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
compare_loss(self, config_dict, dtype)
|
||||
|
||||
|
||||
class TestDeepCompile(DistributedTest):
|
||||
world_size = 2
|
||||
non_daemonic_procs = True
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize('zero_stage', [1, 3])
|
||||
@pytest.mark.parametrize('deepcompile', [True]) # deepcompile==False is included in test_compile_zero
|
||||
def test(self, zero_stage, dtype, deepcompile):
|
||||
if not required_torch_version(min_version=2.6):
|
||||
pytest.skip("DeepCompile requires PyTorch >= v2.6")
|
||||
|
||||
if dtype == torch.bfloat16:
|
||||
skip_on_arch(min_arch=8)
|
||||
if dtype == torch.bfloat16 and not bf16_required_version_check():
|
||||
pytest.skip(
|
||||
"DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly"
|
||||
)
|
||||
if get_accelerator().device_name() == "cpu":
|
||||
pytest.skip("CPU does not support this test yet")
|
||||
|
||||
config_dict = {
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"steps_per_print": 1,
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"params": {
|
||||
"lr": 0.00015
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": zero_stage,
|
||||
},
|
||||
"compile": {
|
||||
"deepcompile": deepcompile
|
||||
}
|
||||
}
|
||||
|
||||
if dtype == torch.float16:
|
||||
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif dtype == torch.bfloat16:
|
||||
config_dict["bf16"] = {"enabled": True}
|
||||
|
||||
# Need warmup steps
|
||||
compare_loss(self, config_dict, dtype, iteration=10)
|
||||
|
@ -70,8 +70,7 @@ def enable_determinism(seed: int):
|
||||
|
||||
|
||||
@enable_determinism(123)
|
||||
def compare_loss(self, config, dtype):
|
||||
iteration = 5
|
||||
def compare_loss(self, config, dtype, iteration=5):
|
||||
hidden_dim = 10
|
||||
RTOL = 5e-1
|
||||
ATOL = 1e-2
|
||||
@ -116,9 +115,12 @@ def compare_loss(self, config, dtype):
|
||||
baseline_engine.backward(baseline_loss)
|
||||
target_engine.backward(target_loss)
|
||||
|
||||
baseline_optimizer.step()
|
||||
target_optimizer.step()
|
||||
baseline_engine.step()
|
||||
target_engine.step()
|
||||
|
||||
with GatheredParameters(target_engine.parameters()):
|
||||
for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()):
|
||||
assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL)
|
||||
|
||||
baseline_engine.destroy()
|
||||
target_engine.destroy()
|
||||
|
Reference in New Issue
Block a user