DeepNVMe update (#7215)

- FastPersist
- ZeRO-Inference+SGLang

---------

Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com>
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: jerryyangli <jerryyangli@gmail.com>
Co-authored-by: Yang Li <yangli2@microsoft.com>
Co-authored-by: Guanhua Wang <alexwgh333@gmail.com>
Co-authored-by: Connor Holmes <connorholmes@microsoft.com>
Co-authored-by: Bing Xie <67908712+xiexbing@users.noreply.github.com>
Co-authored-by: cassieesvelt <73311224+cassieesvelt@users.noreply.github.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: swli <47371259+lucasleesw@users.noreply.github.com>
Co-authored-by: Cheng Li <pistasable@gmail.com>
Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: Ubuntu <jomayeri@microsoft.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
This commit is contained in:
Olatunji Ruwase
2025-06-06 18:49:41 -04:00
committed by GitHub
parent cb3ad0c176
commit 24a1d8f936
107 changed files with 3484 additions and 920 deletions

View File

@ -0,0 +1,137 @@
<div align="center">
# DeepNVMe: Affordable I/O scaling for Deep Learning Applications.
</div>
# Introduction
We introduced [DeepNVMe](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepnvme/08-2024/README.md) in summer 2024 as a suite of optimizations for tackling I/O bottlenecks in Deep Learning (DL). DeepNVMe delivers significant speedups for I/O bound DL workloads by leveraging storage innovations including local NVMe SSDs, NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS), and Linux Asynchronous I/O (AIO).
In this update, we are delighted to announce DeepNVMe improvements on multiple fronts: (i) expanding application coverage to FastPersist model checkpointing and SGLang inference, (ii) I/O performance scaling by upgrading from PCIe Gen4 to Gen5 NVMe SSDs, and (iii) expanding usability to CPU-only environments, offset-based I/O operations, and tensor data type casting. The results reported in this blog are available in DeepSpeed versions >= [0.17.1](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.17.1).
# Evaluation environments
Our experiments are conducted on Azure [ND-H200-v5](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nd-h200-v5-series?tabs=sizebasic) VM. The key software configurations are summarized in the following table.
|Software | Version
|---|--|
|Ubuntu | 24.04.2|
|PyTorch | 2.6.0|
|CUDA | 12.6 |
SGLang | 0.4.4.post4 |
# Addressing I/O Bottlenecks of Deep Learning
We used DeepNVMe to develop FastPersist and ZeRO-Inference to target I/O bottlenecks in DL training and inference respectively. Our experiments are conducted using a single VM, in which we combine the available NVMe SSDs into a single RAID-0 (i.e., disk striping) volume to leverage aggregate read and write bandwidths. Since DeepNVMe can offload tensors using CPU bounce buffers (a.k.a., AIO), or NVIDIA GPUDirect Storage (a.k.a., GDS), we report results for both modes.
## FastPersist: Faster Model Checkpoint Creation
Although saving model checkpoints to persistent storage is critical in model training, it is also a major bottleneck due to the inefficiencies of existing approaches. We developed [FastPersist](https://arxiv.org/abs/2406.13768) to address the performance challenges of checkpointing. FastPersist makes checkpointing overheads negligible during training through three key techniques: (i) DeepNVMe, (ii) data parallelism, and (iii) overlapping I/O and computation.
Our goal here is to demonstrate the impact of DeepNVMe in FastPersist using single-process micro-benchmarks (available [here](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/deepnvme/fastpersist)) which serialize a model checkpoint state from HBM to local NVMe. We use the popular PyTorch `torch.save()` as the baseline in our experiments, and integrate FastPersist into `torch.save()` to simplify adoption and performance comparisons.
### Faster Saving of PyTorch Models to local NVMe Storage
We measure the throughput of serializing Phi-3-Mini checkpoint state from HBM to local NVMe storage. The results are summarized in the Figure below. We observe significantly faster checkpointing with FastPersist compared to the baseline. We see speedups of over 20X in the 8xGen5 NVMe settings. We also observe FastPersist scaling with increased NVMe bandwidth of 8xGen5 compared with 4xGen5.
<img src="./media/fastpersist_phi3_mini.png">
<div align="center">
FastPersist provides significantly faster model checkpointing to local NVMe.
</div>
## ZeRO-Inference: Democratizing Generative AI
[ZeRO-Inference](https://github.com/deepspeedai/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) is a technology that democratizes access to state-of-the-art models by reducing the GPU costs of model inference. ZeRO-Inference enables inference computations of massive models (hundreds-of-billions of parameters) on as few as one GPU by offloading the model weights to DRAM and NVMe storage. ZeRO-Inference is designed for offline or throughput-oriented inference scenarios. In this blog, we share two updates on ZeRO-Inference. First, we have integrated ZeRO-Inference into SGLang, a state-of-the-art model serving framework. Second, we observed ZeRO-Inference performance scales with the faster NVMe SSDs in the latest Azure SKUs.
### Democratizing SGLang through ZeRO-Inference integration
[SGLang](https://docs.sglang.ai/) is a state-of-the-art serving framework for large language models (LLMs) and vision language models (VLMs). Our integration of ZeRO-Inference into SGLang makes SGLang available to budget-constrained users, and offers a cost-reduction option to existing SGLang users. We used SGLang's [offline benchmarking tool](https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_offline_throughput.py) to measure the generation throughput of LLAMA3-70B on a single H200 with NVMe offloading (LLAMA3-70B cannot fit in the 141GB VRAM without offloading). The experiment is configured with prompt length of 512, generation length of 32, and batch size of 128. We summarize the results in the figure below for both AIO and GDS offloading.
<img src="./media/sg_zinf_llama_70b.png">
<div align="center">
ZeRO-Inference improves SGLang inference with NVMe offloading to reduce hardware costs.
</div>
### Scaling HF Transformer Generation with Faster NVMe SSDs
ZeRO-Inference enhances HF Transformer inference with efficient model offloading to DRAM or NVMe. We previously [evaluated](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md#high-performance-offloading-via-nvme-scaling) LLAMA-3-70B generation performance with NVMe offloading on a single GPU and four Gen4 NVMes in an Azure [NC_A100_v4](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizebasic) VM. We measured the generation speed for a prompt of 512 tokens, output of 32 tokens, and batch size 96. Since NVMe bandwidth was the main bottleneck, we repeat the experiments on Azure ND-H200-v5 offering Gen5 NVMes. The results summarized in the Figure below show that ZeRO-Inference uses the increased NVMe bandwidths to improve generation speeds. For example, with GDS, generation speed improves from 7 tokens/sec with four Gen4 NVMes to 17 tokens/sec with four Gen5 NVMes, and further to 26 tokens/sec with eight Gen5 NVMes. We observe similar improvements without GDS. These results show that ZeRO-Inference performance can be improved in cost-effective manner by increasing NVMe bandwidths.
<img src="./media/hf_zinf_llama_70b.png">
<div align="center">
ZeRO-Inference leverages available NVMe bandwidth to scale LLAMA-3-70B generation.
</div>
# I/O performance scaling
We used our `ds_io` benchmarking tool to demonstrate DeepNVMe proportionally scaling I/O performance with available NVMe bandwidths. This empowers users to accelerate I/O bound DL applications at modest cost using more or faster NVMe SSDs. In our experiments, we measure the achieved read and write bandwidths of 1GB data transfers between HBM and NVMes. We evaluate scaling up NVMes from PCIe Gen4 to Gen5, and scaling out from 4 to 8 SSDs. The SSDs are combined into a single RAID-0 (disk striping) volume. We summarize the results in the Figure below which show that DeepNVMe scales I/O performance on both dimensions. Scaling up from 4xGen4 SSDs to 4xGen5 SSDs improves reads from 10GB/sec to 27GB/sec, and writes from 5GB/sec to 11GB/sec. Scaling out from 4xGen5 to 8xGen5 further improves reads to 48GB/sec, and writes to 26GB/sec.
<img src="./media/dnvme_scaling.png">
<div align="center">
Microbenchmark shows DeepNVMe scales I/O performance with available NVMe bandwidth
</div>
# Broadening usability
We have increased the usage scenarios of DeepNVMe by removing restrictions regarding hardware environments and I/O operations, as explained below.
## CPU-Only environments
Although GPUs (and similar accelerators) dominate DL, CPUs are still used in important machine learning (ML) workloads such as recommendation systems. However, DeepNVMe was previously unusable in CPU-only environments. This was because DeepNVMe relied on `torch.pin_memory()` for page-locked CPU tensors, whereas `torch.pin_memory()` does not work in the CPU versions of `torch` as illustrated below.
```bash
>>> import torch
>>> torch.__version__
'2.6.0+cpu'
>>> x = torch.empty(1024).pin_memory()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Cannot access accelerator device when none is available.
>>>
```
We have made DeepNVMe usable in CPU environments by adding mechanisms for allocating (`new_cpu_locked_tensor()`) and releasing (`free_cpu_locked_tensor()`) page-locked CPU tensors. The snippet below illustrates allocating a pinned CPU tensor (`x`).
```bash
>> import torch
>>> torch.__version__
'2.6.0+cpu'
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> x = h.new_cpu_locked_tensor(1024, torch.Tensor())
>>> x.shape
torch.Size([1024])
>>> x.dtype
torch.float32
```
## Offset-based I/O operations
Previously, DeepNVMe functionality was restricted to reading or writing the entire contents of a file. We have now improved DeepNVMe to read or write a user-specified portion of file content from a given offset. In particular, we have extended the existing read/write APIs to accept a user-specified `file offset` argument (with default value 0) such as below:
```bash
>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> help(AsyncIOBuilder().load().aio_handle().pread)
Help on method pread in module async_io:
pread(...) method of async_io.aio_handle instance
pread(self: async_io.aio_handle, buffer: torch.Tensor, filename: str, validate: bool, async: bool, file_offset: int = 0) -> int
```
## Tensor data type casting
While developing FastPersist, we needed to manipulate model tensors, typically of floating point data types, in byte format for both performance and convenience of I/O operations. However, we could not find a zero-copy mechanism for casting tensors from arbitrary data types to a byte data type (i.e., torch.uint8), so we decided to create one. This functionality is available via the `UtilsBuilder` op as demonstrated in the example below. In the example, we cast a `torch.bfloat16` tensor into `torch.uint8`. Note that due to the zero-copy nature of the functionality, `bf16_tensor` and `byte_tensor` are aliases.
```
>>> import torch
>>> from deepspeed.ops.op_builder import UtilsBuilder
>>> util_ops = UtilsBuilder().load()
>>> bf16_tensor = torch.zeros(1024, dtype=torch.bfloat16, device='cuda')
>>> bf16_tensor
tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0', dtype=torch.bfloat16)
>>> byte_tensor = util_ops.cast_to_byte_tensor(bf16_tensor)
>>> byte_tensor
tensor([0, 0, 0, ..., 0, 0, 0], device='cuda:0', dtype=torch.uint8)
>>> bf16_tensor += 1.0
>>> bf16_tensor
tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16)
>>> byte_tensor
tensor([128, 63, 128, ..., 63, 128, 63], device='cuda:0',
dtype=torch.uint8)
```
# Summary
This blog post has provided updates on our continued development of DeepNVMe, an I/O optimization technology for accelerating DL applications. We have announced DeepNVMe improvements on multiple aspects, including application coverage, I/O performance scaling, and usability.
# Acknowledgements
This blog describes work done by Joe Mayer, Logan Adams, and Olatunji Ruwase of the DeepSpeed team at Microsoft.

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

View File

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 31 KiB

View File

Before

Width:  |  Height:  |  Size: 39 KiB

After

Width:  |  Height:  |  Size: 39 KiB

View File

Before

Width:  |  Height:  |  Size: 43 KiB

After

Width:  |  Height:  |  Size: 43 KiB

View File

Before

Width:  |  Height:  |  Size: 46 KiB

After

Width:  |  Height:  |  Size: 46 KiB

View File

@ -101,7 +101,7 @@ int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector<struct iocb*>*
return actual_n_iocbs;
}
int get_file_size(const char* filename, int64_t& size)
int64_t get_file_size(const char* filename, int64_t& size)
{
struct stat st;
if (stat(filename, &st) == -1) { return -1; }
@ -109,6 +109,14 @@ int get_file_size(const char* filename, int64_t& size)
return 0;
}
int64_t get_fd_file_size(const int fd, int64_t& size)
{
struct stat st;
if (fstat(fd, &st) == -1) { return -1; }
size = st.st_size;
return 0;
}
void* ds_page_aligned_alloc(const int64_t size, const bool lock)
{
void* ptr;

View File

@ -78,4 +78,5 @@ struct io_prep_generator {
void* ds_page_aligned_alloc(const int64_t size, const bool lock = false);
int get_file_size(const char* filename, int64_t& size);
int64_t get_file_size(const char* filename, int64_t& size);
int64_t get_fd_file_size(const int fd, int64_t& size);

View File

@ -11,20 +11,19 @@ io_op_desc_t::io_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate,
const int64_t file_offset)
: _read_op(read_op),
_buffer(buffer),
_fd(fd),
_filename(filename),
_file_num_bytes(file_num_bytes),
_filename((filename == nullptr) ? std::string() : filename),
_file_offset(file_offset),
_intra_op_parallelism(intra_op_parallelism),
_num_bytes_per_thread(static_cast<int64_t>(buffer.nbytes()) / intra_op_parallelism),
_validate(validate)
{
if (validate) { assert(nullptr != filename); }
}
char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); }

View File

@ -13,8 +13,7 @@ struct io_op_desc_t {
const bool _read_op;
torch::Tensor _buffer;
int _fd;
const std::string _filename;
const int64_t _file_num_bytes;
std::string _filename;
const int _intra_op_parallelism;
const int64_t _num_bytes_per_thread;
torch::Tensor _contiguous_buffer;
@ -25,7 +24,6 @@ struct io_op_desc_t {
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate,
const int64_t file_offset);

View File

@ -9,23 +9,15 @@
using namespace std;
cpu_op_desc_t::cpu_op_desc_t(
const std::unique_ptr<struct deepspeed_pin_tensor_t>& pinned_tensor_mgr,
const bool read_op,
const torch::Tensor& buffer,
const std::unique_ptr<struct deepspeed_pin_tensor_t>& pinned_tensor_mgr,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate,
const int64_t file_offset)
: io_op_desc_t(read_op,
buffer,
fd,
filename,
file_num_bytes,
intra_op_parallelism,
validate,
file_offset),
: io_op_desc_t(read_op, buffer, fd, filename, intra_op_parallelism, validate, file_offset),
_cpu_buffer(buffer),
_pinned_tensor_mgr(pinned_tensor_mgr),
_is_managed_bounce_buffer(false)
@ -66,7 +58,8 @@ void cpu_op_desc_t::finish()
void cpu_op_desc_t::validate()
{
validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), _file_num_bytes);
const auto num_io_bytes = static_cast<int64_t>(_contiguous_buffer.nbytes());
validate_aio_operation(_read_op, _filename.c_str(), data_ptr(), num_io_bytes);
}
void cpu_op_desc_t::run(const int tid,

View File

@ -13,12 +13,11 @@ struct cpu_op_desc_t : io_op_desc_t {
bool _is_managed_bounce_buffer;
const std::unique_ptr<struct deepspeed_pin_tensor_t>& _pinned_tensor_mgr;
cpu_op_desc_t(const bool read_op,
cpu_op_desc_t(const std::unique_ptr<struct deepspeed_pin_tensor_t>& pinned_tensor_mgr,
const bool read_op,
const torch::Tensor& buffer,
const std::unique_ptr<struct deepspeed_pin_tensor_t>& pinned_tensor_mgr,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate,
const int64_t file_offset);

View File

@ -6,7 +6,6 @@
/*
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
#include <condition_variable>
#include <memory>
#include "deepspeed_py_io_handle.h"

View File

@ -10,10 +10,30 @@ Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
#include "deepspeed_py_io_handle.h"
#include <cstdlib>
#define O_DIRECT_ALIGNMENT 512
using namespace std;
static void _start_aio_thread(std::shared_ptr<struct deepspeed_aio_thread_t> ctxt) { ctxt->run(); }
static bool is_valid_bytes_to_read(const char* filename,
const int64_t file_offset,
const int64_t num_bytes_to_read)
{
int64_t num_file_bytes;
if (-1 == get_file_size(filename, num_file_bytes)) {
const auto error_code = errno;
report_file_error(filename, " fstat for read", error_code);
return false;
}
if ((file_offset + num_bytes_to_read) > num_file_bytes) {
std::cout << filename << ": file_offset + buffer nbytes > file bytes "
<< (file_offset + num_bytes_to_read) << " > " << num_file_bytes << std::endl;
}
assert((file_offset + num_bytes_to_read) <= num_file_bytes);
return true;
}
deepspeed_io_handle_t::deepspeed_io_handle_t(const int block_size,
const int queue_depth,
const bool single_submit,
@ -58,6 +78,11 @@ const bool deepspeed_io_handle_t::get_overlap_events() const { return _overlap_e
const int deepspeed_io_handle_t::get_intra_op_parallelism() const { return _intra_op_parallelism; }
const int deepspeed_io_handle_t::get_alignment() const
{
return _intra_op_parallelism * O_DIRECT_ALIGNMENT;
}
int deepspeed_io_handle_t::read(torch::Tensor& buffer,
const char* filename,
const bool validate,
@ -185,7 +210,7 @@ int deepspeed_io_handle_t::wait()
completed_op->finish();
close(completed_op->_fd);
if (!completed_op->_filename.empty()) { (completed_op->_fd); }
--_num_pending_ops;
++num_completed_ops;
@ -199,7 +224,8 @@ bool deepspeed_io_handle_t::_is_valid_parallel_aio_op(const bool read_op, const
const auto op_string = read_op ? "Read" : "Write";
if (num_bytes % get_intra_op_parallelism()) {
std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes
<< " not divisible by thread count = " << get_intra_op_parallelism() << std::endl;
<< " not divisible by intra op parallelism = " << get_intra_op_parallelism()
<< std::endl;
return false;
}
@ -211,45 +237,61 @@ std::shared_ptr<struct io_op_desc_t> deepspeed_io_handle_t::_create_io_op_desc(
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const bool validate,
const int64_t file_offset)
{
return std::make_shared<cpu_op_desc_t>(read_op,
return std::make_shared<cpu_op_desc_t>(_pinned_tensor_mgr,
read_op,
buffer,
_pinned_tensor_mgr,
fd,
filename,
file_num_bytes,
_intra_op_parallelism,
validate,
file_offset);
}
int deepspeed_io_handle_t::_pread(const torch::Tensor& buffer,
const int fd,
const char* filename,
const bool validate,
const bool async,
const int64_t file_offset)
{
auto scheduled_op = _create_io_op_desc(true, buffer, fd, filename, validate, file_offset);
_schedule_aio_work(scheduled_op);
if (async) { return 0; }
return wait();
}
int deepspeed_io_handle_t::pread(const torch::Tensor& buffer,
const char* filename,
const bool validate,
const bool async,
const int64_t file_offset)
{
int64_t num_file_bytes;
if (-1 == get_file_size(filename, num_file_bytes)) {
const auto error_code = errno;
report_file_error(filename, " fstat for read", error_code);
return -1;
}
// buffer can exceed file size to enable 4k alignment
const auto buffer_bytes = static_cast<int64_t>(buffer.nbytes());
assert((num_file_bytes % _intra_op_parallelism) == 0);
if (!is_valid_bytes_to_read(filename, file_offset, buffer_bytes)) { return -1; }
if (!_is_valid_parallel_aio_op(true, buffer_bytes)) { return -1; }
const auto fd = open_file(filename, true);
if (fd == -1) { return -1; }
auto scheduled_op =
_create_io_op_desc(true, buffer, fd, filename, num_file_bytes, validate, file_offset);
return _pread(buffer, fd, filename, validate, async, file_offset);
}
int deepspeed_io_handle_t::_pwrite(const torch::Tensor& buffer,
const int fd,
const char* filename,
const bool validate,
const bool async,
const int64_t file_offset)
{
auto scheduled_op = _create_io_op_desc(false, buffer, fd, filename, validate, file_offset);
_schedule_aio_work(scheduled_op);
@ -265,21 +307,13 @@ int deepspeed_io_handle_t::pwrite(const torch::Tensor& buffer,
const int64_t file_offset)
{
const auto num_write_bytes = static_cast<int64_t>(buffer.nbytes());
assert((num_write_bytes % _intra_op_parallelism) == 0);
if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }
const auto fd = open_file(filename, false);
if (fd == -1) { return -1; }
auto scheduled_op =
_create_io_op_desc(false, buffer, fd, filename, num_write_bytes, validate, file_offset);
_schedule_aio_work(scheduled_op);
if (async) { return 0; }
return wait();
return _pwrite(buffer, fd, filename, validate, async, file_offset);
}
int deepspeed_io_handle_t::sync_pread(torch::Tensor& buffer,
@ -310,6 +344,16 @@ int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer,
return pwrite(buffer, filename, false, true, file_offset);
}
int deepspeed_io_handle_t::async_pwrite(const torch::Tensor& buffer,
const int fd,
const int64_t file_offset = 0)
{
const auto num_write_bytes = static_cast<int64_t>(buffer.nbytes());
if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; }
return _pwrite(buffer, fd, nullptr, false, true, file_offset);
}
at::Tensor deepspeed_io_handle_t::new_cpu_locked_tensor(const int64_t num_elem,
const torch::Tensor& example_tensor)
{

View File

@ -37,6 +37,7 @@ struct deepspeed_io_handle_t {
const bool get_single_submit() const;
const bool get_overlap_events() const;
const int get_intra_op_parallelism() const;
const int get_alignment() const;
int read(torch::Tensor& buffer,
const char* filename,
@ -67,6 +68,7 @@ struct deepspeed_io_handle_t {
int async_pread(torch::Tensor& buffer, const char* filename, const int64_t file_offset);
int async_pwrite(const torch::Tensor& buffer, const char* filename, const int64_t file_offset);
int async_pwrite(const torch::Tensor& buffer, const int fd, const int64_t file_offset);
// TODO: Make API's args to be shape and dtype.
torch::Tensor new_cpu_locked_tensor(const int64_t num_elem,
@ -84,11 +86,24 @@ struct deepspeed_io_handle_t {
bool _is_valid_parallel_aio_op(const bool read_op, const int64_t num_bytes);
int _pread(const torch::Tensor& buffer,
const int fd,
const char* filename,
const bool validate,
const bool async,
const int64_t file_offset);
int _pwrite(const torch::Tensor& buffer,
const int fd,
const char* filename,
const bool validate,
const bool async,
const int64_t file_offset);
virtual std::shared_ptr<struct io_op_desc_t> _create_io_op_desc(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const bool validate,
const int64_t file_offset);
};

View File

@ -6,7 +6,6 @@
/*
Functionality for swapping optimizer tensors to/from (NVMe) storage devices.
*/
#include <torch/extension.h>
#include "deepspeed_py_aio_handle.h"
#include "deepspeed_py_copy.h"
@ -34,6 +33,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit)
.def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events)
.def("get_intra_op_parallelism", &deepspeed_aio_handle_t::get_intra_op_parallelism)
.def("get_alignment", &deepspeed_aio_handle_t::get_alignment)
.def("read",
&deepspeed_aio_handle_t::read,
@ -53,7 +53,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("pread",
&deepspeed_aio_handle_t::pread,
"Parallel file read with option of parallelism. Returns count of completed read ops",
"Parallel file read with option of asynchronous completion. If synchronous, returns "
"count of completed read ops",
"buffer"_a,
"filename"_a,
"validate"_a,
@ -62,7 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("pwrite",
&deepspeed_aio_handle_t::pwrite,
"Parallel file write with option of parallelism. Returns count of completed write ops",
"Parallel file write with option of asynchronous completion. If synchronous, returns "
"count of completed write ops",
"buffer"_a,
"filename"_a,
"validate"_a,
@ -71,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("sync_pread",
&deepspeed_aio_handle_t::sync_pread,
"Synchrononous parallel file read. Returns count of completed read ops",
"Synchronous parallel file read. Returns count of completed read ops",
"buffer"_a,
"filename"_a,
"file_offset"_a = 0)
@ -86,17 +88,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("async_pread",
&deepspeed_aio_handle_t::async_pread,
"Asynchronous parallel file read. Returns 0 on success. Returns 0 on success, and "
"following wait() returns count of completed ops.",
"subsequent wait() returns count of completed ops.",
"buffer"_a,
"filename"_a,
"file_offset"_a = 0)
.def(
"async_pwrite",
py::overload_cast<const torch::Tensor&, const char*, const int64_t>(
&deepspeed_aio_handle_t::async_pwrite),
"Asynchronous parallel file write. Returns 0 on success, and subsequent wait() returns "
"count of completed ops.",
"buffer"_a,
"filename"_a,
"file_offset"_a = 0)
.def("async_pwrite",
&deepspeed_aio_handle_t::async_pwrite,
"Asynchronous parallel file write. Returns 0 on success, and following wait() returns "
"count of completed ops.",
py::overload_cast<const torch::Tensor&, const int, const int64_t>(
&deepspeed_aio_handle_t::async_pwrite),
"Asynchronous parallel file write using opened python file object.",
"buffer"_a,
"filename"_a,
"fd"_a,
"file_offset"_a = 0)
.def("new_cpu_locked_tensor",

View File

@ -17,7 +17,7 @@ from perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
READ_LOG_DIR, WRITE_LOG_DIR
from deepspeed.ops.op_builder import AsyncIOBuilder
OTHER_OPTIONS = '--handle'
OTHER_OPTIONS = '--engine aio_handle'
PERF_SCRIPT = 'test_ds_aio.py'
DEFAULT_SWEEP_CONFIG = {
"block_size": ["128K", "1M"],
@ -109,6 +109,20 @@ def get_sweep_config_dict(sweep_config_json):
return sweep_config
QUEUE_DEPTH = "--queue_depth"
BLOCK_SIZE = "--block_size"
SINGLE_SUBMIT = "--single_submit"
SEQUENTIAL_REQUESTS = "--sequential_requests"
THREAD_COUNT = "--threads"
IO_PARALLEL = "--io_parallel"
DEPRECATED_KEYS = {THREAD_COUNT: "multi_process"}
def _handle_key_deprecation(key):
return DEPRECATED_KEYS.get(f'--{key}', key)
def get_sweep_cmd_lines(sweep_config_dict):
def flatten_options(key, value_list):
@ -123,7 +137,7 @@ def get_sweep_cmd_lines(sweep_config_dict):
return flat_list
flat_list = [flatten_options(key, value) for key, value in sweep_config_dict.items()]
flat_list = [flatten_options(_handle_key_deprecation(key), value) for key, value in sweep_config_dict.items()]
cmd_list = list(itertools.product(*flat_list))
cmd_list = [list(cmd) for cmd in cmd_list]
#dump_cmd_lines(cmd_list)

View File

@ -0,0 +1,21 @@
python test_ds_aio.py \
--read \
--handle --io_size 400M \
--loops 3 \
--folder_to_device_mapping \
/mnt/nvme23/aio:0 \
/mnt/nvme23/aio:1 \
/mnt/nvme23/aio:2 \
/mnt/nvme23/aio:3 \
/mnt/nvme45/aio:4 \
/mnt/nvme45/aio:5 \
/mnt/nvme45/aio:6 \
/mnt/nvme45/aio:7 \
/mnt/nvme67/aio:8 \
/mnt/nvme67/aio:9 \
/mnt/nvme67/aio:10 \
/mnt/nvme67/aio:11 \
/mnt/nvme89/aio:12 \
/mnt/nvme89/aio:13 \
/mnt/nvme89/aio:14 \
/mnt/nvme89/aio:15 \

View File

@ -0,0 +1,20 @@
python test_ds_aio.py \
--handle --io_size 400M \
--loops 3 \
--folder_to_device_mapping \
/mnt/nvme23/aio:0 \
/mnt/nvme23/aio:1 \
/mnt/nvme23/aio:2 \
/mnt/nvme23/aio:3 \
/mnt/nvme45/aio:4 \
/mnt/nvme45/aio:5 \
/mnt/nvme45/aio:6 \
/mnt/nvme45/aio:7 \
/mnt/nvme67/aio:8 \
/mnt/nvme67/aio:9 \
/mnt/nvme67/aio:10 \
/mnt/nvme67/aio:11 \
/mnt/nvme89/aio:12 \
/mnt/nvme89/aio:13 \
/mnt/nvme89/aio:14 \
/mnt/nvme89/aio:15 \

View File

@ -0,0 +1,6 @@
python test_ds_aio.py \
--read \
--handle --io_size 400M \
--loops 3 \
--folder /mnt/nvme23/aio \
--multi_process 16

View File

@ -0,0 +1,5 @@
python test_ds_aio.py \
--handle --io_size 400M \
--loops 3 \
--folder /mnt/nvme23/aio \
--multi_process 16

View File

@ -9,6 +9,7 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import argparse
import os
from test_ds_aio_utils import refine_integer_value
from ds_aio_constants import AIO_HANDLE, AIO_BASIC, TORCH_FAST_IO, TORCH_IO, VALID_ENGINES
from deepspeed.accelerator import get_accelerator
MAPPING_DELIMITER = ':'
@ -21,6 +22,9 @@ def refine_args(args):
if args.block_size and type(args.block_size) == str:
args.block_size = refine_integer_value(args.block_size)
if args.fast_io_size and type(args.fast_io_size) == str:
args.fast_io_size = refine_integer_value(args.fast_io_size)
return args
@ -83,6 +87,19 @@ def validate_args(args):
no_error = no_error and no_mapping_error
error_messages += mapping_error_messages
# Validate --engine
if args.engine not in VALID_ENGINES:
no_error = False
error_messages.append(f'Invalid engine {args.engine}. Valid options = {VALID_ENGINES}')
# Validate --engine=torch_io
if args.engine == TORCH_IO:
if args.read:
no_error = False
error_messages.append(f'Read not currently supported for --engine={TORCH_IO}')
if not no_error:
print(f'Found {len(error_messages)} validation error(s)')
# Validate --gpu, --use_gds
if args.use_gds and not args.gpu:
error_messages.append(f'--gpu must be set to transfer with --use_gds')
@ -111,6 +128,8 @@ def parse_arguments():
parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.')
parser.add_argument('--fast_io_size', type=str, default='64M', help='Size of fast_io pinned buffer (bytes).')
parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)')
parser.add_argument('--multi_process',
@ -138,7 +157,13 @@ def parse_arguments():
parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.')
parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
parser.add_argument(
'--engine',
type=str,
default=AIO_HANDLE,
help=
f'Engine to perform I/O. Options are [{AIO_HANDLE}, {AIO_BASIC}, {TORCH_IO}, {TORCH_FAST_IO}]. Default is aio_handle'
)
parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions')
@ -152,6 +177,20 @@ def parse_arguments():
action='store_true',
help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.')
parser.add_argument('--torch_legacy_save', action='store_true', help='Use torch legacy save approach')
parser.add_argument('--use_accelerator_pin_memory',
action='store_true',
help='Obtain pinned (CPU page-locked) tensors from accelerator')
parser.add_argument('--warmup_loops', type=int, default=1, help='Count of operation warmup repetitions')
parser.add_argument('--include_warmup_time', action='store_true', help='Include warmup latency in results')
parser.add_argument('--different_file_each_iteration',
action='store_true',
help='Read/write a different file on each iteration.')
args = parser.parse_args()
print(f'args = {args}')
return args
@ -163,7 +202,7 @@ def get_validated_args():
if not validate_args(args):
quit()
print(f'Successful validation of command line arguments')
args.total_loops = args.warmup_loops + args.loops
peer_tag = 'gpu' if args.gpu else 'process'
args.mapping_dict = _get_mapping_dict(args)
args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()]

View File

@ -6,129 +6,59 @@
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from test_ds_aio_utils import report_results, task_log, task_barrier
from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from ds_aio_constants import *
def pre_basic(args, tid, read_op):
class AIOBasic_Engine(object):
def __init__(self, args, tid, read_op):
self.ctxt = self._create_context(args, tid, read_op)
def fini(self):
self.ctxt[BUFFER].detach()
self.ctxt[BUFFER] = None
def read(self, args, tid, loop_id):
start_time = time.time()
AsyncIOBuilder().load().aio_read(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def write(self, args, tid, loop_id):
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(self.ctxt[FILE]):
os.remove(self.ctxt[FILE])
start_time = time.time()
AsyncIOBuilder().load().aio_write(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
file = args.read_file if read_op else f'{args.write_file}.{tid}'
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
buffer = create_page_locked_tensor(args.io_size, True)
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}')
task_log(tid, f'created deepspeed aio basic engine')
ctxt = {}
ctxt['file'] = file
ctxt['num_bytes'] = num_bytes
ctxt['buffer'] = buffer
ctxt['elapsed_sec'] = 0
ctxt[FILE] = filename
ctxt[NUM_BYTES] = args.io_size
ctxt[BUFFER] = buffer
ctxt[ELAPSED_SEC] = 0
return ctxt
def pre_basic_read(pool_params):
args, tid = pool_params
ctxt = pre_basic(args, tid, True)
return ctxt
def pre_basic_write(pool_params):
args, tid = pool_params
ctxt = pre_basic(args, tid, False)
return ctxt
def post_basic(pool_params):
_, _, ctxt = pool_params
ctxt["buffer"].detach()
ctxt["buffer"] = None
return ctxt
def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = pre_basic_read
schedule['post'] = post_basic
schedule['main'] = main_basic_read
else:
schedule['pre'] = pre_basic_write
schedule['post'] = post_basic
schedule['main'] = main_basic_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
for i in range(args.loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
def _init_tasklet(b):
global aio_barrier
aio_barrier = b
def aio_basic_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)

View File

@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
AIO_HANDLE = 'aio_handle'
AIO_BASIC = 'aio_basic'
TORCH_IO = 'torch_io'
TORCH_FAST_IO = 'torch_fastio'
VALID_ENGINES = [AIO_HANDLE, AIO_BASIC, TORCH_IO, TORCH_FAST_IO]
BUFFER = 'buffer'
BOUNCE_BUFFER = 'bounce_buffer'
NUM_BYTES = 'num_bytes'
FILE = 'file'
HANDLE = 'handle'
ELAPSED_SEC = 'elapsed_sec'
FAST_IO_BUFFER = 'fast_io_buffer'
USE_CPU_LOCKED_TENSOR = 'cpu_locked_tensor'

View File

@ -2,221 +2,105 @@
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import torch
import os
import time
from multiprocessing import Pool, Barrier
from deepspeed.ops.aio import AsyncIOBuilder
from deepspeed.ops.op_builder import GDSBuilder
from test_ds_aio_utils import report_results, task_log, task_barrier, create_filename, create_file
from deepspeed.accelerator import get_accelerator
BUFFER = 'buffer'
BOUNCE_BUFFER = 'bounce_buffer'
from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from ds_aio_constants import *
def pre_handle(args, tid, read_op):
class AIOHandle_Engine(object):
def __init__(self, args, tid, read_op):
self.ctxt = self._create_context(args, tid, read_op)
def fini(self):
for buf in [BUFFER, BOUNCE_BUFFER]:
if self.ctxt[buf] is not None:
if self.ctxt[USE_CPU_LOCKED_TENSOR]:
self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf])
self.ctxt[buf].detach()
self.ctxt[buf] = None
def read(self, args, tid, loop_id):
handle = self.ctxt[HANDLE]
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if self.ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.pread(self.ctxt[dest_buffer], self.ctxt[FILE][loop_id], args.validate, True)
assert ret != -1
handle.wait()
if dest_buffer == BOUNCE_BUFFER:
self.ctxt[BUFFER].data.copy_(self.ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
self.ctxt[ELAPSED_SEC].append(end_time - start_time)
def write(self, args, tid, loop_id):
handle = self.ctxt[HANDLE]
start_time = time.time()
if self.ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
self.ctxt[BOUNCE_BUFFER].data.copy_(self.ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.pwrite(self.ctxt[source_buffer], self.ctxt[FILE][loop_id], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
self.ctxt[ELAPSED_SEC].append(end_time - start_time)
def _create_files(self, args, folder, tid):
if args.different_file_each_iteration:
filenames = [
create_filename(folder, args.read, args.io_size, f'{tid}_{l}') for l in range(args.total_loops)
]
else:
filenames = [
create_filename(folder, args.read, args.io_size, f'{tid}_{0}') for _ in range(args.total_loops)
]
if args.read:
for f in filenames:
if not (os.path.isfile(f) and os.path.getsize(f) == args.io_size):
create_file(f, args.io_size)
else:
for f in filenames:
if os.path.isfile(f):
os.remove(f)
return filenames
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
gds = True if args.use_gds else False
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
bounce_buffer = None
if args.gpu:
device_name = get_accelerator().device_name(device_id)
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name)
if not (args.slow_bounce_buffer or gds):
bounce_buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8,
device='cpu').pin_memory()
else:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
filenames = self._create_files(args, folder, tid)
io_parallel = args.io_parallel if args.io_parallel else 1
if gds:
handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
handle.pin_device_tensor(buffer)
else:
handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
task_log(tid, f'created deepspeed aio handle')
task_log(tid, f'created deepspeed aio handle engine')
bounce_buffer = None
if args.gpu:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}')
bounce_buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle)
else:
buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
ctxt = {}
ctxt['file'] = filename
ctxt['num_bytes'] = args.io_size
ctxt['handle'] = handle
ctxt['gds'] = gds
ctxt[FILE] = filenames
ctxt[NUM_BYTES] = args.io_size
ctxt[HANDLE] = handle
ctxt[BUFFER] = buffer
ctxt[BOUNCE_BUFFER] = bounce_buffer
ctxt['elapsed_sec'] = 0
ctxt[ELAPSED_SEC] = []
ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory
task_log(tid,
f'{io_string} file {filenames} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
return ctxt
def pre_handle_read(pool_params):
args, tid = pool_params
ctxt = pre_handle(args, tid, True)
return ctxt
def pre_handle_write(pool_params):
args, tid = pool_params
ctxt = pre_handle(args, tid, False)
return ctxt
def post_handle(pool_params):
_, _, ctxt = pool_params
for buf in [BUFFER, BOUNCE_BUFFER]:
if ctxt[buf] is not None:
if ctxt['gds']:
ctxt['handle'].unpin_device_tensor(ctxt[buf])
ctxt[buf].detach()
ctxt[buf] = None
return ctxt
def main_parallel_read(pool_params):
args, tid, ctxt = pool_params
handle = ctxt['handle']
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, 0, True)
assert ret != -1
handle.wait()
if dest_buffer == BOUNCE_BUFFER:
ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_parallel_write(pool_params):
args, tid, ctxt = pool_params
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(ctxt['file']):
os.remove(ctxt['file'])
handle = ctxt['handle']
start_time = time.time()
if ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.pwrite(ctxt[source_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_handle_read(pool_parms):
args, tid, ctxt = pool_parms
handle = ctxt['handle']
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.read(ctxt[dest_buffer], ctxt['file'], args.validate)
assert ret != -1
if dest_buffer == BOUNCE_BUFFER:
ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_handle_write(pool_parms):
args, tid, ctxt = pool_parms
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(ctxt['file']):
os.remove(ctxt['file'])
handle = ctxt['handle']
start_time = time.time()
if ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.write(ctxt[source_buffer], ctxt['file'], args.validate)
assert ret != -1
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = pre_handle_read
schedule['post'] = post_handle
schedule['main'] = main_parallel_read
else:
schedule['pre'] = pre_handle_write
schedule['post'] = post_handle
schedule['main'] = main_parallel_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
for i in range(args.loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
def _init_tasklet(b):
global aio_barrier
aio_barrier = b
def aio_handle_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)

View File

@ -0,0 +1,126 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import time
from multiprocessing import Pool, Barrier
from ds_aio_constants import AIO_BASIC, TORCH_FAST_IO, TORCH_IO
from test_ds_aio_utils import report_results, task_log, task_barrier
from ds_aio_handle import AIOHandle_Engine
from ds_aio_basic import AIOBasic_Engine
from torch_io import TorchIO_Engine
from torch_fastio_engine import Torch_FastIO_Engine
def prepare_operation(args, tid, read_op):
if args.engine == TORCH_IO:
io_engine = TorchIO_Engine(args, tid, read_op)
elif args.engine == AIO_BASIC:
io_engine = AIOBasic_Engine(args, tid, read_op)
elif args.engine == TORCH_FAST_IO:
io_engine = Torch_FastIO_Engine(args, tid, read_op)
else:
io_engine = AIOHandle_Engine(args, tid, read_op)
return io_engine
def prepare_read(pool_params):
args, tid = pool_params
return prepare_operation(args, tid, True)
def prepare_write(pool_params):
args, tid = pool_params
return prepare_operation(args, tid, False)
def post_operation(pool_params):
_, _, io_engine = pool_params
io_engine.fini()
def read_operation(pool_params):
args, tid, loop_id, io_engine = pool_params
return io_engine.read(args, tid, loop_id)
def write_operation(pool_params):
args, tid, loop_id, io_engine = pool_params
return io_engine.write(args, tid, loop_id)
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = prepare_read
schedule['post'] = post_operation
schedule['main'] = read_operation
else:
schedule['pre'] = prepare_write
schedule['post'] = post_operation
schedule['main'] = write_operation
return schedule
def io_engine_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
io_engine = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
io_engine.ctxt["main_task_sec"] = []
for i in range(args.total_loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
schedule["main"]((args, tid, i, io_engine))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
io_engine.ctxt["main_task_sec"].append(stop_time - start_time)
# Run post task
task_log(tid, f'running post-task')
schedule["post"]((args, tid, io_engine))
task_barrier(aio_barrier, num_processes)
ctxt = io_engine.ctxt
# return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
if args.include_warmup_time:
e2e_latency_sec = sum(ctxt["main_task_sec"])
task_latency_sec = sum(ctxt["elapsed_sec"])
actual_loops = args.total_loops
else:
e2e_latency_sec = sum(ctxt["main_task_sec"][args.warmup_loops:])
task_latency_sec = sum(ctxt["elapsed_sec"][args.warmup_loops:])
actual_loops = args.loops
l = ctxt["elapsed_sec"]
task_log(tid, f'task_latency_sec = {l}')
return e2e_latency_sec, task_latency_sec, ctxt["num_bytes"] * actual_loops
def _init_takslet(b):
global aio_barrier
aio_barrier = b
def io_engine_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_takslet, initargs=(b, )) as p:
pool_results = p.map(io_engine_tasklet, pool_params)
report_results(args, read_op, pool_results)

View File

@ -79,9 +79,9 @@ for xtype in cpu gpu gds; do
gpu_opt="--gpu"
gds_opt="--use_gds"
fi
for sub in single block; do
if [[ $sub == "single" ]]; then
sub_opt="--single_submit"
for ov in overlap sequential; do
if [[ $ov == "sequential" ]]; then
ov_opt="--sequential_requests"
else
sub_opt=""
fi

View File

@ -25,11 +25,42 @@ function validate_environment()
validate_environment
IO_SIZE=$1
LOG_DIR=$2/aio_perf_sweep
MAP_DIR=$2/aio
GPU_MEM=$3
USE_GDS=$4
if [[ $# -ne 3 ]]; then
echo "Usage: $0 <write size in [K,M,G]> <write dir ><output log dir>"
exit 1
fi
SIZE=$1
WRITE_DIR=$2
LOG_DIR=$3/aio_perf_sweep
WRITE_OPT="--folder ${WRITE_DIR} --io_size ${SIZE} --loops 3"
IO_ENGINE="torch_fastio"
ENGINE_OPTS=""
if [[ $IO_ENGINE == "aio_handle" ]]; then
IO_PARALLEL="1" # "1 2 4 8"
QUEUE_DEPTH="8 16 32 64 128"
BLOCK_SIZE="128K 256K 512K 1M 2M 4M 8M 16M"
SUBMIT="block"
OVERLAP="overlap"
elif [[ $IO_ENGINE == "torch_fastio" ]]; then
IO_PARALLEL="1" # "1 2 4 8"
QUEUE_DEPTH="8 16 32 64 128"
BLOCK_SIZE="128K 256K 512K 1M 2M 4M 8M 16M"
SUBMIT="block"
OVERLAP="overlap"
ENGINE_OPTS="--torch_legacy --fast_io_size ${SIZE}"
else
IO_PARALLEL="1"
QUEUE_DEPTH="8"
BLOCK_SIZE="128K"
SUBMIT="single"
OVERLAP="sequential"
fi
prep_folder ${WRITE_DIR}
prep_folder ${LOG_DIR}
RUN_SCRIPT=./test_ds_aio.py
OUTPUT_FILE=${MAP_DIR}/ds_aio_write_${SIZE}B.pt
@ -54,24 +85,24 @@ fi
DISABLE_CACHE="sync; bash -c 'echo 1 > /proc/sys/vm/drop_caches' "
SYNC="sync"
for sub in single block; do
for sub in ${SUBMIT}; do
if [[ $sub == "single" ]]; then
sub_opt="--single_submit"
else
sub_opt=""
fi
for ov in overlap sequential; do
for ov in ${OVERLAP}; do
if [[ $ov == "sequential" ]]; then
ov_opt="--sequential_requests"
else
ov_opt=""
fi
for p in 1 2 4 8; do
for t in 1 2 4 8; do
for d in 32 64 128; do
for bs in 256K 512K 1M; do
SCHED_OPTS="${sub_opt} ${ov_opt} --handle ${gpu_opt} ${gds_opt} --folder ${MAP_DIR}"
OPTS="--queue_depth ${d} --block_size ${bs} --io_size ${IO_SIZE} --multi_process ${p} --io_parallel ${t}"
for p in 1; do
for t in ${IO_PARALLEL}; do
for d in ${QUEUE_DEPTH}; do
for bs in ${BLOCK_SIZE}; do
SCHED_OPTS="${sub_opt} ${ov_opt} --engine ${IO_ENGINE} --io_parallel ${t} ${ENGINE_OPTS}"
OPTS="--multi_process ${p} --queue_depth ${d} --block_size ${bs}"
LOG="${LOG_DIR}/write_${sub}_${ov}_t${t}_p${p}_d${d}_bs${bs}.txt"
cmd="python ${RUN_SCRIPT} ${OPTS} ${SCHED_OPTS} &> ${LOG}"
echo ${DISABLE_CACHE}

View File

@ -2,12 +2,17 @@
"block_size": [
"128K",
"256K",
"1M"
"1M",
"2M",
"4M",
"8M",
"16M"
],
"queue_depth": [
4,
8,
16,
32
32,
64
],
"io_parallel": [
1,
@ -19,7 +24,7 @@
true,
false
],
"overlap_events": [
"sequential_requests": [
true,
false
],

View File

@ -7,17 +7,16 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import multiprocessing as mp
from ds_aio_basic import aio_basic_multiprocessing
from ds_aio_handle import aio_handle_multiprocessing
from ds_aio_args import get_validated_args
from io_engine import io_engine_multiprocessing
def main():
print(f'Testing deepspeed_aio python frontend')
args = get_validated_args()
mp.set_start_method('spawn')
multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing
mp.set_start_method('spawn', force=True)
multiprocess_function = io_engine_multiprocessing
multiprocess_function(args, args.read)

View File

@ -8,6 +8,8 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import os
from ds_aio_job import Job, run_job
import torch
from deepspeed.accelerator import get_accelerator
BYTES_PER_GB = 1024**3
BYTES_PER_MB = 1024**2
@ -79,3 +81,11 @@ def create_file(filename, num_bytes):
print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....')
run_job(dd_job)
print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....')
def create_page_locked_tensor(num_elem, use_accelerator, aio_handle=None):
if use_accelerator:
return get_accelerator().pin_memory(torch.randint(high=128, size=(num_elem, ), dtype=torch.uint8,
device='cpu'))
else:
return aio_handle.new_cpu_locked_tensor(num_elem, torch.empty(0, dtype=torch.uint8))

View File

@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from ds_aio_constants import *
from deepspeed.io import FastFileWriter
class Torch_FastIO_Engine(object):
def __init__(self, args, tid, read_op):
assert read_op is False, f'Read operation is not currently supported'
self.ctxt = self._create_context(args, tid, read_op)
self.zipfile_serialization = not args.torch_legacy_save
def fini(self):
if self.ctxt[USE_CPU_LOCKED_TENSOR]:
for buf in [BUFFER, FAST_IO_BUFFER]:
self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf])
self.ctxt[BUFFER].detach()
self.ctxt[BUFFER] = None
def read(self, args, tid):
start_time = time.time()
torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def write(self, args, tid):
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(self.ctxt[FILE]):
os.remove(self.ctxt[FILE])
ds_file_writer = FastFileWriter(file_path=self.ctxt[FILE],
aio_handle=self.ctxt[HANDLE],
pinned_tensor=self.ctxt[FAST_IO_BUFFER])
start_time = time.time()
torch.save(obj=self.ctxt[BUFFER], f=ds_file_writer, _use_new_zipfile_serialization=self.zipfile_serialization)
ds_file_writer.close() # Force flush to storage
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
ds_file_writer._dump_state()
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
io_parallel = args.io_parallel if args.io_parallel else 1
aio_handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
if args.gpu:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}')
else:
buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, aio_handle)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
fast_io_buffer = create_page_locked_tensor(args.fast_io_size, args.use_accelerator_pin_memory, aio_handle)
task_log(tid, f'created torch_fastio engine')
ctxt = {}
ctxt[FILE] = filename
ctxt[NUM_BYTES] = args.io_size
ctxt[BUFFER] = buffer
ctxt[HANDLE] = aio_handle
ctxt[FAST_IO_BUFFER] = fast_io_buffer
ctxt[ELAPSED_SEC] = 0
ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
return ctxt

View File

@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
import time
from test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from ds_aio_constants import *
class TorchIO_Engine(object):
def __init__(self, args, tid, read_op):
self.ctxt = self._create_context(args, tid, read_op)
self.zipfile_serialization = not args.torch_legacy_save
def fini(self):
self.ctxt[BUFFER].detach()
self.ctxt[BUFFER] = None
def read(self, args, tid):
start_time = time.time()
torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def write(self, args, tid):
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(self.ctxt[FILE]):
os.remove(self.ctxt[FILE])
start_time = time.time()
torch.save(obj=self.ctxt[BUFFER], f=self.ctxt[FILE], _use_new_zipfile_serialization=self.zipfile_serialization)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
if args.gpu:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}')
else:
buffer = create_page_locked_tensor(args.io_size, True)
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
task_log(tid, f'created torch_io engine')
ctxt = {}
ctxt[FILE] = filename
ctxt[NUM_BYTES] = args.io_size
ctxt[BUFFER] = buffer
ctxt[ELAPSED_SEC] = 0
return ctxt

View File

@ -0,0 +1,15 @@
#!/bin/bash
MOUNT_CMD="sudo mount -v -o data=ordered"
for dir in nvme23 nvme45 nvme67 nvme89; do
mnt_point=/mnt/${dir}
sudo mkdir -p ${mnt_point}
sudo chmod -R a+rw ${mnt_point}
done
${MOUNT_CMD} /dev/md127 /mnt/nvme23
${MOUNT_CMD} /dev/md126 /mnt/nvme45
${MOUNT_CMD} /dev/md125 /mnt/nvme67
${MOUNT_CMD} /dev/md124 /mnt/nvme89
lsblk -f

View File

@ -0,0 +1,10 @@
#!/bin/bash
UMOUNT_CMD="sudo umount -v"
for md in md127 md126 md125 md124; do
mnt_device=/dev/${md}
${UMOUNT_CMD} ${mnt_device}
done
lsblk -f

View File

@ -93,18 +93,10 @@ gds_op_desc_t::gds_op_desc_t(const bool read_op,
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate,
const int64_t file_offset)
: io_op_desc_t(read_op,
buffer,
fd,
filename,
file_num_bytes,
intra_op_parallelism,
validate,
file_offset)
: io_op_desc_t(read_op, buffer, fd, filename, intra_op_parallelism, validate, file_offset)
{
_contiguous_buffer = _buffer.contiguous();
const int64_t device = _buffer.get_device();
@ -122,8 +114,9 @@ void gds_op_desc_t::validate()
{
check_cudaruntimecall(cudaSetDevice(_buffer.get_device()));
const auto cpu_buffer = _buffer.to(torch::kCPU);
const auto num_io_bytes = static_cast<int64_t>(_contiguous_buffer.nbytes());
validate_aio_operation(
_read_op, _filename.c_str(), (char*)(cpu_buffer.data_ptr()), _file_num_bytes);
_read_op, _filename.c_str(), (char*)(cpu_buffer.data_ptr()), num_io_bytes);
}
void gds_op_desc_t::run(const int tid,
@ -155,7 +148,7 @@ void gds_op_desc_t::_report_error(const ssize_t return_code,
const auto error_code = IS_CUFILE_ERR(return_code) ? cuFileGetErrorString(return_code)
: cuFileGetErrorString(error_num);
std::cerr << op_string << error_string << error_code << " return code = " << return_code
<< " filename = " << _filename.c_str() << " num bytes = " << _num_bytes_per_thread
<< " filename = " << _filename << " num bytes = " << _num_bytes_per_thread
<< " offset = " << offset << std::endl;
exit(EXIT_FAILURE);
}

View File

@ -22,7 +22,6 @@ struct gds_op_desc_t : io_op_desc_t {
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const int intra_op_parallelism,
const bool validate,
const int64_t file_offset);

View File

@ -106,20 +106,13 @@ std::shared_ptr<struct io_op_desc_t> deepspeed_gds_handle_t::_create_io_op_desc(
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const bool validate,
const int64_t file_offset)
{
if (buffer.is_cuda()) {
return std::make_shared<gds_op_desc_t>(read_op,
buffer,
fd,
filename,
file_num_bytes,
_intra_op_parallelism,
validate,
file_offset);
return std::make_shared<gds_op_desc_t>(
read_op, buffer, fd, filename, _intra_op_parallelism, validate, file_offset);
}
return deepspeed_io_handle_t::_create_io_op_desc(
read_op, buffer, fd, filename, file_num_bytes, validate, file_offset);
read_op, buffer, fd, filename, validate, file_offset);
}

View File

@ -41,7 +41,6 @@ struct deepspeed_gds_handle_t : deepspeed_io_handle_t {
const torch::Tensor& buffer,
const int fd,
const char* filename,
const int64_t file_num_bytes,
const bool validate,
const int64_t file_offset);

View File

@ -27,6 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.def("get_single_submit", &deepspeed_gds_handle_t::get_single_submit)
.def("get_overlap_events", &deepspeed_gds_handle_t::get_overlap_events)
.def("get_intra_op_parallelism", &deepspeed_gds_handle_t::get_intra_op_parallelism)
.def("get_alignment", &deepspeed_gds_handle_t::get_alignment)
.def("read",
&deepspeed_gds_handle_t::read,
@ -84,14 +85,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"filename"_a,
"file_offset"_a = 0)
.def("async_pwrite",
&deepspeed_gds_handle_t::async_pwrite,
"Asynchronous parallel file write. Returns 0 on success, and following wait() returns "
.def(
"async_pwrite",
py::overload_cast<const torch::Tensor&, const char*, const int64_t>(
&deepspeed_gds_handle_t::async_pwrite),
"Asynchronous parallel file write. Returns 0 on success, and subsequent wait() returns "
"count of completed ops.",
"buffer"_a,
"filename"_a,
"file_offset"_a = 0)
.def("async_pwrite",
py::overload_cast<const torch::Tensor&, const int, const int64_t>(
&deepspeed_gds_handle_t::async_pwrite),
"Asynchronous parallel file write using opened python file object.",
"buffer"_a,
"fd"_a,
"file_offset"_a = 0)
.def("new_cpu_locked_tensor",
&deepspeed_gds_handle_t::new_cpu_locked_tensor,
"Allocate pinned CPU tensor.",

View File

@ -1,29 +0,0 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/extension.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at::Tensor flatten(std::vector<at::Tensor> tensors)
{
return torch::utils::flatten_dense_tensors(tensors);
}
std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
{
return torch::utils::unflatten_dense_tensors(flat, tensors);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("flatten", &flatten, "Flatten dense tensors");
m.def("unflatten", &unflatten, "Unflatten dense tensors");
}

View File

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
Collection of system utilities.
*/
#include <torch/extension.h>
#include "tensor_cast.h"
using namespace pybind11::literals;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("cast_to_byte_tensor",
py::overload_cast<at::Tensor&>(&cast_to_byte_tensor),
"Cast a 1-dimensional tensor of any type to byte tensor.",
"src_tensor"_a);
m.def("cast_to_byte_tensor",
py::overload_cast<std::vector<at::Tensor>&>(&cast_to_byte_tensor),
"Cast a multi-dimensional tensor of any type to byte tensor.",
"src_tensor"_a);
}

View File

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "tensor_cast.h"
at::Tensor cast_to_byte_tensor(at::Tensor& src_tensor)
{
if (src_tensor.nbytes() <= 1) return src_tensor;
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(src_tensor.layout())
.device(src_tensor.device());
return at::from_blob(
src_tensor.data_ptr(), static_cast<long int>(src_tensor.nbytes()), options);
}
std::vector<at::Tensor> cast_to_byte_tensor(std::vector<at::Tensor>& tensor_list)
{
std::vector<at::Tensor> byte_tensors;
for (auto src_tensor : tensor_list) { byte_tensors.push_back(cast_to_byte_tensor(src_tensor)); }
return byte_tensors;
}

16
csrc/utils/tensor_cast.h Normal file
View File

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
Utilities for type casting torch tensors without data movement.
*/
#include <torch/extension.h>
#include <vector>
using namespace std;
at::Tensor cast_to_byte_tensor(at::Tensor& src_tensor);
std::vector<at::Tensor> cast_to_byte_tensor(std::vector<at::Tensor>& tensor_list);

View File

@ -19,3 +19,7 @@ TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30)))
INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'
CROSS_RANK = "CROSS_RANK"
CROSS_SIZE = "CROSS_SIZE"
LOCAL_RANK = 'LOCAL_RANK'

8
deepspeed/io/__init__.py Normal file
View File

@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .py_file_writer import PyFileWriter
from .fast_file_writer import FastFileWriter, FastFileWriterConfig
from .mock_file_writer import MockFileWriter

View File

@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .constants import *
BASE_STAT_KEYS = [
CLOSE_COUNT_KEY, FILENO_COUNT_KEY, FLUSH_COUNT_KEY, WRITE_COUNT_KEY, WRITE_BYTES_KEY, WRITE_SEC_KEY,
WRITE_SPEED_KEY
]
class BaseFileWriter(object):
def __init__(self, file_path):
self._file_path = file_path
self._stats = {k: 0 for k in BASE_STAT_KEYS}
def close(self):
pass
def fileno(self):
pass
def flush(self):
pass
def write(self, buffer):
pass
def file_path(self):
return self._file_path
def _incr_stats(self, key, incr=1):
self._stats[key] += incr
def _dump_state(self):
if self._stats[WRITE_SEC_KEY] > 0:
self._stats[WRITE_SPEED_KEY] = (self._stats[WRITE_BYTES_KEY] / self._stats[WRITE_SEC_KEY] / (1024**3))
state = self._stats
state[FILE_PATH_KEY] = self.file_path()
print(f'stats = {self._stats}')

View File

@ -0,0 +1,69 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
class Base_IO_Buffer(object):
def __init__(self, pinned_tensor, dnvme_handle):
assert pinned_tensor.numel() % dnvme_handle.get_alignment() == 0
self._dnvme_handle = dnvme_handle
self._pinned_tensor = pinned_tensor
def fill(self, src_tensor, src_offset):
pass
def drain(self, num_bytes, fd, file_offset):
pass
def is_empty(self):
pass
def is_full(self):
pass
def get_buffer(self):
pass
def get_offset(self):
pass
def get_aligned_num_bytes(self):
pass
def get_unaligned_num_bytes(self):
pass
def reset(self):
pass
def complete_ongoing_drain(self):
pass
def _drain(self, num_bytes, fd, file_offset, blocking=False):
assert num_bytes <= self.get_offset()
assert num_bytes % self._dnvme_handle.get_alignment() == 0
buffer = self.get_buffer()
r = self._dnvme_handle.async_pwrite(torch.narrow(buffer, 0, 0, num_bytes), fd, file_offset)
assert 0 == r
if blocking:
assert 1 == self._dnvme_handle.wait()
@staticmethod
def fill_buffer(src_tensor, src_offset, buffer_tensor, buffer_offset):
src_bytes = src_tensor.numel() - src_offset
assert src_bytes > 0
dst_bytes = buffer_tensor.numel() - buffer_offset
copy_bytes = min(src_bytes, dst_bytes)
assert (buffer_offset + copy_bytes) <= buffer_tensor.numel()
if copy_bytes > 0:
src_slice = torch.narrow(src_tensor, 0, src_offset, copy_bytes)
dst_slice = torch.narrow(buffer_tensor, 0, buffer_offset, copy_bytes)
dst_slice.data.copy_(src_slice.data)
return copy_bytes

31
deepspeed/io/constants.py Normal file
View File

@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
INVALID_FD = -1
FILE_PATH_KEY = 'path'
FLUSH_COUNT_KEY = 'flush'
WRITE_COUNT_KEY = 'write'
CLOSE_COUNT_KEY = 'close'
FILENO_COUNT_KEY = 'fileno'
WRITE_BYTES_KEY = 'bytes'
WRITE_SEC_KEY = 'write_secs'
WRITE_SPEED_KEY = 'write_GB/s'
AIO_WRITE_SEC_KEY = 'aio_write_secs'
AIO_WRITE_BYTES_KEY = 'aio_bytes'
AIO_SPEED_KEY = 'aio_GB/s'
SLOW_WRITE_BYTES_KEY = 'slow_bytes'
SLOW_WRITE_SEC_KEY = 'slow_write_secs'
AIO_FILL_BUFFER_SEC_KEY = 'fill_buffer_secs'
AIO_FILL_BUFFER_COUNT_KEY = 'fill_buffer_count'
AIO_FILL_BUFFER_SPEED_KEY = 'fill_buffer_GB/s'
SAVE_STORAGE_KEY = 'save_storage'
SAVE_STORAGE_BYTES_KEY = 'save_storage_bytes'
SAVE_STORAGE_SEC_KEY = 'save_storage_secs'
STORAGE_OBJ_SIZE = 8
RANK_KEY = 'rank'

View File

@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from .base_io_buffer import Base_IO_Buffer
NUM_BUFFERS = 2
INVALID_BUFFER_INDEX = -1
class Double_IO_Buffer(Base_IO_Buffer):
def __init__(self, pinned_tensor, dnvme_handle):
super(Double_IO_Buffer, self).__init__(pinned_tensor, dnvme_handle)
assert self._pinned_tensor.numel() % (NUM_BUFFERS * self._dnvme_handle.get_alignment()) == 0
self._buffers = self._split_buffer()
self._fill_index = 0
self._drain_index = INVALID_BUFFER_INDEX
self._buffer_offset = 0
def fill(self, src_tensor, src_offset):
self._validate_buffer_index(self._fill_index)
copy_bytes = Base_IO_Buffer.fill_buffer(src_tensor, src_offset, self._buffers[self._fill_index],
self._buffer_offset)
self._buffer_offset += copy_bytes
return copy_bytes
def drain(self, num_bytes, fd, file_offset):
self._validate_buffer_index(self._fill_index)
self.complete_ongoing_drain()
assert self._drain_index == INVALID_BUFFER_INDEX
self._drain(num_bytes, fd, file_offset, blocking=False)
self._drain_index = self._fill_index
self._fill_index = (self._fill_index + 1) % NUM_BUFFERS
self._buffer_offset = 0
def get_buffer(self):
self._validate_buffer_index(self._fill_index)
return self._buffers[self._fill_index]
def get_offset(self):
self._validate_buffer_index(self._fill_index)
return self._buffer_offset
def get_aligned_num_bytes(self):
self._validate_buffer_index(self._fill_index)
aligned_size = self._dnvme_handle.get_alignment()
return (self._buffer_offset // aligned_size) * aligned_size
def get_unaligned_num_bytes(self):
self._validate_buffer_index(self._fill_index)
return self._buffer_offset % self._dnvme_handle.get_alignment()
def is_full(self):
self._validate_buffer_index(self._fill_index)
return self._buffer_offset == self._buffers[self._fill_index].numel()
def is_empty(self):
self._validate_buffer_index(self._fill_index)
return self._buffer_offset == 0 and not self._is_ongoing_drain()
def reset(self):
self._buffer_offset = 0
def complete_ongoing_drain(self):
if self._is_ongoing_drain():
self._wait_for_drain()
def _split_buffer(self):
buffer_size = self._pinned_tensor.numel() // NUM_BUFFERS
return [torch.narrow(self._pinned_tensor, 0, (i * buffer_size), buffer_size) for i in range(NUM_BUFFERS)]
def _validate_buffer_index(self, index):
assert index in [0, 1]
def _wait_for_drain(self):
self._validate_buffer_index(self._drain_index)
assert 1 == self._dnvme_handle.wait()
self._drain_index = INVALID_BUFFER_INDEX
def _is_ongoing_drain(self):
return self._drain_index != INVALID_BUFFER_INDEX

View File

@ -0,0 +1,267 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
import time
from dataclasses import dataclass
from .constants import *
from .base_file_writer import BaseFileWriter
from .single_io_buffer import Single_IO_Buffer
from .double_io_buffer import Double_IO_Buffer
from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.accelerator import get_accelerator
from .utils import (tensor_to_bytes, bytes_to_tensor, obj_serialization_details)
FASTIO_STAT_KEYS = [
AIO_WRITE_SEC_KEY,
AIO_WRITE_BYTES_KEY,
AIO_SPEED_KEY,
SLOW_WRITE_BYTES_KEY,
SLOW_WRITE_SEC_KEY,
AIO_FILL_BUFFER_COUNT_KEY,
AIO_FILL_BUFFER_SEC_KEY,
AIO_FILL_BUFFER_SPEED_KEY,
SAVE_STORAGE_KEY,
SAVE_STORAGE_BYTES_KEY,
]
@dataclass
class FastFileWriterConfig:
dnvme_handle: object
pinned_tensor: torch.Tensor
double_buffer: bool = True
num_parallel_writers: int = 1
writer_rank: int = 0
global_rank: int = 0
class FastFileWriter(BaseFileWriter):
def __init__(self, file_path, config):
super(FastFileWriter, self).__init__(file_path)
self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_CREAT | os.O_WRONLY)
self._dnvme_handle = config.dnvme_handle
self._file_offset = 0
io_buffer_type = Double_IO_Buffer if config.double_buffer else Single_IO_Buffer
self._io_buffer = io_buffer_type(config.pinned_tensor, self._dnvme_handle)
self._cast_to_byte_tensor = UtilsBuilder().load().cast_to_byte_tensor
self._get_serialization_details = obj_serialization_details()
self._num_parallel_writers = config.num_parallel_writers
self._writer_rank = config.writer_rank
self._global_rank = config.global_rank
for k in FASTIO_STAT_KEYS:
self._stats[k] = 0
def write(self, buffer):
assert self._file_offset % self._dnvme_handle.get_alignment() == 0
buffer_num_bytes = len(buffer)
num_written_bytes = self._write_from_tensor(bytes_to_tensor(buffer))
assert buffer_num_bytes == num_written_bytes
return buffer_num_bytes
def split_index_list(self, storage_obj_list, num_splits):
assert num_splits > 0
split_list = [-1] * num_splits
# t[0] is data, t[1] is data_type
tensor_bytes_list = [len(t[0]) for t in storage_obj_list]
print(tensor_bytes_list)
total_bytes = sum(tensor_bytes_list)
bytes_per_group = total_bytes / num_splits
split_counter = 0
tmp_size = 0
for i in range(len(tensor_bytes_list)):
tmp_size += tensor_bytes_list[i]
if tmp_size > bytes_per_group:
split_list[split_counter] = i
tmp_size = 0
split_counter += 1
if split_list[num_splits - 1] == -1:
split_list[num_splits - 1] = len(tensor_bytes_list)
return split_list
def save_torch_storage_object_list(self, storage_obj_list, save_size):
assert self._file_offset % self._dnvme_handle.get_alignment() == 0
num_bytes_written = self._save_storage_list(storage_obj_list, save_size)
return num_bytes_written
def close(self):
self._fini()
self._incr_stats(CLOSE_COUNT_KEY)
def fileno(self):
self._incr_stats(FILENO_COUNT_KEY)
return INVALID_FD # self._aio_fd
def flush(self):
self._incr_stats(FLUSH_COUNT_KEY)
def __del__(self):
self._fini()
assert self._aio_fd == INVALID_FD
assert self._io_buffer.get_offset() == 0, \
f'__del__ assert: pinned_offset {self._io_buffer.get_offset()} != 0'
assert self._file_offset == self._stats[WRITE_BYTES_KEY], \
f'__del__ assert: file_offset != write_bytes - {self._file_offset} != {self._stats[WRITE_BYTES_KEY]}'
def _fini(self):
if not self._io_buffer_is_empty():
self._force_drain()
self._io_buffer.reset()
self._aio_fd = INVALID_FD
def _fill_io_buffer(self, src_tensor, src_offset):
st = time.time()
copy_bytes = self._io_buffer.fill(src_tensor, src_offset)
self._incr_stats(AIO_FILL_BUFFER_SEC_KEY, time.time() - st)
self._incr_stats(AIO_FILL_BUFFER_COUNT_KEY)
return copy_bytes
def _drain_io_buffer(self, num_bytes):
st = time.time()
self._io_buffer.drain(num_bytes, self._aio_fd, self._file_offset)
self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st)
self._incr_stats(AIO_WRITE_BYTES_KEY, num_bytes)
self._file_offset += num_bytes
def _io_buffer_is_full(self):
return self._io_buffer.is_full()
def _io_buffer_is_empty(self):
return self._io_buffer.is_empty()
def _force_drain(self):
st = time.time()
aligned_num_bytes = self._io_buffer.get_aligned_num_bytes()
# Important to retrieve unaligned drain bytes and tensor before doing aligned drain because of the side effects.
# TODO: Need to eliminate this dependency
unaligned_num_bytes = self._io_buffer.get_unaligned_num_bytes()
unaligned_tensor = torch.narrow(self._io_buffer.get_buffer(), 0, aligned_num_bytes, unaligned_num_bytes)
if aligned_num_bytes > 0:
self._drain_io_buffer(aligned_num_bytes)
self._io_buffer.complete_ongoing_drain()
self._incr_stats(AIO_WRITE_SEC_KEY, time.time() - st)
if unaligned_num_bytes > 0:
self._unaligned_drain(unaligned_tensor)
self._incr_stats(WRITE_SEC_KEY, time.time() - st)
def _unaligned_drain(self, unaligned_tensor):
os.close(self._aio_fd)
st = time.time()
fp = open(self._file_path, 'ab')
fp.write(tensor_to_bytes(unaligned_tensor.cpu()))
fp.close()
self._file_offset += unaligned_tensor.numel()
self._incr_stats(SLOW_WRITE_SEC_KEY, time.time() - st)
self._incr_stats(SLOW_WRITE_BYTES_KEY, unaligned_tensor.numel())
self._aio_fd = os.open(self._file_path, flags=os.O_DIRECT | os.O_WRONLY | os.O_APPEND)
def _dump_state(self):
if self._stats[AIO_WRITE_SEC_KEY] > 0:
self._stats[AIO_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] / self._stats[AIO_WRITE_SEC_KEY] /
(1024**3))
if self._stats[AIO_FILL_BUFFER_SEC_KEY] > 0:
self._stats[AIO_FILL_BUFFER_SPEED_KEY] = (self._stats[AIO_WRITE_BYTES_KEY] /
self._stats[AIO_FILL_BUFFER_SEC_KEY] / (1024**3))
super()._dump_state()
def _update_write_stats(self, num_bytes, secs_latency):
self._incr_stats(WRITE_COUNT_KEY)
self._incr_stats(WRITE_BYTES_KEY, num_bytes)
self._incr_stats(WRITE_SEC_KEY, secs_latency)
def _write_from_tensor(self, buffer_tensor):
st = time.time()
buffer_offset = 0
while (buffer_offset < buffer_tensor.numel()):
num_copied_bytes = self._fill_io_buffer(buffer_tensor, buffer_offset)
if self._io_buffer_is_full():
self._drain_io_buffer(self._io_buffer.get_offset())
buffer_offset += num_copied_bytes
self._update_write_stats(buffer_offset, time.time() - st)
return buffer_offset
def _save_storage_list(self, obj_list, save_size):
byte_tensor_list, byte_tensor_nbytes = self._convert_to_byte_tensors(obj_list, save_size)
if self._num_parallel_writers > 1:
my_byte_tensor_list = self._partition_byte_tensors(byte_tensor_list, byte_tensor_nbytes,
self._num_parallel_writers, self._writer_rank)
else:
my_byte_tensor_list = byte_tensor_list
num_object_bytes_written = 0
for byte_tensor in my_byte_tensor_list:
num_object_bytes_written += self._write_from_tensor(byte_tensor)
self._incr_stats(SAVE_STORAGE_KEY, len(obj_list))
self._incr_stats(SAVE_STORAGE_BYTES_KEY, num_object_bytes_written)
return num_object_bytes_written
# Convert list of storage objects into list of byte tensors of object and size bytes
def _convert_to_byte_tensors(self, obj_list, save_size):
tensor_list = []
num_bytes = 0
for storage_obj in obj_list:
details = self._get_serialization_details(storage_obj)
if save_size:
tensor_list.append(
torch.tensor(
details.size,
dtype=torch.int64,
).to(get_accelerator().device_name()))
tensor_list.append(torch.empty(0, dtype=details.dtype, device=details.obj.device).set_(details.obj))
num_bytes += details.nbytes
if save_size:
num_bytes += STORAGE_OBJ_SIZE * len(obj_list)
return self._cast_to_byte_tensor(tensor_list), num_bytes
def _partition_byte_tensors(self, byte_tensor_list, byte_tensor_nbytes, num_ranks, my_rank):
assert my_rank >= 0, f'Invalid for rank number to be negative: {my_rank}'
assert num_ranks > my_rank, f'Number of ranks {num_ranks} must be greater than rank {my_rank}'
partition_size = int(byte_tensor_nbytes // num_ranks)
num_remainder_bytes = byte_tensor_nbytes % num_ranks
if num_remainder_bytes == 0:
partition_start = partition_size * my_rank
else:
# Spread extra bytes evenly among early ranks
if num_remainder_bytes > my_rank:
partition_size += 1
partition_start = partition_size * my_rank
else:
# Account for allocation of extra bytes to earlier ranks
partition_start = (partition_size * my_rank) + num_remainder_bytes
partition_end = min(partition_start + partition_size, byte_tensor_nbytes)
partition_tensor_list = []
current_offset = 0
for byte_tensor in byte_tensor_list:
byte_tensor_end = current_offset + byte_tensor.numel()
if current_offset < partition_end and byte_tensor_end > partition_start:
fragment_start = max(current_offset, partition_start)
fragment_end = min(byte_tensor_end, partition_end)
assert fragment_start < fragment_end, \
f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
fragment_numel = fragment_end - fragment_start
partition_tensor_list.append(byte_tensor.narrow(0, fragment_start - current_offset, fragment_numel))
current_offset += byte_tensor.numel()
actual_partition_nbytes = sum([t.numel() for t in partition_tensor_list])
assert actual_partition_nbytes == partition_size, \
f'Incorrect partition bytes for rank {my_rank}, expected = {partition_size} actual = {actual_partition_nbytes}'
return partition_tensor_list

View File

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .constants import *
from .base_file_writer import BaseFileWriter
from .utils import obj_serialization_details
class MockFileWriter(BaseFileWriter):
def __init__(self, file_path):
super(MockFileWriter, self).__init__(file_path)
self._fp = open(file_path, 'wb')
self._stats[SAVE_STORAGE_KEY] = 0
self._stats[SAVE_STORAGE_BYTES_KEY] = 0
self._get_serialization_details = obj_serialization_details()
def close(self):
self._incr_stats(CLOSE_COUNT_KEY)
self._fp.close()
def fileno(self):
self._incr_stats(FILENO_COUNT_KEY)
return INVALID_FD # self._fp.fileno()
def flush(self):
self._incr_stats(FLUSH_COUNT_KEY)
self._fp.flush()
def write(self, buffer):
return self._write(len(buffer))
def save_torch_storage_object_list(self, storage_obj_list, save_size):
num_bytes = sum([self._save_torch_storage_object(obj, save_size) for obj in storage_obj_list])
return num_bytes
def _save_torch_storage_object(self, storage_obj, save_size):
details = self._get_serialization_details(storage_obj)
self._incr_stats(SAVE_STORAGE_KEY)
self._incr_stats(SAVE_STORAGE_BYTES_KEY, details.size)
num_written_bytes = self._write(STORAGE_OBJ_SIZE) if save_size else 0
return num_written_bytes + self._write(details.size)
def _write(self, num_bytes):
self._incr_stats(WRITE_COUNT_KEY)
self._incr_stats(WRITE_BYTES_KEY, num_bytes)
return num_bytes

View File

@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import time
from .constants import *
from .base_file_writer import BaseFileWriter
class PyFileWriter(BaseFileWriter):
def __init__(self, file_path):
super(PyFileWriter, self).__init__(file_path)
self._fp = open(file_path, 'wb')
def close(self):
self._incr_stats(CLOSE_COUNT_KEY)
self._fp.close()
def fileno(self):
self._incr_stats(FILENO_COUNT_KEY)
return INVALID_FD # self._fp.fileno()
def flush(self):
self._incr_stats(FLUSH_COUNT_KEY)
self._fp.flush()
def write(self, buffer):
st = time.time()
self._fp.write(buffer)
self._incr_stats(WRITE_SEC_KEY, time.time() - st)
self._incr_stats(WRITE_COUNT_KEY)
self._incr_stats(WRITE_BYTES_KEY, len(buffer))
return len(buffer)

View File

@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base_io_buffer import Base_IO_Buffer
class Single_IO_Buffer(Base_IO_Buffer):
def __init__(self, pinned_tensor, dnvme_handle):
super(Single_IO_Buffer, self).__init__(pinned_tensor, dnvme_handle)
self._pinned_offset = 0
def fill(self, src_tensor, src_offset):
copy_bytes = Base_IO_Buffer.fill_buffer(src_tensor, src_offset, self._pinned_tensor, self._pinned_offset)
self._pinned_offset += copy_bytes
return copy_bytes
def drain(self, num_bytes, fd, file_offset):
self._drain(num_bytes, fd, file_offset, blocking=True)
self._pinned_offset = 0
def get_buffer(self):
return self._pinned_tensor
def get_offset(self):
return self._pinned_offset
def get_aligned_num_bytes(self):
aligned_size = self._dnvme_handle.get_alignment()
return (self._pinned_offset // aligned_size) * aligned_size
def get_unaligned_num_bytes(self):
return self._pinned_offset % self._dnvme_handle.get_alignment()
def is_full(self):
return self._pinned_offset == self._pinned_tensor.numel()
def is_empty(self):
return self._pinned_offset == 0
def reset(self):
self._pinned_offset = 0

56
deepspeed/io/utils.py Normal file
View File

@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import numpy
import torch
from dataclasses import dataclass
@dataclass
class serialize_details:
obj: object
dtype: torch.dtype
size: int
nbytes: int
def tensor_to_bytes(tensor):
return tensor.numpy().tobytes()
def bytes_to_tensor(buffer):
return torch.from_numpy(numpy.array(numpy.frombuffer(buffer, dtype=numpy.uint8)))
def required_minimum_torch_version(major_version, minor_version):
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR < major_version:
return False
return TORCH_MAJOR > major_version or TORCH_MINOR >= minor_version
# torch < 1.12
def _legacy_obj_serialization_details(storage_obj):
nbytes = storage_obj.element_size() * storage_obj.size()
return serialize_details(obj=storage_obj, dtype=storage_obj.dtype, size=nbytes, nbytes=nbytes)
# torch >= 1.12
def _new_obj_serialization_details(storage_obj):
obj, dtype = storage_obj
return serialize_details(obj=obj,
dtype=dtype,
size=obj.size() // torch._utils._element_size(dtype),
nbytes=obj.size())
def obj_serialization_details():
if required_minimum_torch_version(1, 12):
return _new_obj_serialization_details
return _legacy_obj_serialization_details

View File

@ -22,8 +22,8 @@ import psutil
from collections import defaultdict
from typing import Dict
from argparse import ArgumentParser, REMAINDER
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, CROSS_RANK, CROSS_SIZE
from deepspeed.accelerator import get_accelerator
from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT
from ..nebula.constants import DLTS_POD_ENV_PATH
from ..utils import logger, get_numactl_cmd
from ..elasticity import is_torch_elastic_compatible
@ -171,8 +171,8 @@ def main():
current_env["MASTER_ADDR"] = args.master_addr
current_env["MASTER_PORT"] = str(args.master_port)
current_env["WORLD_SIZE"] = str(dist_world_size)
current_env["CROSS_RANK"] = str(args.node_rank)
current_env["CROSS_SIZE"] = str(args.nnodes)
current_env[CROSS_RANK] = str(args.node_rank)
current_env[CROSS_SIZE] = str(args.nnodes)
current_env["LOCAL_SIZE"] = str(num_local_procs)
if args.save_pid:

View File

@ -100,6 +100,8 @@ class PDSHRunner(MultiNodeRunner):
f'--world_info={self.world_info_base64}', "--node_rank=%n", f"--master_addr={self.args.master_addr}",
f"--master_port={self.args.master_port}"
]
if self.args.venv_script is not None:
deepspeed_launch = [f"source {self.args.venv_script}"] + deepspeed_launch
if self.args.no_python:
deepspeed_launch.append("--no_python")
if self.args.module:

View File

@ -207,6 +207,11 @@ def parse_args(args=None):
parser.add_argument("--ssh_port", type=int, default=None, help="SSH port to use for remote connections")
parser.add_argument("--venv_script",
type=str,
default=None,
help="Python virtual environment activation script for job.")
return parser.parse_args(args=args)

View File

@ -9,6 +9,7 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import argparse
import os
from .test_ds_aio_utils import refine_integer_value
from .ds_aio_constants import AIO_HANDLE, AIO_BASIC, TORCH_FAST_IO, TORCH_IO, VALID_ENGINES
from deepspeed.accelerator import get_accelerator
MAPPING_DELIMITER = ':'
@ -21,6 +22,9 @@ def refine_args(args):
if args.block_size and type(args.block_size) == str:
args.block_size = refine_integer_value(args.block_size)
if args.fast_io_size and type(args.fast_io_size) == str:
args.fast_io_size = refine_integer_value(args.fast_io_size)
return args
@ -83,6 +87,19 @@ def validate_args(args):
no_error = no_error and no_mapping_error
error_messages += mapping_error_messages
# Validate --engine
if args.engine not in VALID_ENGINES:
no_error = False
error_messages.append(f'Invalid engine {args.engine}. Valid options = {VALID_ENGINES}')
# Validate --engine=torch_io
if args.engine == TORCH_IO:
if args.read:
no_error = False
error_messages.append(f'Read not currently supported for --engine={TORCH_IO}')
if not no_error:
print(f'Found {len(error_messages)} validation error(s)')
# Validate --gpu, --use_gds
if args.use_gds and not args.gpu:
error_messages.append(f'--gpu must be set to transfer with --use_gds')
@ -111,6 +128,8 @@ def parse_arguments():
parser.add_argument('--io_size', type=str, default=None, required=True, help='Number of bytes to read or write.')
parser.add_argument('--fast_io_size', type=str, default='64M', help='Size of fast_io pinned buffer (bytes).')
parser.add_argument('--read', action='store_true', help='Perform read I/O (default is write)')
parser.add_argument('--multi_process',
@ -138,7 +157,13 @@ def parse_arguments():
parser.add_argument('--validate', action='store_true', help='Perform validation of I/O transfer in library.')
parser.add_argument('--handle', action='store_true', help='Use AIO handle.')
parser.add_argument(
'--engine',
type=str,
default=AIO_HANDLE,
help=
f'Engine to perform I/O. Options are [{AIO_HANDLE}, {AIO_BASIC}, {TORCH_IO}, {TORCH_FAST_IO}]. Default is aio_handle'
)
parser.add_argument('--loops', type=int, default=3, help='Count of operation repetitions')
@ -152,6 +177,20 @@ def parse_arguments():
action='store_true',
help='For GPU memory transfers, measure impact of bounce buffer pinning on critical path.')
parser.add_argument('--torch_legacy_save', action='store_true', help='Use torch legacy save approach')
parser.add_argument('--use_accelerator_pin_memory',
action='store_true',
help='Obtain pinned (CPU page-locked) tensors from accelerator')
parser.add_argument('--warmup_loops', type=int, default=1, help='Count of operation warmup repetitions')
parser.add_argument('--include_warmup_time', action='store_true', help='Include warmup latency in results')
parser.add_argument('--different_file_each_iteration',
action='store_true',
help='Read/write a different file on each iteration.')
args = parser.parse_args()
print(f'args = {args}')
return args
@ -163,7 +202,7 @@ def get_validated_args():
if not validate_args(args):
quit()
print(f'Successful validation of command line arguments')
args.total_loops = args.warmup_loops + args.loops
peer_tag = 'gpu' if args.gpu else 'process'
args.mapping_dict = _get_mapping_dict(args)
args.mapping_list = [(device_id, folder) for device_id, folder in args.mapping_dict.items()]

View File

@ -6,129 +6,59 @@
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from multiprocessing import Pool, Barrier
from .test_ds_aio_utils import report_results, task_log, task_barrier
from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from .ds_aio_constants import *
def pre_basic(args, tid, read_op):
class AIOBasic_Engine(object):
def __init__(self, args, tid, read_op):
self.ctxt = self._create_context(args, tid, read_op)
def fini(self):
self.ctxt[BUFFER].detach()
self.ctxt[BUFFER] = None
def read(self, args, tid, loop_id):
start_time = time.time()
AsyncIOBuilder().load().aio_read(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def write(self, args, tid, loop_id):
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(self.ctxt[FILE]):
os.remove(self.ctxt[FILE])
start_time = time.time()
AsyncIOBuilder().load().aio_write(self.ctxt[BUFFER], self.ctxt[FILE], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size
file = args.read_file if read_op else f'{args.write_file}.{tid}'
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
task_log(tid, f'Allocate tensor of size {num_bytes} bytes')
buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid, f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}')
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
buffer = create_page_locked_tensor(args.io_size, True)
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}')
task_log(tid, f'created deepspeed aio basic engine')
ctxt = {}
ctxt['file'] = file
ctxt['num_bytes'] = num_bytes
ctxt['buffer'] = buffer
ctxt['elapsed_sec'] = 0
ctxt[FILE] = filename
ctxt[NUM_BYTES] = args.io_size
ctxt[BUFFER] = buffer
ctxt[ELAPSED_SEC] = 0
return ctxt
def pre_basic_read(pool_params):
args, tid = pool_params
ctxt = pre_basic(args, tid, True)
return ctxt
def pre_basic_write(pool_params):
args, tid = pool_params
ctxt = pre_basic(args, tid, False)
return ctxt
def post_basic(pool_params):
_, _, ctxt = pool_params
ctxt["buffer"].detach()
ctxt["buffer"] = None
return ctxt
def main_basic_read(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_read(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_basic_write(pool_params):
args, tid, ctxt = pool_params
start_time = time.time()
AsyncIOBuilder().load().aio_write(ctxt['buffer'], ctxt['file'], args.block_size, args.queue_depth,
args.single_submit, not args.sequential_requests, args.validate)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = pre_basic_read
schedule['post'] = post_basic
schedule['main'] = main_basic_read
else:
schedule['pre'] = pre_basic_write
schedule['post'] = post_basic
schedule['main'] = main_basic_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
for i in range(args.loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
def _init_tasklet(b):
global aio_barrier
aio_barrier = b
def aio_basic_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)

View File

@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
AIO_HANDLE = 'aio_handle'
AIO_BASIC = 'aio_basic'
TORCH_IO = 'torch_io'
TORCH_FAST_IO = 'torch_fastio'
VALID_ENGINES = [AIO_HANDLE, AIO_BASIC, TORCH_IO, TORCH_FAST_IO]
BUFFER = 'buffer'
BOUNCE_BUFFER = 'bounce_buffer'
NUM_BYTES = 'num_bytes'
FILE = 'file'
HANDLE = 'handle'
ELAPSED_SEC = 'elapsed_sec'
FAST_IO_BUFFER = 'fast_io_buffer'
USE_CPU_LOCKED_TENSOR = 'cpu_locked_tensor'
USE_GDS = 'gds'

View File

@ -9,214 +9,118 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import torch
import os
import time
from multiprocessing import Pool, Barrier
from deepspeed.ops.aio import AsyncIOBuilder
from deepspeed.ops.op_builder import GDSBuilder
from deepspeed.accelerator import get_accelerator
from .test_ds_aio_utils import report_results, task_log, task_barrier, create_filename, create_file
BUFFER = 'buffer'
BOUNCE_BUFFER = 'bounce_buffer'
from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from .ds_aio_constants import *
def pre_handle(args, tid, read_op):
io_string = "Read" if read_op else "Write"
gds = True if args.use_gds else False
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
class AIOHandle_Engine(object):
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
bounce_buffer = None
if args.gpu:
device_name = get_accelerator().device_name(device_id)
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name)
if not (args.slow_bounce_buffer or gds):
bounce_buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8,
device='cpu').pin_memory()
def __init__(self, args, tid, read_op):
self.ctxt = self._create_context(args, tid, read_op)
def fini(self):
for buf in [BUFFER, BOUNCE_BUFFER]:
if self.ctxt[buf] is not None:
if self.ctxt[USE_CPU_LOCKED_TENSOR]:
self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf])
self.ctxt[buf].detach()
self.ctxt[buf] = None
def read(self, args, tid, loop_id):
handle = self.ctxt[HANDLE]
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if self.ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.pread(self.ctxt[dest_buffer], self.ctxt[FILE][loop_id], args.validate, True)
assert ret != -1
handle.wait()
if dest_buffer == BOUNCE_BUFFER:
self.ctxt[BUFFER].data.copy_(self.ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
self.ctxt[ELAPSED_SEC].append(end_time - start_time)
def write(self, args, tid, loop_id):
# Avoid overwriting existing files as it could be artificially faster
# if os.path.isfile(self.ctxt[FILE]):
# os.remove(self.ctxt[FILE])
handle = self.ctxt[HANDLE]
start_time = time.time()
if self.ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
self.ctxt[BOUNCE_BUFFER].data.copy_(self.ctxt[BUFFER].data)
else:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device='cpu').pin_memory()
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
source_buffer = BUFFER
ret = handle.pwrite(self.ctxt[source_buffer], self.ctxt[FILE][loop_id], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
self.ctxt[ELAPSED_SEC].append(end_time - start_time)
def _create_files(self, args, folder, tid):
if args.different_file_each_iteration:
filenames = [
create_filename(folder, args.read, args.io_size, f'{tid}_{l}') for l in range(args.total_loops)
]
else:
filenames = [
create_filename(folder, args.read, args.io_size, f'{tid}_{0}') for _ in range(args.total_loops)
]
if args.read:
for f in filenames:
if not (os.path.isfile(f) and os.path.getsize(f) == args.io_size):
create_file(f, args.io_size)
else:
for f in filenames:
if os.path.isfile(f):
os.remove(f)
return filenames
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
device_id, folder = args.mapping_list[tid]
filenames = self._create_files(args, folder, tid)
gds = True if args.use_gds else False
io_parallel = args.io_parallel if args.io_parallel else 1
if gds:
handle = GDSBuilder().load().gds_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
handle.pin_device_tensor(buffer)
else:
handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
task_log(tid, f'created deepspeed aio handle')
task_log(tid, f'Created DeepNVMe handle engine')
bounce_buffer = None
if args.gpu:
device_name = get_accelerator().device_name(device_id)
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=device_name)
if gds:
handle.pin_device_tensor(buffer)
elif not args.slow_bounce_buffer:
bounce_buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle)
else:
buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, handle)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
ctxt = {}
ctxt['file'] = filename
ctxt['num_bytes'] = args.io_size
ctxt['handle'] = handle
ctxt['gds'] = gds
ctxt[FILE] = filenames
ctxt[NUM_BYTES] = args.io_size
ctxt[HANDLE] = handle
ctxt[USE_GDS] = gds
ctxt[BUFFER] = buffer
ctxt[BOUNCE_BUFFER] = bounce_buffer
ctxt['elapsed_sec'] = 0
ctxt[ELAPSED_SEC] = []
ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory
task_log(tid,
f'{io_string} file {filenames} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
return ctxt
def pre_handle_read(pool_params):
args, tid = pool_params
ctxt = pre_handle(args, tid, True)
return ctxt
def pre_handle_write(pool_params):
args, tid = pool_params
ctxt = pre_handle(args, tid, False)
return ctxt
def post_handle(pool_params):
_, _, ctxt = pool_params
for buf in [BUFFER, BOUNCE_BUFFER]:
if ctxt[buf] is not None:
if ctxt['gds']:
ctxt['handle'].unpin_device_tensor(ctxt[buf])
ctxt[buf].detach()
ctxt[buf] = None
return ctxt
def main_parallel_read(pool_params):
args, tid, ctxt = pool_params
handle = ctxt['handle']
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.pread(ctxt[dest_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
if dest_buffer == BOUNCE_BUFFER:
ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_parallel_write(pool_params):
args, tid, ctxt = pool_params
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(ctxt['file']):
os.remove(ctxt['file'])
handle = ctxt['handle']
start_time = time.time()
if ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.pwrite(ctxt[source_buffer], ctxt['file'], args.validate, True)
assert ret != -1
handle.wait()
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_handle_read(pool_parms):
args, tid, ctxt = pool_parms
handle = ctxt['handle']
start_time = time.time()
dest_buffer = BOUNCE_BUFFER if ctxt[BOUNCE_BUFFER] is not None else BUFFER
ret = handle.read(ctxt[dest_buffer], ctxt['file'], args.validate)
assert ret != -1
if dest_buffer == BOUNCE_BUFFER:
ctxt[BUFFER].data.copy_(ctxt[BOUNCE_BUFFER].data)
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def main_handle_write(pool_parms):
args, tid, ctxt = pool_parms
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(ctxt['file']):
os.remove(ctxt['file'])
handle = ctxt['handle']
start_time = time.time()
if ctxt[BOUNCE_BUFFER] is not None:
source_buffer = BOUNCE_BUFFER
ctxt[BOUNCE_BUFFER].data.copy_(ctxt[BUFFER].data)
else:
source_buffer = BUFFER
ret = handle.write(ctxt[source_buffer], ctxt['file'], args.validate)
assert ret != -1
end_time = time.time()
ctxt['elapsed_sec'] += end_time - start_time
return ctxt
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = pre_handle_read
schedule['post'] = post_handle
schedule['main'] = main_parallel_read
else:
schedule['pre'] = pre_handle_write
schedule['post'] = post_handle
schedule['main'] = main_parallel_write
return schedule
def _aio_handle_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
ctxt = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
ctxt["main_task_sec"] = 0
for i in range(args.loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
ctxt = schedule["main"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
ctxt["main_task_sec"] += stop_time - start_time
# Run post task
task_log(tid, f'running post-task')
ctxt = schedule["post"]((args, tid, ctxt))
task_barrier(aio_barrier, num_processes)
return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
def _init_tasklet(b):
global aio_barrier
aio_barrier = b
def aio_handle_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_tasklet, initargs=(b, )) as p:
pool_results = p.map(_aio_handle_tasklet, pool_params)
report_results(args, read_op, pool_results)

126
deepspeed/nvme/io_engine.py Normal file
View File

@ -0,0 +1,126 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import time
from multiprocessing import Pool, Barrier
from .ds_aio_constants import AIO_BASIC, TORCH_FAST_IO, TORCH_IO
from .test_ds_aio_utils import report_results, task_log, task_barrier
from .ds_aio_handle import AIOHandle_Engine
from .ds_aio_basic import AIOBasic_Engine
from .torch_io import TorchIO_Engine
from .torch_fastio_engine import Torch_FastIO_Engine
def prepare_operation(args, tid, read_op):
if args.engine == TORCH_IO:
io_engine = TorchIO_Engine(args, tid, read_op)
elif args.engine == AIO_BASIC:
io_engine = AIOBasic_Engine(args, tid, read_op)
elif args.engine == TORCH_FAST_IO:
io_engine = Torch_FastIO_Engine(args, tid, read_op)
else:
io_engine = AIOHandle_Engine(args, tid, read_op)
return io_engine
def prepare_read(pool_params):
args, tid = pool_params
return prepare_operation(args, tid, True)
def prepare_write(pool_params):
args, tid = pool_params
return prepare_operation(args, tid, False)
def post_operation(pool_params):
_, _, io_engine = pool_params
io_engine.fini()
def read_operation(pool_params):
args, tid, loop_id, io_engine = pool_params
return io_engine.read(args, tid, loop_id)
def write_operation(pool_params):
args, tid, loop_id, io_engine = pool_params
return io_engine.write(args, tid, loop_id)
def get_schedule(args, read_op):
schedule = {}
if read_op:
schedule['pre'] = prepare_read
schedule['post'] = post_operation
schedule['main'] = read_operation
else:
schedule['pre'] = prepare_write
schedule['post'] = post_operation
schedule['main'] = write_operation
return schedule
def io_engine_tasklet(pool_params):
args, tid, read_op = pool_params
num_processes = len(args.mapping_dict)
# Create schedule
schedule = get_schedule(args, read_op)
task_log(tid, f'schedule = {schedule}')
task_barrier(aio_barrier, num_processes)
# Run pre task
task_log(tid, f'running pre-task')
io_engine = schedule["pre"]((args, tid))
task_barrier(aio_barrier, num_processes)
# Run main tasks in a loop
io_engine.ctxt["main_task_sec"] = []
for i in range(args.total_loops):
task_log(tid, f'running main task {i}')
start_time = time.time()
schedule["main"]((args, tid, i, io_engine))
task_barrier(aio_barrier, num_processes)
stop_time = time.time()
io_engine.ctxt["main_task_sec"].append(stop_time - start_time)
# Run post task
task_log(tid, f'running post-task')
schedule["post"]((args, tid, io_engine))
task_barrier(aio_barrier, num_processes)
ctxt = io_engine.ctxt
# return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops
if args.include_warmup_time:
e2e_latency_sec = sum(ctxt["main_task_sec"])
task_latency_sec = sum(ctxt["elapsed_sec"])
actual_loops = args.total_loops
else:
e2e_latency_sec = sum(ctxt["main_task_sec"][args.warmup_loops:])
task_latency_sec = sum(ctxt["elapsed_sec"][args.warmup_loops:])
actual_loops = args.loops
l = ctxt["elapsed_sec"]
task_log(tid, f'task_latency_sec = {l}')
return e2e_latency_sec, task_latency_sec, ctxt["num_bytes"] * actual_loops
def _init_takslet(b):
global aio_barrier
aio_barrier = b
def io_engine_multiprocessing(args, read_op):
num_processes = len(args.mapping_dict)
b = Barrier(num_processes)
pool_params = [(args, p, read_op) for p in range(num_processes)]
with Pool(processes=num_processes, initializer=_init_takslet, initargs=(b, )) as p:
pool_results = p.map(io_engine_tasklet, pool_params)
report_results(args, read_op, pool_results)

View File

@ -52,8 +52,12 @@ def convert_to_param(key):
def generate_aio_param(read_log_dir, write_log_dir):
_, read_results = get_sorted_results(read_log_dir, READ_SPEED)
_, write_results = get_sorted_results(write_log_dir, WRITE_SPEED)
combined_perf = {key[1:]: value for key, value in read_results.items()}
read_results_count = len(read_results.items())
write_results_count = len(write_results.items())
assert read_results_count == write_results_count, f"Mismatch in number of read & write results: {read_results_count=} != {write_results_count=}"
combined_perf = {key[1:]: value for key, value in read_results.items()}
for key, value in write_results.items():
new_key = key[1:]
if new_key in combined_perf:

View File

@ -12,20 +12,19 @@ import json
import itertools
import shutil
from deepspeed.ops.op_builder import AsyncIOBuilder
from deepspeed.ops.op_builder import GDSBuilder
from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder
from .ds_aio_job import Job, run_job
from .perf_sweep_utils import READ_OP_DESC, WRITE_OP_DESC, BENCH_LOG_DIR, \
READ_LOG_DIR, WRITE_LOG_DIR
OTHER_OPTIONS = '--handle'
OTHER_OPTIONS = '--engine aio_handle'
PERF_SCRIPT = 'ds_io'
DEFAULT_SWEEP_CONFIG = {
"block_size": ["1M", "8M"],
"queue_depth": [32, 128],
"sequential_requests": [False],
"single_submit": [False],
"io_parallel": [1, 8],
"sequential_requests": [True, False],
"single_submit": [False, True],
"io_parallel": [1, 2, 4, 8],
}

View File

@ -7,17 +7,16 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""
import multiprocessing as mp
from .ds_aio_basic import aio_basic_multiprocessing
from .ds_aio_handle import aio_handle_multiprocessing
from .ds_aio_args import get_validated_args
from .io_engine import io_engine_multiprocessing
def ds_io_main():
print(f'Testing deepspeed_aio python frontend')
print(f'Testing DeepNVMe python frontend')
args = get_validated_args()
mp.set_start_method('spawn')
multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing
mp.set_start_method('spawn', force=True)
multiprocess_function = io_engine_multiprocessing
multiprocess_function(args, args.read)

View File

@ -8,6 +8,8 @@ Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
import os
from .ds_aio_job import Job, run_job
import torch
from deepspeed.accelerator import get_accelerator
BYTES_PER_GB = 1024**3
BYTES_PER_MB = 1024**2
@ -79,3 +81,11 @@ def create_file(filename, num_bytes):
print(f'[Start] Create {filename} of {num_bytes} bytes by running {dd_job.cmd()} ....')
run_job(dd_job)
print(f'[Done] Create read file of {num_bytes} bytes by running {dd_job.cmd()} ....')
def create_page_locked_tensor(num_elem, use_accelerator, aio_handle=None):
if use_accelerator:
return get_accelerator().pin_memory(torch.randint(high=128, size=(num_elem, ), dtype=torch.uint8,
device='cpu'))
else:
return aio_handle.new_cpu_locked_tensor(num_elem, torch.empty(0, dtype=torch.uint8))

View File

@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
import time
from deepspeed.ops.aio import AsyncIOBuilder
from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from .ds_aio_constants import *
from deepspeed.io import FastFileWriter
class Torch_FastIO_Engine(object):
def __init__(self, args, tid, read_op):
assert read_op is False, f'Read operation is not currently supported'
self.ctxt = self._create_context(args, tid, read_op)
self.zipfile_serialization = not args.torch_legacy_save
def fini(self):
if self.ctxt[USE_CPU_LOCKED_TENSOR]:
for buf in [BUFFER, FAST_IO_BUFFER]:
self.ctxt[HANDLE].free_cpu_locked_tensor(self.ctxt[buf])
self.ctxt[BUFFER].detach()
self.ctxt[BUFFER] = None
def read(self, args, tid):
start_time = time.time()
torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def write(self, args, tid):
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(self.ctxt[FILE]):
os.remove(self.ctxt[FILE])
ds_file_writer = FastFileWriter(file_path=self.ctxt[FILE],
aio_handle=self.ctxt[HANDLE],
pinned_tensor=self.ctxt[FAST_IO_BUFFER])
start_time = time.time()
torch.save(obj=self.ctxt[BUFFER], f=ds_file_writer, _use_new_zipfile_serialization=self.zipfile_serialization)
ds_file_writer.close() # Force flush to storage
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
ds_file_writer._dump_state()
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
io_parallel = args.io_parallel if args.io_parallel else 1
aio_handle = AsyncIOBuilder().load().aio_handle(args.block_size, args.queue_depth, args.single_submit,
not args.sequential_requests, io_parallel)
if args.gpu:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}')
else:
buffer = create_page_locked_tensor(args.io_size, args.use_accelerator_pin_memory, aio_handle)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
fast_io_buffer = create_page_locked_tensor(args.fast_io_size, args.use_accelerator_pin_memory, aio_handle)
task_log(tid, f'created torch_fastio engine')
ctxt = {}
ctxt[FILE] = filename
ctxt[NUM_BYTES] = args.io_size
ctxt[BUFFER] = buffer
ctxt[HANDLE] = aio_handle
ctxt[FAST_IO_BUFFER] = fast_io_buffer
ctxt[ELAPSED_SEC] = 0
ctxt[USE_CPU_LOCKED_TENSOR] = not args.use_accelerator_pin_memory
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
return ctxt

View File

@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import os
import time
from .test_ds_aio_utils import task_log, create_filename, create_file, create_page_locked_tensor
from .ds_aio_constants import *
class TorchIO_Engine(object):
def __init__(self, args, tid, read_op):
self.ctxt = self._create_context(args, tid, read_op)
self.zipfile_serialization = not args.torch_legacy_save
def fini(self):
self.ctxt[BUFFER].detach()
self.ctxt[BUFFER] = None
def read(self, args, tid):
start_time = time.time()
torch.load(f=self.ctxt[FILE], map_location=self.ctxt[BUFFER].device)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def write(self, args, tid):
# Avoid overwriting existing files as it could be artificially faster
if os.path.isfile(self.ctxt[FILE]):
os.remove(self.ctxt[FILE])
start_time = time.time()
torch.save(obj=self.ctxt[BUFFER], f=self.ctxt[FILE], _use_new_zipfile_serialization=self.zipfile_serialization)
end_time = time.time()
self.ctxt[ELAPSED_SEC] += end_time - start_time
def _create_context(self, args, tid, read_op):
io_string = "Read" if read_op else "Write"
device_id, folder = args.mapping_list[tid]
filename = create_filename(folder, args.read, args.io_size, tid)
if args.read and not (os.path.isfile(filename) and os.path.getsize(filename) == args.io_size):
create_file(filename, args.io_size)
task_log(tid, f'Allocate tensor of size {args.io_size} bytes')
if args.gpu:
buffer = torch.randint(high=128, size=(args.io_size, ), dtype=torch.uint8, device=f'cuda:{device_id}')
else:
buffer = create_page_locked_tensor(args.io_size, True)
task_log(tid,
f'{io_string} file {filename} of size {args.io_size} bytes from buffer on device {buffer.device}',
force=True)
task_log(tid, f'created torch_io engine')
ctxt = {}
ctxt[FILE] = filename
ctxt[NUM_BYTES] = args.io_size
ctxt[BUFFER] = buffer
ctxt[ELAPSED_SEC] = 0
return ctxt

View File

@ -20,7 +20,7 @@ class CheckpointEngine(object):
def __init__(self, config_params=None):
pass
def create(self, tag):
def create(self, info:CheckpointCommitInfo):
# create checkpoint on give tag for save/load.
pass
@ -30,8 +30,8 @@ class CheckpointEngine(object):
def load(self, path: str, map_location=None):
pass
def commit(self, tag):
# to tell checkpoint services if all files are ready.
def commit(self, info:CheckpointCommitInfo):
# to tell checkpoint services if all files are readys.
pass
```

View File

@ -3,3 +3,9 @@
# DeepSpeed Team
'''Copyright The Microsoft DeepSpeed Team'''
from .fast_checkpoint_engine import FastCheckpointEngine
from .torch_checkpoint_engine import TorchCheckpointEngine
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
from .checkpoint_engine import CheckpointCommitInfo
from .utils import create_checkpoint_engine

View File

@ -5,26 +5,56 @@
import os
import abc
from abc import ABC
class CheckpointEngine(object):
from dataclasses import dataclass
@dataclass
class CheckpointCommitInfo(object):
tag: str
save_dir: str
save_latest: bool
class CheckpointEngine(ABC):
# init checkpoint engine for save/load
def __init__(self, config_params=None):
pass
self.name = None
def create(self, tag):
@abc.abstractmethod
def create(self, info: CheckpointCommitInfo):
# create checkpoint on give tag for save/load.
pass
...
@abc.abstractmethod
def save(self, state_dict, path: str):
...
def makedirs(self, path, exist_ok=False):
os.makedirs(path, exist_ok=exist_ok)
def save(self, state_dict, path: str):
pass
@abc.abstractmethod
def load(self, path: str, map_location=None):
...
@abc.abstractmethod
def commit(self, info: CheckpointCommitInfo):
# to tell checkpoint services if all files are ready.
...
def is_data_parallel_writer(self, dp_rank):
return dp_rank == 0
def is_decoupled(self):
return False
def set_commit_info(self, info: CheckpointCommitInfo):
pass
def commit(self, tag):
# to tell checkpoint services if all files are ready.
def get_commit_info(self):
return None
def cleanup(self):
pass

View File

@ -0,0 +1,161 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import torch.multiprocessing as mp
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine, CheckpointCommitInfo
from deepspeed.runtime.checkpoint_engine.fast_checkpoint_engine import FastCheckpointEngine
from deepspeed import comm as dist
from deepspeed.runtime.utils import get_checkpoint_folder_size
from enum import Enum
class DecoupledEvent(Enum):
SAVE_EVENT = 1
COMMIT_EVENT = 2
EXIT_EVENT = 3
class CheckpointSize(object):
def __init__(self):
self._pre = None
self._post = None
self._gigabytes = None
def gb_size(self):
return self._gigabytes
def set_pre_size(self, size):
self._pre = size
def set_post_size(self, size):
self._post = size
self._gigabytes = (self._post - self._pre) / (1024**3)
mp.set_start_method('spawn', force=True)
def init_decoupled_checkpoint(config_params, dp_writer_config, save_event, save_queue, optimize_dp_state):
checkpoint_engine = FastCheckpointEngine(config_params, dp_writer_config, optimize_dp_state)
print(f'Created FastCheckpointEngine for Decoupled Checkpointing')
save_path_list = []
while True:
(save_info, event_type) = save_queue.get()
if event_type == DecoupledEvent.SAVE_EVENT and save_info is not None:
state_dict, save_path = save_info
# print(f'Received decoupled checkpoint request for {save_path=}')
save_path_list.append(save_path)
checkpoint_engine.save(state_dict, save_path)
del state_dict
# print(f'Completed decoupled checkpoint request for {save_path=}')
if event_type == DecoupledEvent.COMMIT_EVENT:
# print(f'Recieved commit request for {save_path_list=}')
save_path_list = []
save_event.set()
if event_type == DecoupledEvent.EXIT_EVENT:
# print(f'Received decoupled exit request')
break
ENGINE_NAME = "DecoupledCheckpointEngine"
class DecoupledCheckpointEngine(CheckpointEngine):
def __init__(self, config_params, dp_writer_config, optimize_dp_state):
super().__init__(config_params)
self.name = ENGINE_NAME
self.dp_writer_config = dp_writer_config
self.commit_info = None
self.checkpoint_size = CheckpointSize()
self.global_rank = dist.get_rank()
self.optimize_dp_state = optimize_dp_state
if dp_writer_config is None:
self.save_event = None
self.save_queue = None
self.ckpt_process = None
self.local_rank = None
print(
f'[{ENGINE_NAME}]: No checkpoint process self.global_rank={self.global_rank} self.dp_writer_config={self.dp_writer_config}'
)
else:
self.save_event = mp.Event()
self.save_queue = mp.SimpleQueue()
engine_args = (config_params, dp_writer_config, self.save_event, self.save_queue, self.optimize_dp_state)
self.ckpt_process = mp.Process(target=init_decoupled_checkpoint, args=engine_args)
self.ckpt_process.start()
self.local_rank = dp_writer_config.local_rank
print(
f'[{ENGINE_NAME}]: Create checkpoint process self.global_rank={self.global_rank} self.ckpt_process.pid={self.ckpt_process.pid} self.dp_writer_config={self.dp_writer_config}'
)
def __del__(self):
self.cleanup()
def create(self, info: CheckpointCommitInfo):
self.commit_info = info
if self.checkpoint_size.gb_size() is None:
pre_size = get_checkpoint_folder_size(info.save_dir, info.tag, self.local_rank)
self.checkpoint_size.set_pre_size(pre_size)
def load(self, path: str, map_location=None):
sd = torch.load(path, map_location=map_location)
return sd
def save(self, state_dict, path: str):
if self.ckpt_process is None:
return
save_info = (state_dict, path)
self.save_queue.put((save_info, DecoupledEvent.SAVE_EVENT))
def commit(self, info: CheckpointCommitInfo):
assert info == self.commit_info
if self.ckpt_process is not None:
self.save_queue.put((None, DecoupledEvent.COMMIT_EVENT))
# print(f'[begin] wait for decoupled complete for {info.tag}')
self.save_event.wait()
# print(f'[end] wait for decoupled complete for {info.tag}')
self.save_event.clear()
self.commit_info = None
if self.checkpoint_size.gb_size() is None:
dist.barrier()
post_size = get_checkpoint_folder_size(info.save_dir, info.tag, self.local_rank)
self.checkpoint_size.set_post_size(post_size)
if self.global_rank == 0:
print(
f'{self.name} self.global_rank={self.global_rank} created checkpoint of {round(self.checkpoint_size.gb_size(), 2)} GB'
)
return True
def get_commit_info(self):
# print(f'getting commit info {self.commit_info=}')
return self.commit_info
def is_decoupled(self):
return True
def cleanup(self):
# print(f'Inside {self.name} cleanup')
if self.get_commit_info() is not None:
self.commit(self.commit_info)
if self.ckpt_process is not None:
self.save_queue.put((None, DecoupledEvent.EXIT_EVENT))
self.ckpt_process.join()
self.ckpt_process = None
self.save_queue = None
def is_data_parallel_writer(self, dp_rank):
return self.ckpt_process is not None

View File

@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine, CheckpointCommitInfo
from deepspeed.runtime.model_checkpointing import (
CHECKPOINT_WRITER,
CHECKPOINT_SERIALIZATION,
CheckpointWriterFactory,
)
class FastCheckpointEngine(CheckpointEngine):
def __init__(self, config_params, dp_writer_config, optimize_dp_state):
super().__init__(config_params)
self.name = 'FastCheckpointEngine'
self.serialization_enabled = config_params.checkpoint_config[CHECKPOINT_SERIALIZATION]
self.optimize_dp_state = optimize_dp_state
if dp_writer_config is None:
self._writer = None
else:
self._writer = CheckpointWriterFactory(writer_config=config_params.checkpoint_config[CHECKPOINT_WRITER],
aio_config=config_params.aio_config,
dp_writer_config=dp_writer_config)
def create(self, info: CheckpointCommitInfo):
pass
def save(self, state_dict, path: str):
if self._writer is None:
return
torch.save(obj=state_dict,
f=self._writer.create_writer(path, self.optimize_dp_state),
_use_new_zipfile_serialization=self.serialization_enabled)
self._writer.release_writer()
def load(self, path: str, map_location=None):
sd = torch.load(path, map_location=map_location)
return sd
def commit(self, info: CheckpointCommitInfo):
return True
def is_data_parallel_writer(self, dp_rank):
return self._writer is not None

View File

@ -8,7 +8,7 @@ import torch
import torch_nebula
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine
CheckpointEngine, CheckpointCommitInfo
from deepspeed.utils import logger, log_dist
from deepspeed.nebula.constants import *
@ -21,6 +21,7 @@ class NebulaCheckpointEngine(CheckpointEngine):
def __init__(self, config_params=None):
super().__init__(config_params)
self.name = "NebulaCheckpointEngine"
self.checkpoint = None
self.tag_flag = None
self.enable_nebula_load = config_params.enable_nebula_load
@ -35,22 +36,21 @@ class NebulaCheckpointEngine(CheckpointEngine):
}
torch_nebula.init(**nebula_config_params)
def create(self, tag):
log_dist(f"[Nebula] Start Checkpoint for tag:{tag}", ranks=[0])
def create(self, info: CheckpointCommitInfo):
log_dist(f"[Nebula] Start Checkpoint for tag:{info.tag}", ranks=[0])
# -2 means: customer needs to explicitly tell nebula
# current checkpoint is complete by commit method.
self.checkpoint = torch_nebula.Checkpoint(tag, -2)
# current checkpoint is complete by commit methond.
self.checkpoint = torch_nebula.Checkpoint(info.tag, -2)
def save(self, state_dict, path: str):
log_dist(f"[Nebula] Create dummy files for loading.")
torch.save("", path)
tag = _get_tag_from_path(path)
partition_name = os.path.basename(path)
logger.info(f"[Nebula] Saving {partition_name} under tag {tag}...")
self.checkpoint.save(partition_name, state_dict)
logger.info(f"[Nebula] Saved {partition_name} under tag {tag}.")
return None
partititon_name = os.path.basename(path)
logger.info(f"[Nebula] Saving {partititon_name} under tag {tag}...")
self.checkpoint.save(partititon_name, state_dict)
logger.info(f"[Nebula] Saved {partititon_name} under tag {tag}.")
def load(self, path: str, map_location=None):
tag = _get_tag_from_path(path)
@ -97,7 +97,8 @@ class NebulaCheckpointEngine(CheckpointEngine):
logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
return partition
def commit(self, tag):
def commit(self, info: CheckpointCommitInfo):
tag = info.tag
# nebula commit will be call when all files under give tag are ready to be persisted in the async way.
logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting")
commit_rls = self.checkpoint.commit()

View File

@ -4,31 +4,40 @@
# DeepSpeed Team
import torch
from deepspeed.utils import logger, log_dist
from deepspeed.utils import log_dist
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
CheckpointEngine
CheckpointEngine, CheckpointCommitInfo
from deepspeed.runtime.model_checkpointing import CHECKPOINT_SERIALIZATION
ENGINE_NAME = "TorchCheckpointEngine"
class TorchCheckpointEngine(CheckpointEngine):
def __init__(self, config_params=None):
super().__init__(config_params)
self.name = ENGINE_NAME
if config_params is None:
self.zipfile_serialization = False
else:
self.zipfile_serialization = config_params.checkpoint_config[CHECKPOINT_SERIALIZATION]
log_dist(f'[{ENGINE_NAME}] Initialized with serialization = {self.zipfile_serialization}', ranks=[0])
def create(self, tag):
log_dist(f"[Torch] Checkpoint {tag} is about to be saved!", ranks=[0])
def create(self, info: CheckpointCommitInfo):
log_dist(f"[Torch] Checkpoint {info.tag} is begin to save!", ranks=[0])
pass
def save(self, state_dict, path: str):
logger.info(f"[Torch] Saving {path}...")
torch.save(state_dict, path)
logger.info(f"[Torch] Saved {path}.")
return None
# log_dist(f"[Torch] Saving [begin] {path}... {self.zipfile_serialization=}", ranks=[0])
torch.save(state_dict, path, _use_new_zipfile_serialization=self.zipfile_serialization)
# log_dist(f"[Torch] Saving [end] {path}... {self.zipfile_serialization=}", ranks=[0])
def load(self, path: str, map_location=None):
logger.info(f"[Torch] Loading checkpoint from {path}...")
log_dist(f"[Torch] Begin Load checkpoint from {path}...", ranks=[0])
partition = torch.load(path, map_location=map_location, weights_only=False)
logger.info(f"[Torch] Loaded checkpoint from {path}.")
log_dist(f"[Torch] End Load checkpoint from {path}...", ranks=[0])
return partition
def commit(self, tag):
logger.info(f"[Torch] Checkpoint {tag} is ready now!")
def commit(self, info: CheckpointCommitInfo):
#logger.info(f"[Torch] Checkpoint {tag} is ready now!")
return True

View File

@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.model_checkpointing.constants import *
from deepspeed.runtime.model_checkpointing.utils import create_data_parallel_writer_config
from deepspeed.utils import logger
from .decoupled_checkpoint_engine import DecoupledCheckpointEngine
from .fast_checkpoint_engine import FastCheckpointEngine
from .torch_checkpoint_engine import TorchCheckpointEngine
def create_checkpoint_engine(config_params, groups, zero_stage, has_moe_layers, optimize_dp_state):
if config_params is not None:
if config_params.checkpoint_config[CHECKPOINT_WRITER] is not None:
writer_config = config_params.checkpoint_config[CHECKPOINT_WRITER]
dp_writer_config = create_data_parallel_writer_config(
groups=groups,
parallel_unit=writer_config[CHECKPOINT_DATA_PARALLEL],
zero_stage=zero_stage,
has_moe_layers=has_moe_layers)
if writer_config[CHECKPOINT_WRITER_DECOUPLED]:
return DecoupledCheckpointEngine(config_params, dp_writer_config, optimize_dp_state)
else:
return FastCheckpointEngine(config_params, dp_writer_config, optimize_dp_state)
if config_params is not None and config_params.nebula_config.enabled:
try:
from .nebula_checkpoint_engine import NebulaCheckpointEngine
except ImportError as err:
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
return TorchCheckpointEngine(config_params)
else:
return NebulaCheckpointEngine(config_params=config_params.nebula_config)
return TorchCheckpointEngine(config_params)

View File

@ -62,6 +62,7 @@ from ..nebula.config import DeepSpeedNebulaConfig
from ..compression.config import get_compression_config, get_quantize_enabled
from ..compression.constants import *
from .swap_tensor.aio_config import get_aio_config
from .model_checkpointing.config import get_checkpoint_config
from .tensor_parallel import get_tensor_parallel_config
from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy
@ -918,6 +919,7 @@ class DeepSpeedConfig(object):
self.dataloader_drop_last = get_dataloader_drop_last(param_dict)
self.nebula_config = DeepSpeedNebulaConfig(param_dict)
self.checkpoint_config = get_checkpoint_config(param_dict)
self.weight_quantization_config = WeightQuantConfig(
**param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None

View File

@ -43,6 +43,9 @@ from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
MUSGD_OPTIMIZER, LION_OPTIMIZER
from deepspeed.runtime.model_checkpointing.constants import ValidationMode, \
CHECKPOINT_TAG_VALIDATION, CHECKPOINT_WRITER, CHECKPOINT_SERIALIZATION
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
@ -62,6 +65,7 @@ from deepspeed.compression.constants import \
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
@ -84,11 +88,12 @@ from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \
RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \
RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
from deepspeed.runtime.checkpoint_engine import (create_checkpoint_engine, TorchCheckpointEngine, CheckpointCommitInfo)
from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler
from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict
from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from .pipe.module import PipelineModule
@ -351,7 +356,7 @@ class DeepSpeedEngine(Module):
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
self._configure_checkpointing(dist_init_required)
self._configure_checkpointing()
if self.eigenvalue_enabled():
self.eigenvalue = self._configure_eigenvalue()
@ -496,6 +501,9 @@ class DeepSpeedEngine(Module):
prepend=True,
with_kwargs=True)
def __del__(self):
self.destroy()
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
@ -503,6 +511,9 @@ class DeepSpeedEngine(Module):
get_deepcompile_handle().cleanup()
debug_clear_module_and_param_names()
if self.checkpoint_engine is not None and self.checkpoint_engine.is_decoupled():
self.checkpoint_engine.cleanup()
def _get_model_parameters(self):
if self.autotuning_profile_model_info():
self.autotuning_model_info = {}
@ -605,11 +616,17 @@ class DeepSpeedEngine(Module):
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def checkpoint_serialization_enabled(self):
return self._config.checkpoint_config[CHECKPOINT_SERIALIZATION]
def checkpoint_writer_enabled(self):
return self._config.checkpoint_config[CHECKPOINT_WRITER] is not None
def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled
return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] != ValidationMode.IGNORE
def checkpoint_tag_validation_fail(self):
return self._config.checkpoint_tag_validation_fail
return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] == ValidationMode.FAIL
def elasticity_enabled(self):
return self._config.elasticity_enabled
@ -1069,27 +1086,22 @@ class DeepSpeedEngine(Module):
log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
def _configure_checkpointing(self, dist_init_required):
self.checkpoint_engine = TorchCheckpointEngine()
if self._config is not None and self._config.nebula_config.enabled:
try:
from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \
NebulaCheckpointEngine
self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)
except ImportError as err:
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
self.checkpoint_engine = TorchCheckpointEngine()
def _configure_checkpointing(self):
# Enable optimization to parallelize checkpointing of DP state
optimize_dp_state = not self.zero_optimization_partition_weights()
self.checkpoint_engine = create_checkpoint_engine(config_params=self._config,
groups=groups,
zero_stage=self.zero_optimization_stage(),
has_moe_layers=self.has_moe_layers,
optimize_dp_state=optimize_dp_state)
dp_rank = groups._get_sequence_data_parallel_rank()
rank = self.local_rank if self.use_node_local_storage() else dp_rank
# only the first data parallel process needs to store the model checkpoint
# if you want to use node local storage this must be done by rank 0 on each
# node
self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights()
and self.is_first_weights_partition_group())
# Determine if this data parallel process needs to store the model checkpoint
if self.checkpoint_engine.is_data_parallel_writer(rank) \
or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()):
self.save_non_zero_checkpoint = True
if self.zero_optimization() or self.bfloat16_enabled():
param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
@ -1409,7 +1421,7 @@ class DeepSpeedEngine(Module):
self._check_for_duplicates(basic_optimizer)
self.basic_optimizer = basic_optimizer
log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0])
log_dist(f"DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", ranks=[0])
optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
@ -1574,10 +1586,12 @@ class DeepSpeedEngine(Module):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if APEX_INSTALLED:
fused_opts = (apex.optimizers.FusedAdam, FusedAdam)
else:
fused_opts = FusedAdam
if isinstance(optimizer, fused_opts) \
or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]:
if self.dynamic_loss_scale():
@ -2373,6 +2387,9 @@ class DeepSpeedEngine(Module):
if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1
if self.checkpoint_engine.is_decoupled():
self._commit_decoupled_checkpoint()
if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
and self.quantizer.any_precision_switch()):
log_dist(f"computing eigenvalue...", ranks=[0])
@ -3317,7 +3334,9 @@ class DeepSpeedEngine(Module):
# Ensure tag is a string
tag = str(tag)
self.checkpoint_engine.create(tag)
commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest)
self.checkpoint_engine.create(commit_info)
# Ensure checkpoint tag is consistent across ranks
self._checkpoint_tag_validation(tag)
@ -3364,8 +3383,9 @@ class DeepSpeedEngine(Module):
self.optimizer.checkpoint_event_epilogue()
# Save latest checkpoint tag
if not self.checkpoint_engine.is_decoupled():
self.checkpoint_engine.commit(tag)
if save_latest and rank == 0:
if save_latest and self.global_rank == 0:
with open(os.path.join(save_dir, 'latest'), 'w') as fd:
fd.write(tag)
@ -3373,6 +3393,22 @@ class DeepSpeedEngine(Module):
return True
def _commit_decoupled_checkpoint(self):
assert self.checkpoint_engine.is_decoupled(), \
f'{self.checkpoint_engine} is not a Decoupled Checkpoint Engine'
commit_info = self.checkpoint_engine.get_commit_info()
if commit_info is None:
return
self.checkpoint_engine.commit(commit_info)
if self.global_rank == 0 and commit_info.save_latest:
with open(os.path.join(commit_info.save_dir, 'latest'), 'w') as fd:
fd.write(commit_info.tag)
dist.barrier()
def _get_non_moe_state_dict(self, full_state_dict):
"""
Get the state dict of the non-moe layers
@ -3385,6 +3421,7 @@ class DeepSpeedEngine(Module):
def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
save_path = self._get_ckpt_name(save_dir, tag)
# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None.
@ -3398,7 +3435,8 @@ class DeepSpeedEngine(Module):
expp_rank = groups._get_expert_parallel_rank(group_name)
exp_dp_rank = groups._get_expert_data_parallel_rank(group_name)
# print(expp_rank, exp_dp_rank)
if exp_dp_rank != 0:
# if exp_dp_rank != 0:
if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank):
moe_layer_id += 1
continue
@ -3434,7 +3472,8 @@ class DeepSpeedEngine(Module):
moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu)
if self.random_ltd_enabled():
expert_state_dict = remove_random_ltd_state_dict(expert_state_dict)
self.checkpoint_engine.save(expert_state_dict, moe_save_path)
saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict)
self.checkpoint_engine.save(saveable_state_dict, moe_save_path)
moe_layer_id += 1
self._curr_ckpt_path = os.path.join(save_dir, tag)
@ -3446,14 +3485,17 @@ class DeepSpeedEngine(Module):
# In the case of E + D parallelism, only the
# first expert parallel group should save the expert weights
# since each expert parallel group is a copy of the model's experts
if exp_dp_rank == 0:
if not self.checkpoint_engine.is_data_parallel_writer(exp_dp_rank):
return
# Save optimizer states. They are different across each exp parallel rank.
optimizer_state = {
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
}
# TODO: why use BufferedWriter not the path
file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank)
self.checkpoint_engine.save(optimizer_state, file_path)
saveable_state_dict = clone_tensors_for_torch_save(optimizer_state)
self.checkpoint_engine.save(saveable_state_dict, file_path)
# Load flow uses below saved file for model parameters, RNG and more
if groups._get_data_parallel_rank() == 0:
@ -3492,7 +3534,8 @@ class DeepSpeedEngine(Module):
}
state.update(client_state)
logger.info(f'Saving model checkpoint: {save_path}')
self.checkpoint_engine.save(state, save_path)
saveable_state_dict = clone_tensors_for_torch_save(state)
self.checkpoint_engine.save(saveable_state_dict, save_path)
def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint):
name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name)
@ -3513,8 +3556,6 @@ class DeepSpeedEngine(Module):
if rank == self.global_rank:
success = self._create_checkpoint_file(save_dir, tag, True)
dist.barrier(group=self.optimizer.dp_process_group)
return success
def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False):
@ -3555,10 +3596,10 @@ class DeepSpeedEngine(Module):
ds_config=self.config,
ds_version=version)
state.update(client_state)
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
if self.save_non_zero_checkpoint:
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
self.checkpoint_engine.save(state, save_path)
self.checkpoint_engine.save(state_dict=state, path=save_path)
def _get_buffer_names(self):
buffer_names = []
@ -3709,7 +3750,7 @@ class DeepSpeedEngine(Module):
if self.global_rank == 0:
self._copy_recovery_script(save_path)
ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero'
logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')
#logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}')
def _replace_module_consolidated_state_dict(self):
"""
@ -3864,7 +3905,8 @@ class DeepSpeedEngine(Module):
tag = f"global_step{self.global_steps}"
tag = str(tag)
self.checkpoint_engine.create(tag)
commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=False)
self.checkpoint_engine.create(commit_info)
if dist.get_rank() == 0:
self.checkpoint_engine.makedirs(save_dir, exist_ok=True)

View File

@ -55,7 +55,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.timers = timers
self.deepspeed = deepspeed
self.has_moe_layers = has_moe_layers
self.using_pipeline = self.deepspeed.pipeline_parallelism
self.using_pipeline = getattr(self.deepspeed, 'pipeline_parallelism', False)
if not get_accelerator().is_available():
raise SystemError("Cannot use fp16 without accelerator.")
self.optimizer = init_optimizer
@ -252,11 +252,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
return self.step_fused_adam()
# First determine if there is overflow.
if self.timers:
self.timers(OVERFLOW_CHECK_TIMER).start()
fp16_params = []
for i, group in enumerate(self.fp16_groups):
fp16_params.extend([p for p in group if p.grad is not None])
self.overflow = self.overflow_checker.has_overflow(fp16_params)
if self.timers:
self.timers(OVERFLOW_CHECK_TIMER).stop()
prev_scale = self.cur_scale
self._update_scale(self.overflow)
@ -271,6 +273,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
for p in group:
p.grad = None
if self.timers:
self.timers.log(OVERFLOW_TIMERS)
return self.overflow
@ -309,6 +312,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
for p in group:
p.grad = None
if self.timers:
self.timers(COMPUTE_NORM_TIMER).start()
all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm,
@ -322,23 +326,29 @@ class FP16_Optimizer(DeepSpeedOptimizer):
norm_type=self.norm_type)
scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
if self.timers:
self.timers(COMPUTE_NORM_TIMER).stop()
# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
if self.timers:
self.timers(UNSCALE_AND_CLIP_TIMER).start()
self.unscale_and_clip_grads(grads_groups_flat, scaled_global_grad_norm)
if self.timers:
self.timers(UNSCALE_AND_CLIP_TIMER).stop()
if self.timers:
self.timers(BASIC_STEP_TIMER).start()
self.optimizer.step()
if self.timers:
self.timers(BASIC_STEP_TIMER).stop()
#get rid of the fp32 gradients. Not needed anymore
for group in self.fp32_groups_flat:
group.grad = None
if self.timers:
self.timers(UPDATE_FP16_TIMER).start()
for i in range(len(self.fp16_groups)):
@ -346,8 +356,10 @@ class FP16_Optimizer(DeepSpeedOptimizer):
for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data)
self.has_executed_step = True
if self.timers:
self.timers(UPDATE_FP16_TIMER).stop()
if self.timers:
self.timers.log(STEP_TIMERS)
return self.overflow

View File

@ -22,6 +22,7 @@ Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
"""
import torch
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
from deepspeed import comm as dist
from deepspeed.utils import logger
@ -39,12 +40,13 @@ def to_python_float(t):
return t[0]
class LossScalerBase:
class LossScalerBase(DeepSpeedConfigObject):
"""LossScalarBase
Base class for a loss scaler
"""
def __init__(self, cur_scale):
super(LossScalerBase, self).__init__()
self.cur_scale = cur_scale
self.dynamic = False

View File

@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .constants import *
from .writer_factory import CheckpointWriterFactory

View File

@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from deepspeed.runtime.config_utils import get_scalar_param
from .constants import *
VALID_VALUES = {
CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_MODES,
CHECKPOINT_WRITER_TYPE: CHECKPOINT_WRITER_TYPES,
CHECKPOINT_DATA_PARALLEL: CHECKPOINT_DATA_PARALLEL_UNITS
}
CHECKPOINT_DEFAULT_DICT = {
CHECKPOINT_TAG_VALIDATION: CHECKPOINT_TAG_VALIDATION_DEFAULT,
CHECKPOINT_SERIALIZATION: CHECKPOINT_SERIALIZATION_DEFAULT,
CHECKPOINT_WRITER: CHECKPOINT_WRITER_DEFAULT
}
def _validate_config_values(config_name, config_dict, valid_values):
for key, value in config_dict.items():
if value is None:
continue
if key in valid_values.keys():
assert value in valid_values[key], \
f"{config_name} contains invalid value {value} for {key}, expecting one of {valid_values[key]}"
def _make_upper_case(value):
return value if value is None else value.upper()
def get_checkpoint_writer_config(param_dict):
writer_dict = param_dict.get(CHECKPOINT_WRITER, None)
if writer_dict is None:
return CHECKPOINT_WRITER_DEFAULT
writer_config = {
CHECKPOINT_WRITER_TYPE:
_make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_WRITER_TYPE, CHECKPOINT_WRITER_TYPE_DEFAULT)),
CHECKPOINT_IO_BUFFER_SIZE:
get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_SIZE, CHECKPOINT_IO_BUFFER_SIZE_DEFAULT),
CHECKPOINT_IO_BUFFER_DOUBLE:
get_scalar_param(writer_dict, CHECKPOINT_IO_BUFFER_DOUBLE, CHECKPOINT_IO_BUFFER_DOUBLE_DEFAULT),
CHECKPOINT_IO_STATISTICS:
get_scalar_param(writer_dict, CHECKPOINT_IO_STATISTICS, CHECKPOINT_IO_STATISTICS_DEFAULT),
CHECKPOINT_DATA_PARALLEL:
_make_upper_case(get_scalar_param(writer_dict, CHECKPOINT_DATA_PARALLEL, CHECKPOINT_DATA_PARALLEL_DEFAULT)),
CHECKPOINT_WRITER_DECOUPLED:
get_scalar_param(writer_dict, CHECKPOINT_WRITER_DECOUPLED, CHECKPOINT_WRITER_DECOUPLED_DEFAULT),
CHECKPOINT_IO_MULTIPLIER:
get_scalar_param(writer_dict, CHECKPOINT_IO_MULTIPLIER, CHECKPOINT_IO_MULTIPLIER_DEFAULT),
}
_validate_config_values(CHECKPOINT_WRITER, writer_config, VALID_VALUES)
return writer_config
def get_checkpoint_config(param_dict):
checkpoint_dict = param_dict.get(CHECKPOINT, None)
if checkpoint_dict is None:
return CHECKPOINT_DEFAULT_DICT
checkpoint_config = {
CHECKPOINT_TAG_VALIDATION:
get_scalar_param(checkpoint_dict, CHECKPOINT_TAG_VALIDATION, CHECKPOINT_TAG_VALIDATION_DEFAULT).upper(),
CHECKPOINT_SERIALIZATION:
get_scalar_param(checkpoint_dict, CHECKPOINT_SERIALIZATION, CHECKPOINT_SERIALIZATION_DEFAULT),
CHECKPOINT_WRITER:
get_checkpoint_writer_config(checkpoint_dict)
}
_validate_config_values(CHECKPOINT, checkpoint_config, VALID_VALUES)
return checkpoint_config

View File

@ -0,0 +1,85 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
#########################################
# Validation modes
#########################################
class ValidationMode:
WARN = "WARN"
IGNORE = "IGNORE"
FAIL = "FAIL"
#########################################
# Checkpoint config params
#########################################
# "checkpoint": {tag_validation=["Ignore"|"Warn"|"Fail"]}
CHECKPOINT_FORMAT = '''
"checkpoint": {
"tag_validation": [Ignore|Warn|Fail],
"checkpoint_serialization": False,
"writer": {
"type": [mock|python|fast],
"decoupled": [True|False]
"io_buffer_size": 64e6,
"io_buffer_double": True,
"show_statistics": False,
"data_parallel": [replica|socket|machine],
"io_multiplier": 1,
}
}
'''
CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation"
CHECKPOINT_TAG_VALIDATION_DEFAULT = ValidationMode.WARN
CHECKPOINT_TAG_VALIDATION_MODES = [ValidationMode.WARN, ValidationMode.IGNORE, ValidationMode.FAIL]
CHECKPOINT_SERIALIZATION = "checkpoint_serialization"
CHECKPOINT_SERIALIZATION_DEFAULT = True
CHECKPOINT_WRITER = "writer"
CHECKPOINT_WRITER_DEFAULT = None
CHECKPOINT_WRITER_TYPE = "type"
class CheckpointWriterType:
MOCK = "MOCK"
PYTHON = "PYTHON"
FAST = "FAST"
CHECKPOINT_WRITER_TYPE_DEFAULT = CheckpointWriterType.FAST
CHECKPOINT_WRITER_TYPES = [CheckpointWriterType.MOCK, CheckpointWriterType.PYTHON, CheckpointWriterType.FAST]
CHECKPOINT_IO_BUFFER_SIZE = "io_buffer_size"
CHECKPOINT_IO_BUFFER_SIZE_DEFAULT = 64 * (1024**2)
CHECKPOINT_IO_BUFFER_DOUBLE = "io_buffer_double"
CHECKPOINT_IO_BUFFER_DOUBLE_DEFAULT = True
CHECKPOINT_IO_MULTIPLIER = "io_multiplier"
CHECKPOINT_IO_MULTIPLIER_DEFAULT = 1
CHECKPOINT_IO_STATISTICS = "show_statistics"
CHECKPOINT_IO_STATISTICS_DEFAULT = False
CHECKPOINT_DATA_PARALLEL = "data_parallel"
CHECKPOINT_DATA_PARALLEL_DEFAULT = None
class CheckpointDataParallel:
REPLICA = "REPLICA"
SOCKET = "SOCKET"
MACHINE = "MACHINE"
CHECKPOINT_DATA_PARALLEL_UNITS = [
CheckpointDataParallel.REPLICA, CheckpointDataParallel.SOCKET, CheckpointDataParallel.MACHINE
]
CHECKPOINT_WRITER_DECOUPLED = "decoupled"
CHECKPOINT_WRITER_DECOUPLED_DEFAULT = False

View File

@ -0,0 +1,216 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from dataclasses import dataclass
from deepspeed.checkpoint.reshape_utils import partition_data
from deepspeed.runtime.zero.config import ZeroStageEnum
from .constants import *
@dataclass
class DataParallelWriterConfig(object):
world_size: int
rank: int
global_rank: int
local_rank: int
pure_dp: bool
class DataParallelWriterFactory(object):
def __init__(self, uni_parallel_info, parallel_unit):
self._uni_parallel_info = uni_parallel_info
self._parallel_unit = parallel_unit
if parallel_unit == CheckpointDataParallel.SOCKET:
self._num_resources = uni_parallel_info.num_sockets
else:
self._num_resources = uni_parallel_info.num_machines
self._ranks_per_resource = max(1, self._uni_parallel_info.global_world_size // self._num_resources)
def create_config(self, zero_stage, has_moe_layers):
if zero_stage == ZeroStageEnum.weights:
return self._create_config(1, 0)
if has_moe_layers:
writer_config = self._get_expert_data_parallel_config()
else:
writer_config = self._get_data_parallel_config()
if writer_config is None and zero_stage >= ZeroStageEnum.optimizer_states:
return self._create_config(1, 0)
return writer_config
def _create_config(self, world_size, rank):
return DataParallelWriterConfig(world_size=world_size,
rank=rank,
global_rank=self._uni_parallel_info.global_rank,
local_rank=self._uni_parallel_info.local_rank,
pure_dp=self._uni_parallel_info.pure_dp)
def _get_expert_data_parallel_config(self):
ep_info = self._uni_parallel_info.ep_info
if self._parallel_unit is None:
dp_rank = ep_info.dp_rank
return self._create_config(1, 0) if dp_rank == 0 else None
assert self._uni_parallel_info.pure_dp, \
f'3D parallelism is not yet supported for data parallel checkpointing.'
if self._parallel_unit == CheckpointDataParallel.REPLICA or ep_info.ep_world_size == 1:
return self._get_parallel_write_for_ddp(ep_info.dp_world_size, ep_info.dp_rank)
return self._get_expert_parallel_write_for_2d()
def _get_expert_parallel_write_for_2d(self):
ep_info = self._uni_parallel_info.ep_info
def _get_expert_slice_resources(expert_resources, resource_name):
ep_world_size = ep_info.ep_world_size
slices_per_resource = min(self._ranks_per_resource, ep_world_size)
assert slices_per_resource <= len(expert_resources)
ep_num_resources = len(expert_resources)
assert ep_num_resources % slices_per_resource == 0, f'{resource_name}: Expected ep_num_resources={ep_num_resources} to multiple of slices_per_resource={slices_per_resource} for ep_world_size={ep_world_size}'
slice_partitions = partition_data(expert_resources, slices_per_resource)
# print(
# f'edp_resource_partition: self._uni_parallel_info.global_rank={self._uni_parallel_info.global_rank} expert_resources={expert_resources} slices_per_resource={slices_per_resource} ep_world_size={ep_world_size} slice_partitions={slice_partitions}'
# )
resource_index = ep_info.ep_rank % slice_resources
return slice_partitions[resource_index]
dp_ranks = ep_info.dp_peer_ranks
expert_resources = [r // self._ranks_per_resource for r in dp_ranks]
slice_resources = _get_expert_slice_resources(expert_resources, self._parallel_unit)
assert all([idx < self._num_resources for idx in expert_resources]), \
f'Detected invalid resource index in expert_resources={expert_resources}, self._num_resources={self._num_resources}'
return self._assign_resources_to_tensor_slice(slice_resources, ep_info.ep_rank, dp_ranks)
def _get_data_parallel_config(self):
mpu_info = self._uni_parallel_info.mpu_info
if self._parallel_unit is None:
dp_rank = self._uni_parallel_info.dp_rank if mpu_info is None else mpu_info.dp_rank
return self._create_config(1, 0) if dp_rank == 0 else None
if self._uni_parallel_info.pure_dp:
return self._get_parallel_write_for_ddp(self._uni_parallel_info.global_world_size,
self._uni_parallel_info.global_rank)
if self._parallel_unit == CheckpointDataParallel.REPLICA:
return self._create_config(mpu_info.dp_world_size, mpu_info.dp_rank)
return self._get_parallel_write_for_3d()
def _get_parallel_write_for_3d(self):
mpu_info = self._uni_parallel_info.mpu_info
my_global_rank = self._uni_parallel_info.global_rank
def _expand_resources(resource_list, new_size):
old_size = len(resource_list)
if old_size >= new_size:
return resource_list
assert new_size % old_size == 0, f'Expect new_size={new_size} to be multiple of old_size={old_size}'
multiplier = new_size // old_size
new_resource_list = []
for r in resource_list:
new_resource_list += [r] * multiplier
# print(f'expand_resources: {my_global_rank=} {old_size=} {new_size=} {resource_list=} {new_resource_list=}')
return new_resource_list
# Getting resource partition for a tensor slice is a 2-step process
# 1. Get resource partitions for all pipeline stages. A pipeline stage is a 2D grid of size TP x DP
def _get_pipeline_stage_resources(resource_indices):
num_resources = len(resource_indices)
pp_world_size = mpu_info.pp_world_size
if num_resources < pp_world_size:
resource_indices = _expand_resources(resource_indices, pp_world_size)
num_resources = pp_world_size
global_resource_partitions = partition_data(resource_indices, pp_world_size)
pp_rank = mpu_info.pp_rank
return global_resource_partitions[pp_rank]
# 2. Get resource partition for tensor slice. A tensor slice is a 1D vector of size DP
def _get_tensor_slice_resources(resource_indices, resource_name):
pipe_stage_resources = _get_pipeline_stage_resources(resource_indices)
tp_world_size = mpu_info.tp_world_size
if len(pipe_stage_resources) < tp_world_size:
pipe_stage_resources = _expand_resources(pipe_stage_resources, tp_world_size)
tp_num_resources = len(pipe_stage_resources)
assert tp_num_resources % tp_world_size == 0, \
f'{resource_name}: Expected tp_num_resources={tp_num_resources} to multiple of tp_world_size={tp_world_size}'
pipe_stage_resource_partitions = partition_data(pipe_stage_resources, tp_world_size)
tp_rank = mpu_info.tp_rank
return pipe_stage_resource_partitions[tp_rank]
def _get_model_parallel_slice_resources():
# Get resources of my dp peer ranks
resources = [(r // self._ranks_per_resource) for r in mpu_info.dp_peer_ranks]
if len(resources) < self._ranks_per_resource:
resources = _expand_resources(resources, self._ranks_per_resource)
resource_partitions = partition_data(resources, self._ranks_per_resource)
mp_rank = (mpu_info.pp_rank * mpu_info.tp_world_size) + mpu_info.tp_rank
slice_rank = mp_rank % self._ranks_per_resource
return resource_partitions[slice_rank]
num_slices = mpu_info.tp_world_size * mpu_info.pp_world_size
if num_slices > self._ranks_per_resource:
slice_resources = _get_model_parallel_slice_resources()
else:
all_resources = list(range(self._num_resources))
slice_resources = _get_tensor_slice_resources(all_resources, self._parallel_unit)
return self._assign_resources_to_tensor_slice(slice_resources, mpu_info.tp_rank, mpu_info.dp_peer_ranks)
def _get_slice_writers(self, slice_resources, my_dp_ranks):
resource_map = {}
for res in slice_resources:
resource_map[res] = [r for r in my_dp_ranks if (r // self._ranks_per_resource) == res]
# Only one writer per resource, and we conventionally pick the first rank as writer.
return [ranks[0] for ranks in resource_map.values()]
def _assign_resources_to_tensor_slice(self, slice_resources, my_slice_index, my_dp_ranks):
my_global_rank = self._uni_parallel_info.global_rank
slice_writer_ranks = self._get_slice_writers(slice_resources, my_dp_ranks)
my_resource_index = my_global_rank // self._ranks_per_resource
print(
f'resource_assign: my_global_rank={my_global_rank} my_slice_index={my_slice_index} my_dp_ranks={my_dp_ranks} slice_resources={slice_resources} slice_writer_ranks={slice_writer_ranks}'
)
if my_resource_index in slice_resources and my_global_rank in slice_writer_ranks:
my_writer_index = (my_global_rank - slice_writer_ranks[0]) // self._ranks_per_resource
num_slice_writers = len(slice_writer_ranks)
print(
f'slice_writer: my_global_rank={my_global_rank} my_writer_index={my_writer_index} num_slice_writers={num_slice_writers}'
)
return self._create_config(num_slice_writers, my_writer_index)
return None
def _get_parallel_write_for_ddp(self, dp_world_size, dp_rank):
if self._parallel_unit == CheckpointDataParallel.REPLICA:
return self._create_config(dp_world_size, dp_rank)
num_machines = self._uni_parallel_info.num_machines
if self._parallel_unit == CheckpointDataParallel.SOCKET:
if dp_world_size == num_machines:
# There is one rank per machine
return self._create_config(num_machines, dp_rank)
num_sockets = self._uni_parallel_info.num_sockets
ranks_per_socket = dp_world_size // num_sockets
if dp_rank % ranks_per_socket == 0:
return self._create_config(num_sockets, dp_rank // ranks_per_socket)
else:
return None
ranks_per_machine = dp_world_size // num_machines
if dp_rank % ranks_per_machine == 0:
return self._create_config(num_machines, self._uni_parallel_info.machine_rank)
return None

View File

@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from dataclasses import dataclass
from deepspeed import comm as dist
from deepspeed.constants import CROSS_RANK, CROSS_SIZE, LOCAL_RANK
from .data_parallel_writer_factory import DataParallelWriterFactory
# TODO: parse socket number from env.
SOCKETS_PER_MACHINE = 2
@dataclass
class MPUInfo(object):
pp_world_size: int
pp_rank: int
tp_world_size: int
tp_rank: int
dp_world_size: int
dp_peer_ranks: list
dp_rank: int
def _create_model_parallel_info(mpu):
return MPUInfo(pp_world_size=mpu.get_pipeline_model_parallel_world_size(),
pp_rank=mpu.get_pipeline_model_parallel_rank(),
tp_world_size=mpu.get_tensor_model_parallel_world_size(),
tp_rank=mpu.get_tensor_model_parallel_rank(),
dp_world_size=mpu.get_data_parallel_world_size(),
dp_peer_ranks=mpu.get_data_parallel_group_ranks(),
dp_rank=mpu.get_data_parallel_rank())
@dataclass
class ExpertParallelInfo(object):
ep_world_size: int
ep_rank: int
dp_world_size: int
dp_peer_ranks: list
dp_rank: int
def _create_expert_parallel_info(groups):
group_name = groups._get_max_expert_size_name()
return ExpertParallelInfo(ep_world_size=groups._get_expert_parallel_world_size(group_name),
ep_rank=groups._get_expert_parallel_rank(group_name),
dp_world_size=groups._get_expert_data_parallel_world_size(group_name),
dp_peer_ranks=groups._get_expert_data_parallel_group_ranks(group_name),
dp_rank=groups._get_expert_data_parallel_rank(group_name))
@dataclass
class UniversalParallelInfo(object):
global_world_size: int
global_rank: int
local_rank: int
mpu_info: MPUInfo
ep_info: ExpertParallelInfo
pure_dp: bool
num_machines: int
machine_rank: int
num_sockets: int
def create_universal_parallel_info(groups, has_moe_layers):
return UniversalParallelInfo(global_world_size=dist.get_world_size(),
global_rank=dist.get_rank(),
local_rank=int(os.environ[LOCAL_RANK]),
mpu_info=None if groups.mpu is None else _create_model_parallel_info(groups.mpu),
ep_info=_create_expert_parallel_info(groups) if has_moe_layers else None,
pure_dp=groups.mpu is None
or groups.mpu.get_data_parallel_world_size() == dist.get_world_size(),
num_machines=int(os.environ[CROSS_SIZE]),
machine_rank=int(os.environ[CROSS_RANK]),
num_sockets=int(os.environ[CROSS_SIZE]) * SOCKETS_PER_MACHINE)
def create_data_parallel_writer_config(groups, parallel_unit, zero_stage, has_moe_layers):
uni_parallel_info = create_universal_parallel_info(groups, has_moe_layers)
writer_factory = DataParallelWriterFactory(uni_parallel_info, parallel_unit)
return writer_factory.create_config(zero_stage, has_moe_layers)

View File

@ -0,0 +1,95 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder
from deepspeed.io import MockFileWriter, PyFileWriter, FastFileWriter, FastFileWriterConfig
from deepspeed.runtime.swap_tensor.constants import *
from .constants import *
from deepspeed.accelerator import get_accelerator
class CheckpointWriterFactory(object):
def __init__(self, writer_config, aio_config, dp_writer_config):
self._type = writer_config[CHECKPOINT_WRITER_TYPE]
self._io_buffer_size = writer_config[CHECKPOINT_IO_BUFFER_SIZE]
self._io_buffer_double = writer_config[CHECKPOINT_IO_BUFFER_DOUBLE]
self._data_parallel_writer = dp_writer_config
self._io_multiplier = writer_config[CHECKPOINT_IO_MULTIPLIER]
if self._data_parallel_writer.pure_dp:
self._show_statistics = writer_config[CHECKPOINT_IO_STATISTICS] and self._data_parallel_writer is not None
else:
self._show_statistics = writer_config[CHECKPOINT_IO_STATISTICS] and self._data_parallel_writer is not None
self._io_buffer = None
self._dnvme_handle = None
self._writer = None
self._use_gds = False
if self._type == CheckpointWriterType.FAST:
self._use_gds = aio_config[AIO_USE_GDS]
if self._use_gds:
self._setup_for_gds(aio_config)
else:
self._setup_for_aio(aio_config)
print(
f'WriterFactory: self._data_parallel_writer={self._data_parallel_writer} self._show_statistics={self._show_statistics}'
)
def create_writer(self, file_path, optimize_dp_state):
assert self._writer is None, \
f'Cannot create checkpoint writer for {file_path} because writer is currently used for {self._writer.file_path()}.\
Must call writer.release() before reusing to avoid this error.'
if self._type == CheckpointWriterType.MOCK:
self._writer = MockFileWriter(file_path)
elif self._type == CheckpointWriterType.PYTHON:
self._writer = PyFileWriter(file_path)
else:
if optimize_dp_state:
num_parallel_writers = self._data_parallel_writer.world_size * self._io_multiplier
writer_rank = self._data_parallel_writer.rank
file_path = f'{file_path}-{writer_rank}.{num_parallel_writers}'
# print(f'create_dp_writer: {self._data_parallel_writer.global_rank=} {writer_rank=} {num_parallel_writers=} {file_path=}')
else:
num_parallel_writers = 1
writer_rank = 0
# print(f'create_rank0_writer: {self._data_parallel_writer.global_rank=} {writer_rank=} {num_parallel_writers=} {file_path=}')
config = FastFileWriterConfig(dnvme_handle=self._dnvme_handle,
pinned_tensor=self._io_buffer,
double_buffer=self._io_buffer_double,
num_parallel_writers=num_parallel_writers,
writer_rank=writer_rank,
global_rank=self._data_parallel_writer.global_rank)
self._writer = FastFileWriter(file_path=file_path, config=config)
return self._writer
def release_writer(self):
self._writer.close()
if self._show_statistics:
self._writer._dump_state()
self._writer = None
def _setup_for_aio(self, aio_config):
self._io_buffer = torch.zeros(self._io_buffer_size, dtype=torch.uint8, device='cpu').pin_memory()
self._dnvme_handle = AsyncIOBuilder().load().aio_handle(
block_size=aio_config[AIO_BLOCK_SIZE],
queue_depth=aio_config[AIO_QUEUE_DEPTH],
single_submit=aio_config[AIO_SINGLE_SUBMIT],
overlap_events=aio_config[AIO_OVERLAP_EVENTS],
intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM])
def _setup_for_gds(self, aio_config):
self._io_buffer = torch.zeros(self._io_buffer_size,
dtype=torch.uint8,
device=get_accelerator().current_device_name())
self._dnvme_handle = GDSBuilder().load().gds_handle(block_size=aio_config[AIO_BLOCK_SIZE],
queue_depth=aio_config[AIO_QUEUE_DEPTH],
single_submit=aio_config[AIO_SINGLE_SUBMIT],
overlap_events=aio_config[AIO_OVERLAP_EVENTS],
intra_op_parallelism=aio_config[AIO_INTRA_OP_PARALLELISM])
self._dnvme_handle.pin_device_tensor(self._io_buffer)

View File

@ -599,6 +599,8 @@ class PipelineModule(nn.Module):
return ckpt_files
def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=False):
# TODO: Need to validate interaction of checkpoint_parallel_write_pipeline and fastwriter
# Processes having the same model parallel rank on different data parallel instances
# have identical layer weights. We can distribute the task of saving the layer weights
# among the data parallel ranks. For example, if a pipeline stage has 9 layers and
@ -629,7 +631,7 @@ class PipelineModule(nn.Module):
for n in self._get_frozen_parameter_names(layer):
del orig_state_dict[n]
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
checkpoint_engine.save(final_state_dict, model_ckpt_path)
checkpoint_engine.save(state_dict=final_state_dict, path=model_ckpt_path)
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):
for idx, layer in enumerate(self.forward_funcs):

View File

@ -411,10 +411,16 @@ class PipelineParallelGrid:
""" The stage of the pipeline this rank resides in. """
return self.get_stage_id()
def get_pipeline_model_parallel_rank(self):
return self.get_pipe_parallel_rank()
def get_pipe_parallel_world_size(self):
""" The number of stages in the pipeline. """
return self.pipe_parallel_size
def get_pipeline_model_parallel_world_size(self):
return self.get_pipe_parallel_world_size()
def get_pipe_parallel_group(self):
""" The group of ranks within the same pipeline. """
return self.pp_proc_group
@ -431,6 +437,10 @@ class PipelineParallelGrid:
""" The group of ranks within the same stage of all pipelines. """
return self.dp_proc_group
def get_data_parallel_group_ranks(self):
""" List of ranks in the data parallel group. """
return self.dp_group
# These are model parallel groups across all types of model parallelism.
# Deepspeed uses them to detect overflow, etc.
def get_model_parallel_rank(self):
@ -449,8 +459,14 @@ class PipelineParallelGrid:
else:
return 0
def get_tensor_model_parallel_rank(self):
return self.get_slice_parallel_rank()
def get_slice_parallel_world_size(self):
return self.slice_parallel_size
def get_tensor_model_parallel_world_size(self):
return self.get_slice_parallel_world_size()
def get_slice_parallel_group(self):
return self.slice_proc_group

View File

@ -120,18 +120,15 @@ class AsyncPartitionedParameterSwapper(object):
overlap_events=self.aio_config[AIO_OVERLAP_EVENTS],
intra_op_parallelism=self.aio_config[AIO_INTRA_OP_PARALLELISM])
if self.use_gds:
buffer_device = get_accelerator().device_name() if self.use_gds else "cpu"
self.buffers = torch.empty(int(self.aligned_elements_per_buffer * self.param_buffer_count),
dtype=self.dtype,
device=get_accelerator().device_name(),
device=buffer_device,
requires_grad=False)
if self.use_gds:
self.aio_read_handle.pin_device_tensor(self.buffers)
else:
self.buffers = get_accelerator().pin_memory(torch.empty(int(self.aligned_elements_per_buffer *
self.param_buffer_count),
dtype=self.dtype,
requires_grad=False),
align_bytes=0)
self.buffers = get_accelerator().pin_memory(self.buffers, align_bytes=0)
self.swap_out_params = []
@ -357,7 +354,7 @@ class AsyncPartitionedParameterSwapper(object):
assert self.available_swap_in_buffers(
) > 0, f"No swap buffers to allocate for fp16 param {param_id} of numel = {numel}"
assert numel < self.elements_per_buffer, f"More elements {numel} than buffer size {self.elements_per_buffer}"
assert numel <= self.elements_per_buffer, f"More elements {numel} than buffer size {self.elements_per_buffer}"
self.param_id_to_numel[param_id] = numel
buffer_id = self.available_buffer_ids.pop()

View File

@ -1007,6 +1007,37 @@ def all_gather_dp_groups(groups_flat, partitioned_param_groups, dp_process_group
dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id])
def get_tensor_bytes(item):
if torch.is_tensor(item):
return item.numel() * item.element_size()
elif isinstance(item, list):
return sum([get_tensor_bytes(v) for v in item])
elif isinstance(item, tuple):
return sum([get_tensor_bytes(v) for v in item])
elif isinstance(item, dict):
return sum([get_tensor_bytes(v) for v in item.values()])
else:
return 0
def _get_folder_size(folder):
size = 0
for path, _, files in os.walk(folder):
size += sum([os.path.getsize(os.path.join(path, f)) for f in files])
return size
def get_checkpoint_folder_size(save_dir, tag, local_rank=None):
if local_rank == 0:
folder = os.path.join(save_dir, tag)
size_tensor = torch.tensor(_get_folder_size(folder)).to(get_accelerator().device_name())
else:
size_tensor = torch.tensor(0).to(get_accelerator().device_name())
dist.reduce(tensor=size_tensor, dst=0)
return int(size_tensor)
class TLinear(torch.nn.Linear):
def __init__(self, orig_layer, name=""):

View File

@ -8,6 +8,8 @@ from .partition_parameters import ZeroParamStatus
from .partition_parameters import Init
from .partition_parameters import GatheredParameters
from .partition_parameters import register_external_parameter
from .parameter_offload import DeepSpeedZeRoOffload
from .partition_parameters import DeepSpeedTensorOverride
from .tiling import TiledLinear
from .tiling import TiledLinearReturnBias

Some files were not shown because too many files have changed in this diff Show More