[V1] Get supported tasks from model runner instead of model config (#21585)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -14,6 +14,7 @@ from pydantic import ValidationError
|
||||
from tqdm.auto import tqdm
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence,
|
||||
create_sort_beams_key_function)
|
||||
@ -44,9 +45,10 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -277,6 +279,16 @@ class LLM:
|
||||
self.request_counter = Counter()
|
||||
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
supported_tasks = self.llm_engine \
|
||||
.get_supported_tasks() # type: ignore
|
||||
else:
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -1170,8 +1182,7 @@ class LLM:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
model_config = self.llm_engine.model_config
|
||||
if "embed" not in model_config.supported_tasks:
|
||||
if "embed" not in self.supported_tasks:
|
||||
raise ValueError("Embedding API is not supported by this model. "
|
||||
"Please set `--task embed`.")
|
||||
|
||||
@ -1215,8 +1226,7 @@ class LLM:
|
||||
A list of `ClassificationRequestOutput` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
model_config = self.llm_engine.model_config
|
||||
if "classify" not in model_config.supported_tasks:
|
||||
if "classify" not in self.supported_tasks:
|
||||
raise ValueError(
|
||||
"Classification API is not supported by this model. "
|
||||
"Please set `--task classify`.")
|
||||
@ -1397,8 +1407,8 @@ class LLM:
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if all(t not in model_config.supported_tasks
|
||||
for t in ("embed", "classify")):
|
||||
supported_tasks = self.supported_tasks
|
||||
if all(t not in supported_tasks for t in ("embed", "classify")):
|
||||
raise ValueError("Score API is not supported by this model. "
|
||||
"Please set `--task embed` or `--task classify`.")
|
||||
|
||||
|
@ -1586,6 +1586,14 @@ async def init_app_state(
|
||||
state.vllm_config = vllm_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
supported_tasks = await engine_client \
|
||||
.get_supported_tasks() # type: ignore
|
||||
else:
|
||||
supported_tasks = model_config.supported_tasks
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
# Get the tokenizer to check official template
|
||||
@ -1647,7 +1655,7 @@ async def init_app_state(
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
) if "generate" in supported_tasks else None
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1664,7 +1672,7 @@ async def init_app_state(
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
) if "generate" in supported_tasks else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1673,7 +1681,7 @@ async def init_app_state(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
) if "generate" in supported_tasks else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1681,7 +1689,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if "encode" in model_config.supported_tasks else None
|
||||
) if "encode" in supported_tasks else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1689,24 +1697,22 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if "embed" in model_config.supported_tasks else None
|
||||
) if "embed" in supported_tasks else None
|
||||
state.openai_serving_classification = ServingClassification(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if "classify" in model_config.supported_tasks else None
|
||||
) if "classify" in supported_tasks else None
|
||||
|
||||
enable_serving_reranking = ("classify" in model_config.supported_tasks
|
||||
and getattr(model_config.hf_config,
|
||||
"num_labels", 0) == 1)
|
||||
enable_serving_reranking = ("classify" in supported_tasks and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
state.openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if ("embed" in model_config.supported_tasks
|
||||
or enable_serving_reranking) else None
|
||||
) if ("embed" in supported_tasks or enable_serving_reranking) else None
|
||||
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
@ -1721,13 +1727,13 @@ async def init_app_state(
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if "transcription" in model_config.supported_tasks else None
|
||||
) if "transcription" in supported_tasks else None
|
||||
state.openai_serving_translation = OpenAIServingTranslation(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if "transcription" in model_config.supported_tasks else None
|
||||
) if "transcription" in supported_tasks else None
|
||||
state.task = model_config.task
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
|
@ -14,6 +14,7 @@ import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@ -335,6 +336,14 @@ async def run_batch(
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
supported_tasks = await engine_client \
|
||||
.get_supported_tasks() # type: ignore
|
||||
else:
|
||||
supported_tasks = model_config.supported_tasks
|
||||
|
||||
logger.info("Supported_tasks: %s", supported_tasks)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
@ -351,7 +360,7 @@ async def run_batch(
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
) if "generate" in supported_tasks else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -359,19 +368,17 @@ async def run_batch(
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if "embed" in model_config.supported_tasks else None
|
||||
) if "embed" in supported_tasks else None
|
||||
|
||||
enable_serving_reranking = ("classify" in model_config.supported_tasks
|
||||
and getattr(model_config.hf_config,
|
||||
"num_labels", 0) == 1)
|
||||
enable_serving_reranking = ("classify" in supported_tasks and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if ("embed" in model_config.supported_tasks
|
||||
or enable_serving_reranking) else None
|
||||
) if ("embed" in supported_tasks or enable_serving_reranking) else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
@ -16,8 +16,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
@ -136,9 +136,9 @@ class ExecutorBase(ABC):
|
||||
return self.collective_rpc(rpc_func)
|
||||
|
||||
@cached_property # Avoid unnecessary RPC calls
|
||||
def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]:
|
||||
output = self.collective_rpc("get_supported_pooling_tasks")
|
||||
return tuple({task for tasks in output for task in tasks})
|
||||
def supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
output = self.collective_rpc("get_supported_tasks")
|
||||
return output[0]
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
|
@ -16,8 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.model_executor.pooling_metadata import ( # noqa: E501
|
||||
PoolingMetadata as V0PoolingMetadata)
|
||||
from vllm.model_executor.pooling_metadata import PoolingTensors
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
|
||||
|
||||
|
@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
@ -16,8 +16,8 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
get_prompt_token_ids)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import PoolerOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
|
||||
from .interfaces import SupportsV0Only
|
||||
|
@ -23,8 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsV0Only
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
|
@ -1,17 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||
|
||||
|
||||
class PoolingParams(
|
||||
msgspec.Struct,
|
||||
|
11
vllm/tasks.py
Normal file
11
vllm/tasks.py
Normal file
@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Literal, get_args
|
||||
|
||||
GenerationTask = Literal["generate", "transcription"]
|
||||
GENERATION_TASKS = get_args(GenerationTask)
|
||||
|
||||
PoolingTask = Literal["encode", "embed", "classify", "score"]
|
||||
POOLING_TASKS = get_args(PoolingTask)
|
||||
|
||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -211,6 +212,9 @@ class AsyncLLM(EngineClient):
|
||||
if handler := getattr(self, "output_handler", None):
|
||||
handler.cancel()
|
||||
|
||||
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return await self.engine_core.get_supported_tasks_async()
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
|
@ -23,6 +23,7 @@ from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import (bind_process_name, make_zmq_socket,
|
||||
@ -195,11 +196,17 @@ class EngineCore:
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_executor.supported_tasks
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
if pooling_params := request.pooling_params:
|
||||
supported_pooling_tasks = (
|
||||
self.model_executor.supported_pooling_tasks)
|
||||
supported_pooling_tasks = [
|
||||
task for task in self.get_supported_tasks()
|
||||
if task in POOLING_TASKS
|
||||
]
|
||||
|
||||
if pooling_params.task not in supported_pooling_tasks:
|
||||
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
|
||||
f"Supported tasks: {supported_pooling_tasks}")
|
||||
|
@ -21,6 +21,7 @@ import zmq.asyncio
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
@ -104,6 +105,9 @@ class EngineCoreClient(ABC):
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -170,6 +174,9 @@ class EngineCoreClient(ABC):
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -238,6 +245,9 @@ class InprocClient(EngineCoreClient):
|
||||
outputs, _ = self.engine_core.step()
|
||||
return outputs.get(0) or EngineCoreOutputs()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.engine_core.get_supported_tasks()
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
self.engine_core.add_request(request)
|
||||
|
||||
@ -608,6 +618,9 @@ class SyncMPClient(MPClient):
|
||||
|
||||
return future.result()
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.call_utility("get_supported_tasks")
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
if self.is_dp:
|
||||
self.engines_running = True
|
||||
@ -802,6 +815,9 @@ class AsyncMPClient(MPClient):
|
||||
self._ensure_output_queue_task()
|
||||
return await future
|
||||
|
||||
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
|
||||
return await self.call_utility_async("get_supported_tasks")
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
request.client_index = self.client_index
|
||||
await self._send_input(EngineCoreRequestType.ADD, request)
|
||||
|
@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -176,6 +177,9 @@ class LLMEngine:
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
return outputs
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.engine_core.get_supported_tasks()
|
||||
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
|
@ -30,15 +30,17 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import is_mixture_of_experts
|
||||
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling,
|
||||
is_pooling_model)
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
supports_transcription)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up)
|
||||
@ -1153,6 +1155,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_supported_generation_tasks(self) -> list[GenerationTask]:
|
||||
model = self.get_model()
|
||||
supported_tasks = list[GenerationTask]()
|
||||
|
||||
if is_text_generation_model(model):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if supports_transcription(model):
|
||||
if model.supports_transcription_only:
|
||||
return ["transcription"]
|
||||
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
@ -1160,6 +1177,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks = list[SupportedTask]()
|
||||
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.extend(self.get_supported_generation_tasks())
|
||||
if self.model_config.runner_type == "pooling":
|
||||
tasks.extend(self.get_supported_pooling_tasks())
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
|
@ -23,8 +23,8 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
@ -320,8 +320,8 @@ class Worker(WorkerBase):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
return self.model_runner.get_supported_pooling_tasks()
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
@ -27,13 +27,15 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
from vllm.model_executor.models.interfaces_base import is_pooling_model
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available,
|
||||
prev_power_of_2)
|
||||
from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
|
||||
@ -489,6 +491,21 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_supported_generation_tasks(self) -> list[GenerationTask]:
|
||||
model = self.get_model()
|
||||
supported_tasks = list[GenerationTask]()
|
||||
|
||||
if is_text_generation_model(model):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if supports_transcription(model):
|
||||
if model.supports_transcription_only:
|
||||
return ["transcription"]
|
||||
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
@ -496,6 +513,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks = list[SupportedTask]()
|
||||
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.extend(self.get_supported_generation_tasks())
|
||||
if self.model_config.runner_type == "pooling":
|
||||
tasks.extend(self.get_supported_pooling_tasks())
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
@ -21,7 +21,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@ -282,8 +282,8 @@ class TPUWorker:
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
return self.model_runner.get_supported_pooling_tasks()
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_runner.get_supported_tasks()
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
@ -12,9 +12,11 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.interfaces_base import is_pooling_model
|
||||
from vllm.pooling_params import PoolingTask
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
@ -224,6 +226,21 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_generation_tasks(self) -> list[GenerationTask]:
|
||||
model = self.get_model()
|
||||
supported_tasks = list[GenerationTask]()
|
||||
|
||||
if is_text_generation_model(model):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if supports_transcription(model):
|
||||
if model.supports_transcription_only:
|
||||
return ["transcription"]
|
||||
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
@ -231,6 +248,16 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks = list[SupportedTask]()
|
||||
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.extend(self.get_supported_generation_tasks())
|
||||
if self.model_config.runner_type == "pooling":
|
||||
tasks.extend(self.get_supported_pooling_tasks())
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
|
Reference in New Issue
Block a user