mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][TPU] Fix V1 TPU worker for sliding window (#16059)
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
@ -137,7 +137,7 @@ class TPUWorker:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if isinstance(layer_spec, FullAttentionSpec):
|
||||
if isinstance(layer_spec, AttentionSpec):
|
||||
dtype = layer_spec.dtype
|
||||
|
||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
@ -147,7 +147,8 @@ class TPUWorker:
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = tpu_kv_cache
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec '{type(layer_spec)}'")
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(
|
||||
|
Reference in New Issue
Block a user