mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Simplify PoolerOutput and move to v1/outputs
(#25629)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -15,10 +15,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.outputs import PoolerOutput, SamplerOutput
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -16,9 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import current_stream, resolve_obj_by_qualname
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return PoolerClassify()
|
||||
|
||||
|
||||
def build_output(
|
||||
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
|
||||
# Pooling models D2H & synchronize occurs here
|
||||
if isinstance(all_data, list):
|
||||
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
|
||||
else:
|
||||
all_data = all_data.to("cpu", non_blocking=True)
|
||||
current_stream().synchronize()
|
||||
|
||||
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
||||
return PoolerOutput(outputs=all_outputs)
|
||||
|
||||
|
||||
class PoolingMethod(nn.Module, ABC):
|
||||
|
||||
@staticmethod
|
||||
@ -556,7 +543,7 @@ class SimplePooler(Pooler):
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return build_output(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
@ -607,7 +594,7 @@ class StepPooler(Pooler):
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return build_output(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPooler(Pooler):
|
||||
@ -678,7 +665,7 @@ class ClassifierPooler(Pooler):
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return build_output(scores)
|
||||
return scores
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
@ -708,7 +695,7 @@ class DispatchPooler(Pooler):
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
outputs = list[PoolingSequenceGroupOutput]()
|
||||
outputs = list[torch.Tensor]()
|
||||
offset = 0
|
||||
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
@ -722,10 +709,10 @@ class DispatchPooler(Pooler):
|
||||
pooling_metadata[offset:offset + num_items],
|
||||
)
|
||||
|
||||
outputs.extend(group_output.outputs)
|
||||
outputs.extend(group_output)
|
||||
offset += num_items
|
||||
|
||||
return PoolerOutput(outputs)
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"supported_task={self.get_supported_tasks()}"
|
||||
|
@ -12,12 +12,12 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||
PoolerHead, PoolerNormalize,
|
||||
PoolingParamsUpdate,
|
||||
build_output, get_prompt_lens,
|
||||
get_prompt_lens,
|
||||
get_prompt_token_ids)
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.sequence import PoolerOutput
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces_base import default_pooling_type
|
||||
@ -212,7 +212,7 @@ class GritLMPooler(Pooler):
|
||||
) -> PoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return build_output(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
|
||||
@default_pooling_type("MEAN")
|
||||
|
@ -11,7 +11,6 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorOutput)
|
||||
else:
|
||||
LoRARequest = Any
|
||||
KVConnectorOutput = Any
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
@ -48,29 +47,6 @@ class RequestMetrics:
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
|
||||
class PoolingSequenceGroupOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""The model output associated with a pooling sequence group."""
|
||||
# Annotated as Any to be compatible with msgspec
|
||||
# The actual type is in SequenceGroup.pooled_data
|
||||
data: Any
|
||||
|
||||
def get_data_nbytes(self) -> int:
|
||||
data: torch.Tensor = self.data
|
||||
return data.nbytes
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, PoolingSequenceGroupOutput):
|
||||
raise NotImplementedError()
|
||||
return self.data == other.data
|
||||
|
||||
|
||||
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
@ -119,30 +95,6 @@ class IntermediateTensors:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
|
||||
class PoolerOutput(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""The output from a pooling operation in the pooling model."""
|
||||
outputs: list[PoolingSequenceGroupOutput]
|
||||
|
||||
def get_data_nbytes(self) -> int:
|
||||
return sum(o.get_data_nbytes() for o in self.outputs)
|
||||
|
||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
|
||||
self.outputs[idx] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.outputs)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other,
|
||||
self.__class__) and self.outputs == other.outputs
|
||||
|
||||
|
||||
class ExecuteModelRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -65,6 +65,11 @@ class LogprobsTensors(NamedTuple):
|
||||
)
|
||||
|
||||
|
||||
# [num_reqs, <dynamic>]
|
||||
# The shape of each element depends on the pooler used
|
||||
PoolerOutput = Union[torch.Tensor, list[torch.Tensor]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
|
||||
|
@ -52,13 +52,14 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available,
|
||||
length_from_prompt_token_ids_or_embeds, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -79,7 +80,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
ModelRunnerOutput, PoolerOutput, SamplerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -1823,15 +1824,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
device=hidden_states.device)
|
||||
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
|
||||
|
||||
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
||||
raw_pooler_output = self.model.pooler(
|
||||
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
raw_pooler_output: PoolerOutput = model.pooler(
|
||||
hidden_states=hidden_states,
|
||||
pooling_metadata=pooling_metadata,
|
||||
)
|
||||
raw_pooler_output = json_map_leaves(
|
||||
lambda x: x.to("cpu", non_blocking=True),
|
||||
raw_pooler_output,
|
||||
)
|
||||
self._sync_device()
|
||||
|
||||
pooler_output: list[Optional[torch.Tensor]] = []
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||
|
||||
output = raw_output.data if seq_len == prompt_len else None
|
||||
output = raw_output if seq_len == prompt_len else None
|
||||
pooler_output.append(output)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
@ -3233,7 +3241,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
for task in self.get_supported_pooling_tasks():
|
||||
# Run a full batch with each task to ensure none of them OOMs
|
||||
output = self._dummy_pooler_run_task(hidden_states, task)
|
||||
output_size[task] = output.get_data_nbytes()
|
||||
output_size[task] = sum(o.nbytes for o in output)
|
||||
del output # Allow GC
|
||||
|
||||
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
||||
|
Reference in New Issue
Block a user