diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2f766a2dae..2c961156bc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -14,6 +14,7 @@ from pydantic import ValidationError from tqdm.auto import tqdm from typing_extensions import TypeVar, deprecated +import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, create_sort_beams_key_function) @@ -44,9 +45,10 @@ from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) -from vllm.pooling_params import PoolingParams, PoolingTask +from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) +from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext @@ -277,6 +279,16 @@ class LLM: self.request_counter = Counter() self.default_sampling_params: Union[dict[str, Any], None] = None + if envs.VLLM_USE_V1: + supported_tasks = self.llm_engine \ + .get_supported_tasks() # type: ignore + else: + supported_tasks = self.llm_engine.model_config.supported_tasks + + logger.info("Supported_tasks: %s", supported_tasks) + + self.supported_tasks = supported_tasks + def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, @@ -1170,8 +1182,7 @@ class LLM: A list of `EmbeddingRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - model_config = self.llm_engine.model_config - if "embed" not in model_config.supported_tasks: + if "embed" not in self.supported_tasks: raise ValueError("Embedding API is not supported by this model. " "Please set `--task embed`.") @@ -1215,8 +1226,7 @@ class LLM: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - model_config = self.llm_engine.model_config - if "classify" not in model_config.supported_tasks: + if "classify" not in self.supported_tasks: raise ValueError( "Classification API is not supported by this model. " "Please set `--task classify`.") @@ -1397,8 +1407,8 @@ class LLM: raise ValueError(" ".join(messages)) - if all(t not in model_config.supported_tasks - for t in ("embed", "classify")): + supported_tasks = self.supported_tasks + if all(t not in supported_tasks for t in ("embed", "classify")): raise ValueError("Score API is not supported by this model. " "Please set `--task embed` or `--task classify`.") diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8540d25d4e..5b87aed06e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1586,6 +1586,14 @@ async def init_app_state( state.vllm_config = vllm_config model_config = vllm_config.model_config + if envs.VLLM_USE_V1: + supported_tasks = await engine_client \ + .get_supported_tasks() # type: ignore + else: + supported_tasks = model_config.supported_tasks + + logger.info("Supported_tasks: %s", supported_tasks) + resolved_chat_template = load_chat_template(args.chat_template) if resolved_chat_template is not None: # Get the tokenizer to check official template @@ -1647,7 +1655,7 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if "generate" in model_config.supported_tasks else None + ) if "generate" in supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -1664,7 +1672,7 @@ async def init_app_state( reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if "generate" in model_config.supported_tasks else None + ) if "generate" in supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, @@ -1673,7 +1681,7 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if "generate" in model_config.supported_tasks else None + ) if "generate" in supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, model_config, @@ -1681,7 +1689,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if "encode" in model_config.supported_tasks else None + ) if "encode" in supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, @@ -1689,24 +1697,22 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if "embed" in model_config.supported_tasks else None + ) if "embed" in supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if "classify" in model_config.supported_tasks else None + ) if "classify" in supported_tasks else None - enable_serving_reranking = ("classify" in model_config.supported_tasks - and getattr(model_config.hf_config, - "num_labels", 0) == 1) + enable_serving_reranking = ("classify" in supported_tasks and getattr( + model_config.hf_config, "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if ("embed" in model_config.supported_tasks - or enable_serving_reranking) else None + ) if ("embed" in supported_tasks or enable_serving_reranking) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -1721,13 +1727,13 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, - ) if "transcription" in model_config.supported_tasks else None + ) if "transcription" in supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if "transcription" in model_config.supported_tasks else None + ) if "transcription" in supported_tasks else None state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 5770550923..137b368dad 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -14,6 +14,7 @@ import torch from prometheus_client import start_http_server from tqdm import tqdm +import vllm.envs as envs from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.protocol import EngineClient @@ -335,6 +336,14 @@ async def run_batch( model_config = vllm_config.model_config + if envs.VLLM_USE_V1: + supported_tasks = await engine_client \ + .get_supported_tasks() # type: ignore + else: + supported_tasks = model_config.supported_tasks + + logger.info("Supported_tasks: %s", supported_tasks) + # Create the openai serving objects. openai_serving_models = OpenAIServingModels( engine_client=engine_client, @@ -351,7 +360,7 @@ async def run_batch( chat_template=None, chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if "generate" in model_config.supported_tasks else None + ) if "generate" in supported_tasks else None openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, @@ -359,19 +368,17 @@ async def run_batch( request_logger=request_logger, chat_template=None, chat_template_content_format="auto", - ) if "embed" in model_config.supported_tasks else None + ) if "embed" in supported_tasks else None - enable_serving_reranking = ("classify" in model_config.supported_tasks - and getattr(model_config.hf_config, - "num_labels", 0) == 1) + enable_serving_reranking = ("classify" in supported_tasks and getattr( + model_config.hf_config, "num_labels", 0) == 1) openai_serving_scores = ServingScores( engine_client, model_config, openai_serving_models, request_logger=request_logger, - ) if ("embed" in model_config.supported_tasks - or enable_serving_reranking) else None + ) if ("embed" in supported_tasks or enable_serving_reranking) else None tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 483fdb1486..97d0d6f08b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -16,8 +16,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.pooling_params import PoolingTask from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.tasks import SupportedTask from vllm.utils import make_async from vllm.worker.worker_base import WorkerBase @@ -136,9 +136,9 @@ class ExecutorBase(ABC): return self.collective_rpc(rpc_func) @cached_property # Avoid unnecessary RPC calls - def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]: - output = self.collective_rpc("get_supported_pooling_tasks") - return tuple({task for tasks in output for task in tasks}) + def supported_tasks(self) -> tuple[SupportedTask, ...]: + output = self.collective_rpc("get_supported_tasks") + return output[0] def execute_model( self, execute_model_req: ExecuteModelRequest diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index c06cca0802..5bfd4aaccc 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -16,8 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import ( # noqa: E501 PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors -from vllm.pooling_params import PoolingParams, PoolingTask +from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.tasks import PoolingTask from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 9dc6115f85..c3066aaa2b 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors +from vllm.tasks import PoolingTask from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 8a3fbc6a49..c99970284a 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -16,8 +16,8 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, get_prompt_token_ids) from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingTask from vllm.sequence import PoolerOutput +from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsV0Only diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index be1c3438d9..fc2b0c1f51 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -23,8 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors +from vllm.tasks import PoolingTask from .interfaces import SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 868facbe25..23eb775f2d 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Literal, Optional +from typing import TYPE_CHECKING, Optional import msgspec from vllm.sampling_params import RequestOutputKind +from vllm.tasks import PoolingTask if TYPE_CHECKING: from vllm.config import ModelConfig -PoolingTask = Literal["encode", "embed", "classify", "score"] - class PoolingParams( msgspec.Struct, diff --git a/vllm/tasks.py b/vllm/tasks.py new file mode 100644 index 0000000000..85c5c6e436 --- /dev/null +++ b/vllm/tasks.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Literal, get_args + +GenerationTask = Literal["generate", "transcription"] +GENERATION_TASKS = get_args(GenerationTask) + +PoolingTask = Literal["encode", "embed", "classify", "score"] +POOLING_TASKS = get_args(PoolingTask) + +SupportedTask = Literal[GenerationTask, PoolingTask] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 02cb80197f..ed0d9620f4 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.tasks import SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -211,6 +212,9 @@ class AsyncLLM(EngineClient): if handler := getattr(self, "output_handler", None): handler.cancel() + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return await self.engine_core.get_supported_tasks_async() + async def add_request( self, request_id: str, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 88c511606d..4124ee0532 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,6 +23,7 @@ from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest +from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import (bind_process_name, make_zmq_socket, @@ -195,11 +196,17 @@ class EngineCore: "warmup model) took %.2f seconds"), elapsed) return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.model_executor.supported_tasks + def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" if pooling_params := request.pooling_params: - supported_pooling_tasks = ( - self.model_executor.supported_pooling_tasks) + supported_pooling_tasks = [ + task for task in self.get_supported_tasks() + if task in POOLING_TASKS + ] + if pooling_params.task not in supported_pooling_tasks: raise ValueError(f"Unsupported task: {pooling_params.task!r} " f"Supported tasks: {supported_pooling_tasks}") diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 69ae3690d0..b14d85bbf8 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -21,6 +21,7 @@ import zmq.asyncio from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.tasks import SupportedTask from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, @@ -104,6 +105,9 @@ class EngineCoreClient(ABC): def get_output(self) -> EngineCoreOutputs: raise NotImplementedError + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + raise NotImplementedError + def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError @@ -170,6 +174,9 @@ class EngineCoreClient(ABC): async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError + async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]: + raise NotImplementedError + async def add_request_async(self, request: EngineCoreRequest) -> None: raise NotImplementedError @@ -238,6 +245,9 @@ class InprocClient(EngineCoreClient): outputs, _ = self.engine_core.step() return outputs.get(0) or EngineCoreOutputs() + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.engine_core.get_supported_tasks() + def add_request(self, request: EngineCoreRequest) -> None: self.engine_core.add_request(request) @@ -608,6 +618,9 @@ class SyncMPClient(MPClient): return future.result() + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.call_utility("get_supported_tasks") + def add_request(self, request: EngineCoreRequest) -> None: if self.is_dp: self.engines_running = True @@ -802,6 +815,9 @@ class AsyncMPClient(MPClient): self._ensure_output_queue_task() return await future + async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]: + return await self.call_utility_async("get_supported_tasks") + async def add_request_async(self, request: EngineCoreRequest) -> None: request.client_index = self.client_index await self._send_input(EngineCoreRequestType.ADD, request) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 991242e182..efbdffbc09 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.tasks import SupportedTask from vllm.transformers_utils.tokenizer_group import ( TokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext @@ -176,6 +177,9 @@ class LLMEngine: def validate_outputs(cls, outputs, output_type): return outputs + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.engine_core.get_supported_tasks() + def abort_request(self, request_ids: list[str]) -> None: """Remove request_ids from EngineCore and Detokenizer.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 32004ced4a..5fe594db66 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -30,15 +30,17 @@ from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import is_mixture_of_experts -from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, - is_pooling_model) +from vllm.model_executor.models.interfaces import (is_mixture_of_experts, + supports_transcription) +from vllm.model_executor.models.interfaces_base import ( + VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.pooling_params import PoolingParams, PoolingTask +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) @@ -1153,6 +1155,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def get_model(self) -> nn.Module: return self.model + def get_supported_generation_tasks(self) -> list[GenerationTask]: + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + def get_supported_pooling_tasks(self) -> list[PoolingTask]: model = self.get_model() if not is_pooling_model(model): @@ -1160,6 +1177,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return list(model.pooler.get_supported_tasks()) + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5229463511..dcfb038d28 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -23,8 +23,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors +from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -320,8 +320,8 @@ class Worker(WorkerBase): def get_model(self) -> nn.Module: return self.model_runner.get_model() - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - return self.model_runner.get_supported_pooling_tasks() + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.model_runner.get_supported_tasks() @torch.inference_mode() def execute_model( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e8c8008458..59cbb01505 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -27,13 +27,15 @@ from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader -from vllm.model_executor.models.interfaces_base import is_pooling_model +from vllm.model_executor.models.interfaces import supports_transcription +from vllm.model_executor.models.interfaces_base import ( + is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors +from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, prev_power_of_2) from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, @@ -489,6 +491,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def get_model(self) -> nn.Module: return self.model + def get_supported_generation_tasks(self) -> list[GenerationTask]: + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + def get_supported_pooling_tasks(self) -> list[PoolingTask]: model = self.get_model() if not is_pooling_model(model): @@ -496,6 +513,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return list(model.pooler.get_supported_tasks()) + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 254b058d2c..72e0e4230a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -21,7 +21,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.pooling_params import PoolingTask +from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput @@ -282,8 +282,8 @@ class TPUWorker: def get_model(self) -> nn.Module: return self.model_runner.get_model() - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - return self.model_runner.get_supported_pooling_tasks() + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + return self.model_runner.get_supported_tasks() def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index feca8a7a1e..7b8fe2f802 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -12,9 +12,11 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.models.interfaces_base import is_pooling_model -from vllm.pooling_params import PoolingTask +from vllm.model_executor.models.interfaces import supports_transcription +from vllm.model_executor.models.interfaces_base import ( + is_pooling_model, is_text_generation_model) from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.tasks import GenerationTask, PoolingTask, SupportedTask if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -224,6 +226,21 @@ class ModelRunnerBase(ABC, Generic[T]): def get_model(self) -> nn.Module: raise NotImplementedError + def get_supported_generation_tasks(self) -> list[GenerationTask]: + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + def get_supported_pooling_tasks(self) -> list[PoolingTask]: model = self.get_model() if not is_pooling_model(model): @@ -231,6 +248,16 @@ class ModelRunnerBase(ABC, Generic[T]): return list(model.pooler.get_supported_tasks()) + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + def execute_model( self, model_input: T,