[V1] Get supported tasks from model runner instead of model config (#21585)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-25 20:36:45 +08:00
committed by GitHub
parent 5c3f2628d5
commit 46d81d6951
19 changed files with 200 additions and 54 deletions

View File

@ -14,6 +14,7 @@ from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
import vllm.envs as envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence,
create_sort_beams_key_function)
@ -44,9 +45,10 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext
@ -277,6 +279,16 @@ class LLM:
self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None
if envs.VLLM_USE_V1:
supported_tasks = self.llm_engine \
.get_supported_tasks() # type: ignore
else:
supported_tasks = self.llm_engine.model_config.supported_tasks
logger.info("Supported_tasks: %s", supported_tasks)
self.supported_tasks = supported_tasks
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
@ -1170,8 +1182,7 @@ class LLM:
A list of `EmbeddingRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
model_config = self.llm_engine.model_config
if "embed" not in model_config.supported_tasks:
if "embed" not in self.supported_tasks:
raise ValueError("Embedding API is not supported by this model. "
"Please set `--task embed`.")
@ -1215,8 +1226,7 @@ class LLM:
A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
"""
model_config = self.llm_engine.model_config
if "classify" not in model_config.supported_tasks:
if "classify" not in self.supported_tasks:
raise ValueError(
"Classification API is not supported by this model. "
"Please set `--task classify`.")
@ -1397,8 +1407,8 @@ class LLM:
raise ValueError(" ".join(messages))
if all(t not in model_config.supported_tasks
for t in ("embed", "classify")):
supported_tasks = self.supported_tasks
if all(t not in supported_tasks for t in ("embed", "classify")):
raise ValueError("Score API is not supported by this model. "
"Please set `--task embed` or `--task classify`.")

View File

@ -1586,6 +1586,14 @@ async def init_app_state(
state.vllm_config = vllm_config
model_config = vllm_config.model_config
if envs.VLLM_USE_V1:
supported_tasks = await engine_client \
.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks
logger.info("Supported_tasks: %s", supported_tasks)
resolved_chat_template = load_chat_template(args.chat_template)
if resolved_chat_template is not None:
# Get the tokenizer to check official template
@ -1647,7 +1655,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@ -1664,7 +1672,7 @@ async def init_app_state(
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
@ -1673,7 +1681,7 @@ async def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
@ -1681,7 +1689,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if "encode" in model_config.supported_tasks else None
) if "encode" in supported_tasks else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
@ -1689,24 +1697,22 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if "embed" in model_config.supported_tasks else None
) if "embed" in supported_tasks else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if "classify" in model_config.supported_tasks else None
) if "classify" in supported_tasks else None
enable_serving_reranking = ("classify" in model_config.supported_tasks
and getattr(model_config.hf_config,
"num_labels", 0) == 1)
enable_serving_reranking = ("classify" in supported_tasks and getattr(
model_config.hf_config, "num_labels", 0) == 1)
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
) if ("embed" in supported_tasks or enable_serving_reranking) else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
@ -1721,13 +1727,13 @@ async def init_app_state(
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if "transcription" in model_config.supported_tasks else None
) if "transcription" in supported_tasks else None
state.openai_serving_translation = OpenAIServingTranslation(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if "transcription" in model_config.supported_tasks else None
) if "transcription" in supported_tasks else None
state.task = model_config.task
state.enable_server_load_tracking = args.enable_server_load_tracking

View File

@ -14,6 +14,7 @@ import torch
from prometheus_client import start_http_server
from tqdm import tqdm
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.protocol import EngineClient
@ -335,6 +336,14 @@ async def run_batch(
model_config = vllm_config.model_config
if envs.VLLM_USE_V1:
supported_tasks = await engine_client \
.get_supported_tasks() # type: ignore
else:
supported_tasks = model_config.supported_tasks
logger.info("Supported_tasks: %s", supported_tasks)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
@ -351,7 +360,7 @@ async def run_batch(
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if "generate" in model_config.supported_tasks else None
) if "generate" in supported_tasks else None
openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
@ -359,19 +368,17 @@ async def run_batch(
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if "embed" in model_config.supported_tasks else None
) if "embed" in supported_tasks else None
enable_serving_reranking = ("classify" in model_config.supported_tasks
and getattr(model_config.hf_config,
"num_labels", 0) == 1)
enable_serving_reranking = ("classify" in supported_tasks and getattr(
model_config.hf_config, "num_labels", 0) == 1)
openai_serving_scores = ServingScores(
engine_client,
model_config,
openai_serving_models,
request_logger=request_logger,
) if ("embed" in model_config.supported_tasks
or enable_serving_reranking) else None
) if ("embed" in supported_tasks or enable_serving_reranking) else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)

View File

@ -16,8 +16,8 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.pooling_params import PoolingTask
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.tasks import SupportedTask
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
@ -136,9 +136,9 @@ class ExecutorBase(ABC):
return self.collective_rpc(rpc_func)
@cached_property # Avoid unnecessary RPC calls
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
output = self.collective_rpc("get_supported_pooling_tasks")
return tuple({task for tasks in output for task in tasks})
def supported_tasks(self) -> tuple[SupportedTask, ...]:
output = self.collective_rpc("get_supported_tasks")
return output[0]
def execute_model(
self, execute_model_req: ExecuteModelRequest

View File

@ -16,8 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.pooling_metadata import ( # noqa: E501
PoolingMetadata as V0PoolingMetadata)
from vllm.model_executor.pooling_metadata import PoolingTensors
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata

View File

@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix

View File

@ -16,8 +16,8 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsV0Only

View File

@ -23,8 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix

View File

@ -1,17 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional
import msgspec
from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask
if TYPE_CHECKING:
from vllm.config import ModelConfig
PoolingTask = Literal["encode", "embed", "classify", "score"]
class PoolingParams(
msgspec.Struct,

11
vllm/tasks.py Normal file
View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["encode", "embed", "classify", "score"]
POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]

View File

@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -211,6 +212,9 @@ class AsyncLLM(EngineClient):
if handler := getattr(self, "output_handler", None):
handler.cancel()
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
async def add_request(
self,
request_id: str,

View File

@ -23,6 +23,7 @@ from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (bind_process_name, make_zmq_socket,
@ -195,11 +196,17 @@ class EngineCore:
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if pooling_params := request.pooling_params:
supported_pooling_tasks = (
self.model_executor.supported_pooling_tasks)
supported_pooling_tasks = [
task for task in self.get_supported_tasks()
if task in POOLING_TASKS
]
if pooling_params.task not in supported_pooling_tasks:
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")

View File

@ -21,6 +21,7 @@ import zmq.asyncio
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
@ -104,6 +105,9 @@ class EngineCoreClient(ABC):
def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
raise NotImplementedError
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
@ -170,6 +174,9 @@ class EngineCoreClient(ABC):
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
raise NotImplementedError
async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
@ -238,6 +245,9 @@ class InprocClient(EngineCoreClient):
outputs, _ = self.engine_core.step()
return outputs.get(0) or EngineCoreOutputs()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)
@ -608,6 +618,9 @@ class SyncMPClient(MPClient):
return future.result()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.call_utility("get_supported_tasks")
def add_request(self, request: EngineCoreRequest) -> None:
if self.is_dp:
self.engines_running = True
@ -802,6 +815,9 @@ class AsyncMPClient(MPClient):
self._ensure_output_queue_task()
return await future
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
return await self.call_utility_async("get_supported_tasks")
async def add_request_async(self, request: EngineCoreRequest) -> None:
request.client_index = self.client_index
await self._send_input(EngineCoreRequestType.ADD, request)

View File

@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.transformers_utils.tokenizer_group import (
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
@ -176,6 +177,9 @@ class LLMEngine:
def validate_outputs(cls, outputs, output_type):
return outputs
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""

View File

@ -30,15 +30,17 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import is_mixture_of_experts
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling,
is_pooling_model)
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
supports_transcription)
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams, PoolingTask
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
@ -1153,6 +1155,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
@ -1160,6 +1177,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",

View File

@ -23,8 +23,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@ -320,8 +320,8 @@ class Worker(WorkerBase):
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
return self.model_runner.get_supported_pooling_tasks()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def execute_model(

View File

@ -27,13 +27,15 @@ from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingTask
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
prev_power_of_2)
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
@ -489,6 +491,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
@ -496,6 +513,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@ -21,7 +21,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingTask
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
@ -282,8 +282,8 @@ class TPUWorker:
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
return self.model_runner.get_supported_pooling_tasks()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()

View File

@ -12,9 +12,11 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces_base import is_pooling_model
from vllm.pooling_params import PoolingTask
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
@ -224,6 +226,21 @@ class ModelRunnerBase(ABC, Generic[T]):
def get_model(self) -> nn.Module:
raise NotImplementedError
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
@ -231,6 +248,16 @@ class ModelRunnerBase(ABC, Generic[T]):
return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def execute_model(
self,
model_input: T,