mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:53:48 +08:00
- FastPersist - ZeRO-Inference+SGLang --------- Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: jerryyangli <jerryyangli@gmail.com> Co-authored-by: Yang Li <yangli2@microsoft.com> Co-authored-by: Guanhua Wang <alexwgh333@gmail.com> Co-authored-by: Connor Holmes <connorholmes@microsoft.com> Co-authored-by: Bing Xie <67908712+xiexbing@users.noreply.github.com> Co-authored-by: cassieesvelt <73311224+cassieesvelt@users.noreply.github.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: swli <47371259+lucasleesw@users.noreply.github.com> Co-authored-by: Cheng Li <pistasable@gmail.com> Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com> Co-authored-by: Ubuntu <jomayeri@microsoft.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
from .base_io_buffer import Base_IO_Buffer
|
|
|
|
NUM_BUFFERS = 2
|
|
INVALID_BUFFER_INDEX = -1
|
|
|
|
|
|
class Double_IO_Buffer(Base_IO_Buffer):
|
|
|
|
def __init__(self, pinned_tensor, dnvme_handle):
|
|
super(Double_IO_Buffer, self).__init__(pinned_tensor, dnvme_handle)
|
|
assert self._pinned_tensor.numel() % (NUM_BUFFERS * self._dnvme_handle.get_alignment()) == 0
|
|
self._buffers = self._split_buffer()
|
|
self._fill_index = 0
|
|
self._drain_index = INVALID_BUFFER_INDEX
|
|
self._buffer_offset = 0
|
|
|
|
def fill(self, src_tensor, src_offset):
|
|
self._validate_buffer_index(self._fill_index)
|
|
copy_bytes = Base_IO_Buffer.fill_buffer(src_tensor, src_offset, self._buffers[self._fill_index],
|
|
self._buffer_offset)
|
|
self._buffer_offset += copy_bytes
|
|
return copy_bytes
|
|
|
|
def drain(self, num_bytes, fd, file_offset):
|
|
self._validate_buffer_index(self._fill_index)
|
|
self.complete_ongoing_drain()
|
|
assert self._drain_index == INVALID_BUFFER_INDEX
|
|
self._drain(num_bytes, fd, file_offset, blocking=False)
|
|
self._drain_index = self._fill_index
|
|
self._fill_index = (self._fill_index + 1) % NUM_BUFFERS
|
|
self._buffer_offset = 0
|
|
|
|
def get_buffer(self):
|
|
self._validate_buffer_index(self._fill_index)
|
|
return self._buffers[self._fill_index]
|
|
|
|
def get_offset(self):
|
|
self._validate_buffer_index(self._fill_index)
|
|
return self._buffer_offset
|
|
|
|
def get_aligned_num_bytes(self):
|
|
self._validate_buffer_index(self._fill_index)
|
|
aligned_size = self._dnvme_handle.get_alignment()
|
|
return (self._buffer_offset // aligned_size) * aligned_size
|
|
|
|
def get_unaligned_num_bytes(self):
|
|
self._validate_buffer_index(self._fill_index)
|
|
return self._buffer_offset % self._dnvme_handle.get_alignment()
|
|
|
|
def is_full(self):
|
|
self._validate_buffer_index(self._fill_index)
|
|
return self._buffer_offset == self._buffers[self._fill_index].numel()
|
|
|
|
def is_empty(self):
|
|
self._validate_buffer_index(self._fill_index)
|
|
return self._buffer_offset == 0 and not self._is_ongoing_drain()
|
|
|
|
def reset(self):
|
|
self._buffer_offset = 0
|
|
|
|
def complete_ongoing_drain(self):
|
|
if self._is_ongoing_drain():
|
|
self._wait_for_drain()
|
|
|
|
def _split_buffer(self):
|
|
buffer_size = self._pinned_tensor.numel() // NUM_BUFFERS
|
|
return [torch.narrow(self._pinned_tensor, 0, (i * buffer_size), buffer_size) for i in range(NUM_BUFFERS)]
|
|
|
|
def _validate_buffer_index(self, index):
|
|
assert index in [0, 1]
|
|
|
|
def _wait_for_drain(self):
|
|
self._validate_buffer_index(self._drain_index)
|
|
assert 1 == self._dnvme_handle.wait()
|
|
self._drain_index = INVALID_BUFFER_INDEX
|
|
|
|
def _is_ongoing_drain(self):
|
|
return self._drain_index != INVALID_BUFFER_INDEX
|