mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][Optimization] change python dict to pytorch tensor (#4607)
This commit is contained in:
@ -13,7 +13,7 @@ void swap_blocks(
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
|
@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel(
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
||||
torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
@ -114,17 +114,9 @@ void copy_blocks(
|
||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
// Create block mapping array.
|
||||
std::vector<int64_t> block_mapping_vec;
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
block_mapping_vec.push_back(src_block_number);
|
||||
block_mapping_vec.push_back(dst_block_number);
|
||||
}
|
||||
}
|
||||
int64_t* block_mapping_array = block_mapping_vec.data();
|
||||
int num_pairs = block_mapping_vec.size() / 2;
|
||||
|
||||
// block_mapping is a 2D tensor with shape (num_pairs, 2).
|
||||
int num_pairs = block_mapping.size(0);
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
// NOTE: This synchronizes the CPU and GPU.
|
||||
@ -132,8 +124,6 @@ void copy_blocks(
|
||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
@ -146,7 +136,7 @@ void copy_blocks(
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int64_t>(),
|
||||
block_mapping.data_ptr<int64_t>(),
|
||||
numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
@ -8,16 +8,16 @@ template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(
|
||||
std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block, const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size();
|
||||
const size_t pair_num = mapping_pairs.size(0);
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair].second;
|
||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t *source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t *target_ptr = key_cache_ptr + target_offset;
|
||||
@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl(
|
||||
|
||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
|
||||
torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
|
||||
mapping_pairs.reserve(block_mapping.size());
|
||||
for (const auto &pair : block_mapping) {
|
||||
for (const auto &dst : pair.second) {
|
||||
mapping_pairs.emplace_back(pair.first, dst);
|
||||
}
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
|
@ -568,7 +568,7 @@ def test_decode_schedule_preempted():
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == {}
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
|
||||
def test_decode_swap_beam_search():
|
||||
@ -618,7 +618,7 @@ def test_decode_swap_beam_search():
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == expected_swap_mapping
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == {}
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
|
||||
def test_schedule_decode_blocks_to_copy_update():
|
||||
@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
assert output.blocks_to_swap_out == {}
|
||||
# Since append_slot returns the source -> dist mapping, it should
|
||||
# applied.
|
||||
assert output.blocks_to_copy == {2: [3]}
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
|
||||
|
||||
def test_schedule_swapped_simple():
|
||||
@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy():
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
assert output.blocks_to_copy == {2: [3]}
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
|
||||
|
||||
def test_scheduling_budget():
|
||||
|
@ -63,12 +63,13 @@ def test_copy_blocks(
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||
block_mapping = {}
|
||||
block_mapping = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping[src] = [dst1, dst2]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
@ -81,15 +82,17 @@ def test_copy_blocks(
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device=device).view(-1, 2)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
||||
|
||||
# Run the reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
for dst in dsts:
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||
for src, dst in block_mapping:
|
||||
for cloned_key_cache in cloned_key_caches:
|
||||
cloned_key_cache[dst].copy_(cloned_key_cache[src])
|
||||
for cloned_value_cache in cloned_value_caches:
|
||||
cloned_value_cache[dst].copy_(cloned_value_cache[src])
|
||||
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
|
@ -59,7 +59,7 @@ def test_swap() -> None:
|
||||
seq_group_metadata_list=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy={},
|
||||
blocks_to_copy=[],
|
||||
)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
|
@ -42,7 +42,7 @@ class AttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -48,7 +48,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
@ -48,7 +48,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -46,7 +46,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
@ -44,7 +44,7 @@ class TorchSDPABackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
@ -49,7 +49,7 @@ class XFormersBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
@ -209,7 +209,7 @@ class PagedAttention:
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
|
@ -13,7 +13,6 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
from vllm.utils import merge_dicts
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -122,8 +121,8 @@ class SchedulerOutputs:
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
# Blocks to copy. Source to a list of dest blocks.
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]]
|
||||
# Sequence groups that are going to be ignored.
|
||||
ignored_seq_groups: List[SequenceGroup]
|
||||
# The number of slots for lookahead decoding.
|
||||
@ -177,7 +176,7 @@ class SchedulerRunningOutputs:
|
||||
# The blocks to swap out.
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
# The blocks to copy.
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
blocks_to_copy: List[Tuple[int, int]]
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int
|
||||
|
||||
@ -189,7 +188,7 @@ class SchedulerRunningOutputs:
|
||||
preempted=[],
|
||||
swapped_out=[],
|
||||
blocks_to_swap_out={},
|
||||
blocks_to_copy={},
|
||||
blocks_to_copy=[],
|
||||
num_lookahead_slots=0,
|
||||
)
|
||||
|
||||
@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs:
|
||||
# The blocks to swap in.
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
# The blocks to copy.
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
blocks_to_copy: List[Tuple[int, int]]
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int
|
||||
# Infeasible sequence groups.
|
||||
@ -221,7 +220,7 @@ class SchedulerSwappedInOutputs:
|
||||
decode_seq_groups=[],
|
||||
prefill_seq_groups=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_copy={},
|
||||
blocks_to_copy=[],
|
||||
num_lookahead_slots=0,
|
||||
infeasible_seq_groups=[],
|
||||
)
|
||||
@ -394,7 +393,7 @@ class Scheduler:
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
blocks_to_copy: Dict[int, List[int]] = {}
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
@ -511,7 +510,7 @@ class Scheduler:
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_copy: Dict[int, List[int]] = {}
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
now = time.time()
|
||||
@ -794,8 +793,8 @@ class Scheduler:
|
||||
num_batched_tokens=budget.num_batched_tokens,
|
||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
|
||||
swapped_in.blocks_to_copy),
|
||||
blocks_to_copy=running_scheduled.blocks_to_copy +
|
||||
swapped_in.blocks_to_copy,
|
||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||
swapped_in.infeasible_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
@ -882,8 +881,8 @@ class Scheduler:
|
||||
num_batched_tokens=budget.num_batched_tokens,
|
||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
|
||||
swapped_in.blocks_to_copy),
|
||||
blocks_to_copy=running_scheduled.blocks_to_copy +
|
||||
swapped_in.blocks_to_copy,
|
||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
running_queue_size=len(self.running),
|
||||
@ -1011,17 +1010,18 @@ class Scheduler:
|
||||
def _append_slots(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""Appends new slots to the sequences in the given sequence group.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group containing the
|
||||
sequences to append slots to.
|
||||
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
|
||||
block indices to lists of destination block indices. This
|
||||
dictionary is updated with the new source and destination block
|
||||
indices for the appended slots.
|
||||
blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two
|
||||
ints, the first int is the source block index, and the second
|
||||
int is the destination block index. This list is updated with
|
||||
the new source and destination block indices for the appended
|
||||
slots.
|
||||
"""
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
|
||||
|
||||
@ -1029,9 +1029,8 @@ class Scheduler:
|
||||
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
||||
|
||||
for src, dests in cows.items():
|
||||
if src not in blocks_to_copy:
|
||||
blocks_to_copy[src] = []
|
||||
blocks_to_copy[src].extend(dests)
|
||||
for dest in dests:
|
||||
blocks_to_copy.append((src, dest))
|
||||
|
||||
def _preempt(
|
||||
self,
|
||||
|
@ -203,6 +203,9 @@ def broadcast_tensor_dict(
|
||||
group=metadata_group)
|
||||
async_handles = []
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
continue
|
||||
async_handles.append(
|
||||
torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
@ -224,6 +227,10 @@ def broadcast_tensor_dict(
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device="cuda")
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
continue
|
||||
async_handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
async_op=True,
|
||||
|
@ -2,7 +2,7 @@
|
||||
import copy
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -745,8 +745,8 @@ class ExecuteModelRequest:
|
||||
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
|
||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
|
||||
# Blocks to copy. Source to a list of dest blocks.
|
||||
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int = 0
|
||||
# The number of requests in the running queue.
|
||||
|
@ -77,7 +77,7 @@ class CacheEngine:
|
||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
|
@ -248,9 +248,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
def cache_copy(
|
||||
self,
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
blocks_to_copy: torch.Tensor,
|
||||
) -> None:
|
||||
if blocks_to_copy:
|
||||
if blocks_to_copy.numel() > 0:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -269,6 +269,9 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
num_seq_groups: int = len(seq_group_metadata_list)
|
||||
assert execute_model_req is not None
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
data: Dict[str, Any] = {
|
||||
|
@ -197,7 +197,7 @@ class Worker(WorkerBase):
|
||||
self,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
blocks_to_copy: torch.Tensor,
|
||||
) -> None:
|
||||
# Issue cache operations.
|
||||
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
||||
@ -205,7 +205,7 @@ class Worker(WorkerBase):
|
||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||
if blocks_to_swap_out:
|
||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||
if blocks_to_copy:
|
||||
if blocks_to_copy.numel() > 0:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -225,7 +225,9 @@ class Worker(WorkerBase):
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
||||
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
|
Reference in New Issue
Block a user