[Bugfix][TPU] Fix CPU cache allocation (#5869)

This commit is contained in:
Woosuk Kwon
2024-06-26 13:42:40 -07:00
committed by GitHub
parent cbc53b6b8d
commit f5c8628fdc
2 changed files with 8 additions and 5 deletions

View File

@ -37,11 +37,10 @@ class PallasAttentionBackend(AttentionBackend):
) -> None:
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
src_indices, dst_indices = src_to_dst
device = dst_k_cache.device
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
device = dst_k_cache.device
src_indices, dst_indices = src_to_dst
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)

View File

@ -156,14 +156,18 @@ class TPUWorker(LoraNotSupportedWorkerBase):
self.tpu_cache = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
cpu_k_cache = torch.zeros(cpu_cache_shape,
dtype=dtype,
device="cpu")
cpu_v_cache = torch.zeros_like(cpu_k_cache)
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
self._warmup_model()