[Kernel] correct cpu worker function parameter type (#19745)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie
2025-06-20 18:50:13 +08:00
committed by GitHub
parent e384f2f108
commit 71d1219545
2 changed files with 5 additions and 5 deletions

View File

@ -29,7 +29,7 @@ class _PagedAttention:
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(

View File

@ -3,7 +3,7 @@
"""A CPU worker class."""
import os
from importlib import util
from typing import Dict, List, Optional, Set, Tuple, Type
from typing import List, Optional, Set, Tuple, Type
import torch
import torch.distributed
@ -88,13 +88,13 @@ class CPUCacheEngine:
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
def swap_in(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
def swap_out(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod