mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] Fix multimodal tests (#22491)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -845,7 +845,8 @@ class LLMEngine:
|
||||
|
||||
def reset_mm_cache(self) -> bool:
|
||||
"""Reset the multi-modal cache."""
|
||||
return self.input_preprocessor.mm_registry.reset_processor_cache()
|
||||
return self.input_preprocessor.mm_registry.reset_processor_cache(
|
||||
self.model_config)
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
@ -86,6 +87,13 @@ class _ProcessorFactories(Generic[_I]):
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
# Make sure a different cache is used for each model config
|
||||
# NOTE: ModelConfig is not hashable so it cannot be passed directly
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_processor_cache(model_id: str, capacity_gb: int):
|
||||
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
@ -95,22 +103,15 @@ class MultiModalRegistry:
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
_ProcessorFactories]()
|
||||
|
||||
self._processor_cache: Optional[ProcessingCache] = None
|
||||
|
||||
def _get_processor_cache(self, model_config: "ModelConfig"):
|
||||
model_id = model_config.model
|
||||
capacity_gb = model_config.mm_processor_cache_gb
|
||||
if capacity_gb is None:
|
||||
return None # Overrides `disable_cache` argument
|
||||
return _get_processor_cache(model_id, capacity_gb)
|
||||
|
||||
if self._processor_cache is None:
|
||||
self._processor_cache = ProcessingCache(capacity_gb)
|
||||
|
||||
return self._processor_cache
|
||||
|
||||
def reset_processor_cache(self) -> bool:
|
||||
def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
|
||||
"""Reset the multi-modal processing cache."""
|
||||
if self._processor_cache:
|
||||
self._processor_cache.reset()
|
||||
if processor_cache := self._get_processor_cache(model_config):
|
||||
processor_cache.reset()
|
||||
|
||||
return True # Success
|
||||
|
||||
|
@ -566,7 +566,7 @@ class AsyncLLM(EngineClient):
|
||||
await self.engine_core.profile_async(False)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_registry.reset_processor_cache(self.model_config)
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
|
@ -271,7 +271,7 @@ class LLMEngine:
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.processor.mm_registry.reset_processor_cache()
|
||||
self.processor.mm_registry.reset_processor_cache(self.model_config)
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
|
Reference in New Issue
Block a user