mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][TPU] Fix KV cache size calculation (#5860)
This commit is contained in:
@ -118,14 +118,15 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
xm.wait_device_ops()
|
||||
|
||||
m = xm.get_memory_info(self.device)
|
||||
program_size = 1024 * 1024 * 1024 # 1GB
|
||||
free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0)
|
||||
kv_cache_bytes = int(free_bytes *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
profiled = m["bytes_used"] # Weights + intermediate activations.
|
||||
kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
dtype_btyes = get_dtype_size(self.cache_dtype)
|
||||
block_size = self.cache_config.block_size
|
||||
num_tpu_blocks = (kv_cache_bytes //
|
||||
(kv_cache_dtype_btyes * block_size * num_layers * 2 *
|
||||
(dtype_btyes * block_size * num_layers * 2 *
|
||||
head_size * num_kv_heads))
|
||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||
return num_tpu_blocks, 0
|
||||
|
Reference in New Issue
Block a user