[Bugfix][TPU] Fix V1 TPU worker for sliding window (#16059)

Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-04-04 17:31:19 -06:00
committed by GitHub
parent d6fc629f4d
commit 70ad3f9e98

View File

@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
@ -137,7 +137,7 @@ class TPUWorker:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
@ -147,7 +147,8 @@ class TPUWorker:
device=self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError
raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'")
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(