mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel] correct cpu worker function parameter type (#19745)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@ -29,7 +29,7 @@ class _PagedAttention:
|
||||
head_size: int,
|
||||
*args,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
return 2, num_blocks, block_size * num_kv_heads * head_size
|
||||
|
||||
@staticmethod
|
||||
def split_kv_cache(
|
||||
|
@ -3,7 +3,7 @@
|
||||
"""A CPU worker class."""
|
||||
import os
|
||||
from importlib import util
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -88,13 +88,13 @@ class CPUCacheEngine:
|
||||
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||
|
||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user