[Core][Optimization] change python dict to pytorch tensor (#4607)

This commit is contained in:
youkaichao
2024-05-06 21:30:27 -07:00
committed by GitHub
parent a98187cf72
commit 63575bc2e1
19 changed files with 77 additions and 81 deletions

View File

@ -13,7 +13,7 @@ void swap_blocks(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
torch::Tensor& block_mapping);
void reshape_and_cache(
torch::Tensor& key,

View File

@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
@ -114,17 +114,9 @@ void copy_blocks(
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// Create block mapping array.
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// block_mapping is a 2D tensor with shape (num_pairs, 2).
int num_pairs = block_mapping.size(0);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
@ -132,8 +124,6 @@ void copy_blocks(
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
@ -146,7 +136,7 @@ void copy_blocks(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(),
numel_per_block);
}));
}

View File

@ -8,16 +8,16 @@ template <typename scalar_t>
void copy_blocks_cpu_impl(
std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
const torch::Tensor& mapping_pairs,
const int element_num_per_block, const int layer_num) {
const size_t pair_num = mapping_pairs.size();
const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second;
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset;
@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(
void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});

View File

@ -568,7 +568,7 @@ def test_decode_schedule_preempted():
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
# Nothing is copied.
assert output.blocks_to_copy == {}
assert output.blocks_to_copy == []
def test_decode_swap_beam_search():
@ -618,7 +618,7 @@ def test_decode_swap_beam_search():
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == expected_swap_mapping
# Nothing is copied.
assert output.blocks_to_copy == {}
assert output.blocks_to_copy == []
def test_schedule_decode_blocks_to_copy_update():
@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert output.blocks_to_swap_out == {}
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert output.blocks_to_copy == {2: [3]}
assert output.blocks_to_copy == [(2, 3)]
def test_schedule_swapped_simple():
@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy():
assert len(remaining_swapped) == 0
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
assert output.blocks_to_copy == {2: [3]}
assert output.blocks_to_copy == [(2, 3)]
def test_scheduling_budget():

View File

@ -63,12 +63,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {}
block_mapping = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
block_mapping.append((src, dst1))
block_mapping.append((src, dst2))
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
@ -81,15 +82,17 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
ops.copy_blocks(key_caches, value_caches, block_mapping)
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
for src, dst in block_mapping:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):

View File

@ -59,7 +59,7 @@ def test_swap() -> None:
seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={},
blocks_to_copy=[],
)
worker.execute_model(execute_model_req=execute_model_req)

View File

@ -42,7 +42,7 @@ class AttentionBackend(ABC):
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError

View File

@ -48,7 +48,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

View File

@ -48,7 +48,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError

View File

@ -46,7 +46,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

View File

@ -44,7 +44,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

View File

@ -49,7 +49,7 @@ class XFormersBackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)

View File

@ -209,7 +209,7 @@ class PagedAttention:
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]

View File

@ -13,7 +13,6 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.utils import merge_dicts
logger = init_logger(__name__)
@ -122,8 +121,8 @@ class SchedulerOutputs:
blocks_to_swap_in: Dict[int, int]
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int]
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]]
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding.
@ -177,7 +176,7 @@ class SchedulerRunningOutputs:
# The blocks to swap out.
blocks_to_swap_out: Dict[int, int]
# The blocks to copy.
blocks_to_copy: Dict[int, List[int]]
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
@ -189,7 +188,7 @@ class SchedulerRunningOutputs:
preempted=[],
swapped_out=[],
blocks_to_swap_out={},
blocks_to_copy={},
blocks_to_copy=[],
num_lookahead_slots=0,
)
@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs:
# The blocks to swap in.
blocks_to_swap_in: Dict[int, int]
# The blocks to copy.
blocks_to_copy: Dict[int, List[int]]
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Infeasible sequence groups.
@ -221,7 +220,7 @@ class SchedulerSwappedInOutputs:
decode_seq_groups=[],
prefill_seq_groups=[],
blocks_to_swap_in={},
blocks_to_copy={},
blocks_to_copy=[],
num_lookahead_slots=0,
infeasible_seq_groups=[],
)
@ -394,7 +393,7 @@ class Scheduler:
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
@ -511,7 +510,7 @@ class Scheduler:
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
@ -794,8 +793,8 @@ class Scheduler:
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy),
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
@ -882,8 +881,8 @@ class Scheduler:
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy),
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
@ -1011,17 +1010,18 @@ class Scheduler:
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
block indices to lists of destination block indices. This
dictionary is updated with the new source and destination block
indices for the appended slots.
blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
ints, the first int is the source block index, and the second
int is the destination block index. This list is updated with
the new source and destination block indices for the appended
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
@ -1029,9 +1029,8 @@ class Scheduler:
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
for src, dests in cows.items():
if src not in blocks_to_copy:
blocks_to_copy[src] = []
blocks_to_copy[src].extend(dests)
for dest in dests:
blocks_to_copy.append((src, dest))
def _preempt(
self,

View File

@ -203,6 +203,9 @@ def broadcast_tensor_dict(
group=metadata_group)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
@ -224,6 +227,10 @@ def broadcast_tensor_dict(
tensor = torch.empty(value.size,
dtype=value.dtype,
device="cuda")
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True,

View File

@ -2,7 +2,7 @@
import copy
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock
from vllm.lora.request import LoRARequest
@ -745,8 +745,8 @@ class ExecuteModelRequest:
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# The number of slots for lookahead decoding.
num_lookahead_slots: int = 0
# The number of requests in the running queue.

View File

@ -77,7 +77,7 @@ class CacheEngine:
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
@staticmethod

View File

@ -248,9 +248,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
def cache_copy(
self,
blocks_to_copy: Dict[int, List[int]],
blocks_to_copy: torch.Tensor,
) -> None:
if blocks_to_copy:
if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode()
@ -269,6 +269,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = {

View File

@ -197,7 +197,7 @@ class Worker(WorkerBase):
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
blocks_to_copy: torch.Tensor,
) -> None:
# Issue cache operations.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
@ -205,7 +205,7 @@ class Worker(WorkerBase):
self.cache_engine.swap_in(blocks_to_swap_in)
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
if blocks_to_copy:
if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode()
@ -225,7 +225,9 @@ class Worker(WorkerBase):
num_seq_groups = len(seq_group_metadata_list)
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,