[CI/Build] Fix multimodal tests (#22491)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-08 15:31:19 +08:00
committed by GitHub
parent 808a7b69df
commit 1712543df6
4 changed files with 17 additions and 15 deletions

View File

@ -845,7 +845,8 @@ class LLMEngine:
def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache()
return self.input_preprocessor.mm_registry.reset_processor_cache(
self.model_config)
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
@ -86,6 +87,13 @@ class _ProcessorFactories(Generic[_I]):
return self.processor(info, dummy_inputs_builder, cache=cache)
# Make sure a different cache is used for each model config
# NOTE: ModelConfig is not hashable so it cannot be passed directly
@lru_cache(maxsize=1)
def _get_processor_cache(model_id: str, capacity_gb: int):
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
@ -95,22 +103,15 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]()
self._processor_cache: Optional[ProcessingCache] = None
def _get_processor_cache(self, model_config: "ModelConfig"):
model_id = model_config.model
capacity_gb = model_config.mm_processor_cache_gb
if capacity_gb is None:
return None # Overrides `disable_cache` argument
return _get_processor_cache(model_id, capacity_gb)
if self._processor_cache is None:
self._processor_cache = ProcessingCache(capacity_gb)
return self._processor_cache
def reset_processor_cache(self) -> bool:
def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
"""Reset the multi-modal processing cache."""
if self._processor_cache:
self._processor_cache.reset()
if processor_cache := self._get_processor_cache(model_config):
processor_cache.reset()
return True # Success

View File

@ -566,7 +566,7 @@ class AsyncLLM(EngineClient):
await self.engine_core.profile_async(False)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
await self.engine_core.reset_mm_cache_async()

View File

@ -271,7 +271,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache()
self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
self.engine_core.reset_mm_cache()