mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V0 Deprecation] Remove V0 Sequence class & Sampler (#25332)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
@ -48,10 +48,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
|
@ -7,8 +7,8 @@ from typing import Optional
|
||||
import pytest
|
||||
from transformers import AutoModelForSpeechSeq2Seq
|
||||
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput,
|
||||
VllmRunner)
|
||||
|
@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.image import convert_image_mode, rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
|
||||
PromptImageInput, VllmRunner)
|
||||
|
@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
||||
from vllm.logprobs import Logprob, SampleLogprobs
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
|
||||
GenerationConfig, GenerationMixin)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
||||
|
@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import RunnerOption
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset,
|
||||
|
@ -12,7 +12,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
@ -8,10 +8,7 @@ import pytest
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast):
|
||||
|
||||
assert decoded_text == ''
|
||||
assert out_ids == [len(tokenizer)]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_name,
|
||||
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
|
||||
trust_remote_code=False,
|
||||
revision=None,
|
||||
)
|
||||
|
||||
return Detokenizer(tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
def create_complete_sequence_token_ids(complete_sequence: str,
|
||||
tokenizer) -> list[int]:
|
||||
return tokenizer(complete_sequence, add_special_tokens=False).input_ids
|
||||
|
||||
|
||||
def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or []
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
inputs=token_inputs(prompt_token_ids),
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_logprobs(
|
||||
complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]:
|
||||
return [{
|
||||
token_id: Logprob(logprob=0.0),
|
||||
token_id + 1: Logprob(logprob=0.1)
|
||||
} for token_id in complete_sequence_token_ids]
|
||||
|
||||
|
||||
def create_dummy_prompt_logprobs(
|
||||
complete_sequence_token_ids: list[int]
|
||||
) -> list[Optional[dict[int, Any]]]:
|
||||
# logprob for the first prompt token is None.
|
||||
logprobs: list[Optional[dict[int, Any]]] = [None]
|
||||
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
||||
return logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True)
|
||||
def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: list[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
"""Verify Detokenizer decodes logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
logprobs=2)
|
||||
|
||||
# Run sequentially.
|
||||
seq = create_sequence()
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
sequential_logprobs_text_chosen_token: list[str] = []
|
||||
sequential_logprobs_text_other_token: list[str] = []
|
||||
for new_token, logprobs in zip(complete_sequence_token_ids,
|
||||
dummy_logprobs):
|
||||
seq.append_token_id(new_token, logprobs)
|
||||
detokenizer.decode_sequence_inplace(seq, sampling_params)
|
||||
sequential_logprobs_text_chosen_token.append(
|
||||
seq.output_logprobs[-1][new_token].decoded_token)
|
||||
sequential_logprobs_text_other_token.append(
|
||||
seq.output_logprobs[-1][new_token + 1].decoded_token)
|
||||
sequential_result = seq.output_text
|
||||
|
||||
assert sequential_result == "".join(sequential_logprobs_text_chosen_token)
|
||||
assert sequential_result != "".join(sequential_logprobs_text_other_token)
|
||||
|
||||
if not skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# generated text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
assert sequential_result == complete_sequence
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: list[int],
|
||||
detokenizer: Detokenizer):
|
||||
|
||||
# We want to use skip_special_tokens=False here but Mistral tokenizers
|
||||
# don't support that.
|
||||
if complete_sequence not in SPECIAL_TOKS_TRUTH:
|
||||
skip_special_tokens = True
|
||||
elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
|
||||
skip_special_tokens = False
|
||||
else:
|
||||
pytest.skip("MistralTokenizers don't support "
|
||||
"skip_special_tokens=False")
|
||||
return
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
seq = create_sequence(complete_sequence_token_ids)
|
||||
seq_group = SequenceGroup(request_id="1",
|
||||
seqs=[seq],
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=0.0)
|
||||
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group,
|
||||
dummy_logprobs,
|
||||
position_offset=0)
|
||||
# First logprob is None.
|
||||
decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[
|
||||
1:] # type: ignore
|
||||
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenizer = detokenizer.tokenizer
|
||||
text_full = tokenizer.decode(token_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text_first = tokenizer.decode(token_ids[0],
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# prompt text. Note that the first logprob is None.
|
||||
assert text == "".join([
|
||||
logprobs[token_id].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
assert text != "".join([
|
||||
logprobs[token_id + 1].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
|
@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow
|
||||
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
MODEL = "ai21labs/Jamba-tiny-dev"
|
||||
|
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
|
||||
Qwen3CoderToolParser)
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
|
||||
|
@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
|
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
|
@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore)
|
||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
|
@ -15,10 +15,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
||||
ray)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async)
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
if ray is not None:
|
||||
from ray.actor import ActorHandle
|
||||
|
@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||
build_explicit_enc_dec_prompt, embeds_inputs,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
||||
InputRegistry)
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
"""
|
||||
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
|
||||
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
|
||||
target model.
|
||||
"""
|
||||
from .registry import InputContext, InputProcessingContext
|
||||
|
||||
__all__ = [
|
||||
"DataPrompt",
|
||||
@ -36,9 +28,6 @@ __all__ = [
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
"INPUT_REGISTRY",
|
||||
"DummyData",
|
||||
"InputContext",
|
||||
"InputProcessingContext",
|
||||
"InputRegistry",
|
||||
]
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
|
||||
MultiModalRegistry)
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
else:
|
||||
ModelConfig = Any
|
||||
MultiModalDataDict = Any
|
||||
MultiModalPlaceholderDict = Any
|
||||
MultiModalRegistry = Any
|
||||
SequenceData = Any
|
||||
AnyTokenizer = Any
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -191,61 +184,3 @@ class InputProcessingContext(InputContext):
|
||||
f"on data={data} with kwargs={allowed_kwargs}")
|
||||
|
||||
raise ValueError(msg) from exc
|
||||
|
||||
|
||||
class DummyData(NamedTuple):
|
||||
"""
|
||||
Dummy data used for profiling.
|
||||
|
||||
Note: This is only used in V0.
|
||||
"""
|
||||
|
||||
seq_data: SequenceData
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
|
||||
|
||||
|
||||
class InputRegistry:
|
||||
"""
|
||||
Note: This is only used in V0.
|
||||
"""
|
||||
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
seq_len: int,
|
||||
mm_registry: MultiModalRegistry,
|
||||
is_encoder_data: bool = False,
|
||||
) -> DummyData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
||||
return DummyData(seq_data=seq_data)
|
||||
|
||||
cache = processor_only_cache_from_config(model_config, mm_registry)
|
||||
|
||||
# Encoder dummy data does not contain multi-modal data
|
||||
if is_encoder_data:
|
||||
enc_data = mm_registry.get_encoder_dummy_data(model_config,
|
||||
seq_len,
|
||||
cache=cache)
|
||||
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
|
||||
return DummyData(seq_data=seq_data)
|
||||
|
||||
dec_data = mm_registry.get_decoder_dummy_data(model_config,
|
||||
seq_len,
|
||||
cache=cache)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
||||
multi_modal_data=dec_data.multi_modal_data.get_data(),
|
||||
multi_modal_placeholders=dec_data.multi_modal_placeholders,
|
||||
)
|
||||
|
@ -3,13 +3,11 @@
|
||||
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingMetadataCache)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
__all__ = [
|
||||
"SamplingMetadata",
|
||||
"SamplingMetadataCache",
|
||||
"set_random_seed",
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
|
@ -1,13 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A layer that compute logits from hidden_stats."""
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -16,11 +13,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
|
||||
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
|
||||
_logits_processor_threadpool = ThreadPoolExecutor(
|
||||
envs.VLLM_LOGITS_PROCESSOR_THREADS)
|
||||
|
||||
|
||||
@CustomOp.register("logits_processor")
|
||||
class LogitsProcessor(CustomOp):
|
||||
@ -60,15 +52,10 @@ class LogitsProcessor(CustomOp):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
prune_hidden_states: bool = True,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
if sampling_metadata is not None and prune_hidden_states:
|
||||
hidden_states = _prune_hidden_states(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if logits is not None:
|
||||
@ -79,12 +66,6 @@ class LogitsProcessor(CustomOp):
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
if sampling_metadata is not None and \
|
||||
sampling_metadata.seq_groups is not None:
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
@ -125,75 +106,3 @@ class LogitsProcessor(CustomOp):
|
||||
s += f", org_vocab_size={self.org_vocab_size}"
|
||||
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
|
||||
return s
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios
|
||||
# (warmup, profile_run) we might not have selected_token_indices,
|
||||
# so we skip pruning.
|
||||
if sampling_metadata.selected_token_indices is not None:
|
||||
return hidden_states.index_select(
|
||||
0, sampling_metadata.selected_token_indices)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _apply_logits_processors(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
found_logits_processors = False
|
||||
logits_processed = 0
|
||||
logits_row_ids_and_logits_row_futures = []
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
logits_processors = sampling_params.logits_processors
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
|
||||
for seq_id, logits_row_idx in zip(seq_ids,
|
||||
seq_group.sample_indices):
|
||||
logits_row = logits[logits_row_idx]
|
||||
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
|
||||
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
|
||||
|
||||
if _logits_processor_threadpool is not None:
|
||||
logits_row_ids_and_logits_row_futures.append(
|
||||
(logits_row_idx,
|
||||
_logits_processor_threadpool.submit(
|
||||
_apply_logits_processors_single_seq, logits_row,
|
||||
logits_processors, past_tokens_ids,
|
||||
prompt_tokens_ids)))
|
||||
else:
|
||||
logits[logits_row_idx] = \
|
||||
_apply_logits_processors_single_seq(
|
||||
logits_row, logits_processors, past_tokens_ids,
|
||||
prompt_tokens_ids)
|
||||
|
||||
logits_processed += len(seq_group.sample_indices) + len(
|
||||
seq_group.prompt_logprob_indices)
|
||||
|
||||
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
|
||||
logits[logits_row_idx] = future.result()
|
||||
|
||||
if found_logits_processors:
|
||||
# verifies that no rows in logits were missed unexpectedly
|
||||
assert logits_processed == logits.shape[0]
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_logits_processors_single_seq(logits_row, logits_processors,
|
||||
past_tokens_ids,
|
||||
prompt_tokens_ids) -> torch.Tensor:
|
||||
for logits_processor in logits_processors:
|
||||
parameters = inspect.signature(logits_processor).parameters
|
||||
if len(parameters) == 3:
|
||||
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
|
||||
logits_row)
|
||||
else:
|
||||
logits_row = logits_processor(past_tokens_ids, logits_row)
|
||||
return logits_row
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,18 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
@ -105,8 +102,10 @@ class Medusa(nn.Module):
|
||||
return [block(hidden_states) for block in self.blocks]
|
||||
|
||||
def compute_logits(
|
||||
self, hidden_states: list[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata) -> list[torch.Tensor]:
|
||||
self,
|
||||
hidden_states: list[torch.Tensor],
|
||||
sampling_metadata,
|
||||
) -> list[torch.Tensor]:
|
||||
logits_lst: list[torch.Tensor] = []
|
||||
|
||||
for hs, lm_head in zip(hidden_states, self.lm_heads):
|
||||
@ -130,57 +129,6 @@ class Medusa(nn.Module):
|
||||
|
||||
return logits_lst
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: list[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[SamplerOutput]:
|
||||
logits = torch.stack(logits, dim=0).float()
|
||||
logprobs = torch.log_softmax(logits, dim=-1)
|
||||
token_ids = logits.argmax(-1) # support only top-1 for now
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
|
||||
token_id_list = []
|
||||
token_prob_list = []
|
||||
token_logprob_list = []
|
||||
|
||||
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
token_id_list.append(token_ids[:, seq_group.sample_indices])
|
||||
token_prob_list.append(probs[:, seq_group.sample_indices])
|
||||
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
|
||||
|
||||
outputs: list[Optional[SamplerOutput]] = []
|
||||
for idx in range(len(sampling_metadata.seq_groups)):
|
||||
outputs.append(
|
||||
SamplerOutput(
|
||||
outputs=None,
|
||||
sampled_token_probs=token_prob_list[idx].squeeze(1),
|
||||
logprobs=token_logprob_list[idx].squeeze(1),
|
||||
sampled_token_ids=token_id_list[idx].squeeze(1),
|
||||
))
|
||||
|
||||
return outputs
|
||||
|
||||
def generate_proposals(
|
||||
self,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
# During preemption, we may receive an empty tensor (batch_size=0)
|
||||
if previous_hidden_states.size(0) == 0:
|
||||
# Return None to signal the Top1Proposer that no proposals
|
||||
# were generated for this batch, allowing it to handle this
|
||||
# special case appropriately
|
||||
return None
|
||||
|
||||
return self.sample(
|
||||
logits=self.compute_logits(
|
||||
hidden_states=self.forward(previous_hidden_states),
|
||||
sampling_metadata=sampling_metadata,
|
||||
),
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
@ -8,9 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module):
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
config.vocab_size, 1.0)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def generate_proposals(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
num_predict_tokens: int,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[SamplerOutput]:
|
||||
if num_predict_tokens > self.max_speculative_tokens:
|
||||
raise ValueError(f"Max speculative tokens for model is "
|
||||
f"{self.max_speculative_tokens}, but "
|
||||
f"{num_predict_tokens} were requested")
|
||||
# NOTE(woosuk): This method is commented out because it is old code
|
||||
# using V0. We should either port it to V1 or remove it.
|
||||
|
||||
# b x 1 x d
|
||||
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||
# def generate_proposals(
|
||||
# self,
|
||||
# input_ids: torch.Tensor,
|
||||
# previous_hidden_states: torch.Tensor,
|
||||
# num_predict_tokens: int,
|
||||
# sampling_metadata: SamplingMetadata,
|
||||
# ) -> list[SamplerOutput]:
|
||||
# if num_predict_tokens > self.max_speculative_tokens:
|
||||
# raise ValueError(f"Max speculative tokens for model is "
|
||||
# f"{self.max_speculative_tokens}, but "
|
||||
# f"{num_predict_tokens} were requested")
|
||||
|
||||
if self.scale_input:
|
||||
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
|
||||
# # b x 1 x d
|
||||
# previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||
|
||||
# b x 1
|
||||
last_tokens = input_ids.unsqueeze(1)
|
||||
# if self.scale_input:
|
||||
# previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
|
||||
|
||||
next_tokens = []
|
||||
# # b x 1
|
||||
# last_tokens = input_ids.unsqueeze(1)
|
||||
|
||||
for head_index in range(num_predict_tokens):
|
||||
# next_tokens = []
|
||||
|
||||
# Project and predict
|
||||
z = self.emb[head_index](last_tokens) # b k d
|
||||
states = self.proj[head_index](previous_hidden_states)
|
||||
# for head_index in range(num_predict_tokens):
|
||||
|
||||
# Weighted add of state_weight*state and emb_weight*z
|
||||
# Let subsequent LN take care of denominator
|
||||
# state_weight is close to 1, so shouldn't be any precision issues
|
||||
states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||
# # Project and predict
|
||||
# z = self.emb[head_index](last_tokens) # b k d
|
||||
# states = self.proj[head_index](previous_hidden_states)
|
||||
|
||||
states = self.activation(self.ln[head_index](states)) # b k d
|
||||
previous_hidden_states = states
|
||||
# TODO: not yet supporting top_k_tokens_per_head
|
||||
states = states.flatten(0, 1)
|
||||
# # Weighted add of state_weight*state and emb_weight*z
|
||||
# # Let subsequent LN take care of denominator
|
||||
# # state_weight is close to 1, so shouldn't be any precision issues
|
||||
# states.add_(z, alpha=self.emb_weight / self.state_weight)
|
||||
|
||||
logits = self.logits_processor(self.head[head_index], states,
|
||||
sampling_metadata)
|
||||
# states = self.activation(self.ln[head_index](states)) # b k d
|
||||
# previous_hidden_states = states
|
||||
# # TODO: not yet supporting top_k_tokens_per_head
|
||||
# states = states.flatten(0, 1)
|
||||
|
||||
output = self.sampler(logits, sampling_metadata)
|
||||
last_tokens = output.sampled_token_ids
|
||||
next_tokens.append(output)
|
||||
# logits = self.logits_processor(self.head[head_index], states,
|
||||
# sampling_metadata)
|
||||
|
||||
return next_tokens
|
||||
# output = self.sampler(logits, sampling_metadata)
|
||||
# last_tokens = output.sampled_token_ids
|
||||
# next_tokens.append(output)
|
||||
|
||||
# return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
@ -697,16 +697,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# If the shape is the same, it means that we have already
|
||||
# prune hidden states manually.
|
||||
prune_hidden_states = hidden_states.size(
|
||||
0) != sampling_metadata.selected_token_indices.size(0)
|
||||
processed_logits = self.logits_processor(
|
||||
self.lm_head,
|
||||
hidden_states,
|
||||
sampling_metadata,
|
||||
self.embedding_bias,
|
||||
prune_hidden_states=prune_hidden_states)
|
||||
)
|
||||
return processed_logits
|
||||
|
||||
def load_weights(
|
||||
|
@ -1,597 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from array import array
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupToSample:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Sequence ids for the sequence group in a previous step.
|
||||
seq_ids: list[int]
|
||||
sampling_params: SamplingParams
|
||||
# seq_id -> sequence data.
|
||||
seq_data: dict[int, SequenceData]
|
||||
# The length of the sequence (all tokens seen in the past + new token to
|
||||
# compute attention) of the sequence group. None if it is in a decode
|
||||
# stage.
|
||||
seq_len: Optional[int]
|
||||
# The length of new query tokens to compute in the current step. None if it
|
||||
# is in a decode stage. The length of query_len <= seq_len if chunked
|
||||
# prefill is enabled.
|
||||
query_len: Optional[int]
|
||||
# A random number generator for sampling.
|
||||
generator: Optional[torch.Generator]
|
||||
# True if the sequence group is in prefill stage. False if it is in a
|
||||
# decode stage.
|
||||
is_prompt: bool
|
||||
# Query token indices from logits. to compute prompt logprob. Empty if
|
||||
# prompt logprob is not required.
|
||||
prompt_logprob_indices: list[int]
|
||||
# Sample token indices from logits. Empty if sampling is not required.
|
||||
sample_indices: list[int]
|
||||
|
||||
@property
|
||||
def do_sample(self):
|
||||
return len(self.sample_indices) > 0
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.prompt_logprob_indices) > 0:
|
||||
assert self.sampling_params.prompt_logprobs is not None
|
||||
if self.is_prompt:
|
||||
assert self.seq_len is not None
|
||||
assert self.query_len is not None
|
||||
|
||||
|
||||
def gen_seq_group_to_sample_builder(num_seqs: int):
|
||||
return lambda: SequenceGroupToSample(
|
||||
seq_ids=[0] * num_seqs,
|
||||
sampling_params=None,
|
||||
seq_data=None, # type: ignore
|
||||
seq_len=0,
|
||||
query_len=0,
|
||||
generator=None,
|
||||
is_prompt=True,
|
||||
prompt_logprob_indices=[],
|
||||
sample_indices=[],
|
||||
)
|
||||
|
||||
|
||||
class SamplingMetadataCache:
|
||||
"""Used to cache SamplingMetadata objects between scheduler iterations"""
|
||||
|
||||
def __init__(self):
|
||||
self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}
|
||||
|
||||
def get_cached_seq_group_to_sample(self, num_seqs):
|
||||
if num_seqs not in self._seq_group_to_sample_cache:
|
||||
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
|
||||
gen_seq_group_to_sample_builder(num_seqs))
|
||||
|
||||
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
|
||||
return obj
|
||||
|
||||
def reset(self):
|
||||
for cache in self._seq_group_to_sample_cache.values():
|
||||
cache.reset()
|
||||
|
||||
|
||||
class SamplingMetadata:
|
||||
"""Metadata for input sequences. Used in sampler.
|
||||
|
||||
The usage is as follows;
|
||||
```
|
||||
hidden_states = execute_model(...)
|
||||
logits = hidden_states[sampling_metadata.selected_token_indices]
|
||||
sample(logits)
|
||||
|
||||
def sample(logits):
|
||||
# Use categorized_sample_indices for sampling....
|
||||
```
|
||||
|
||||
Args:
|
||||
seq_groups: List of batched sequence groups.
|
||||
selected_token_indices: (num_query_tokens_to_logprob). Indices to find
|
||||
logits from the initial model output hidden states.
|
||||
categorized_sample_indices: SamplingType -> token indices to sample.
|
||||
Each token indices is 2D tensor of (num_indices, num_indices) where
|
||||
the first item means the sample index within the returned logit
|
||||
(before pruning padding), and the second item means the sample
|
||||
index after pruning using selected_token_indices.
|
||||
For example, if the returned logit is [1, 2, 3], and we select
|
||||
[1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
|
||||
The first tuple is [1, 2] (sampled index within original logit),
|
||||
and the second tuple is [0, 1] (sampled index within pruned logit).
|
||||
num_prompts: Number of prompt sequence groups in seq_groups.
|
||||
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
||||
serialization of token outputs.
|
||||
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
||||
tensors that are part of the sampler forward pass. Currently,
|
||||
it is mainly used for multi-step decode.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_groups: list[SequenceGroupToSample],
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: dict[SamplingType, torch.Tensor],
|
||||
num_prompts: int,
|
||||
skip_sampler_cpu_output: bool = False,
|
||||
reuse_sampling_tensors: bool = False,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
self.num_prompts = num_prompts
|
||||
self.skip_sampler_cpu_output = skip_sampler_cpu_output
|
||||
self.reuse_sampling_tensors = reuse_sampling_tensors
|
||||
|
||||
@staticmethod
|
||||
def prepare(
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata],
|
||||
seq_lens: list[int],
|
||||
query_lens: list[int],
|
||||
device: str,
|
||||
pin_memory: bool,
|
||||
generators: Optional[dict[str, torch.Generator]] = None,
|
||||
cache: Optional[SamplingMetadataCache] = None,
|
||||
) -> "SamplingMetadata":
|
||||
(
|
||||
seq_groups,
|
||||
selected_token_indices,
|
||||
categorized_sample_indices,
|
||||
num_prompts,
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
||||
device, generators, cache)
|
||||
selected_token_indices = async_tensor_h2d(
|
||||
selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
categorized_sample_indices = {
|
||||
t:
|
||||
async_tensor_h2d(
|
||||
seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=device,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=categorized_sample_indices,
|
||||
num_prompts=num_prompts,
|
||||
)
|
||||
return sampling_metadata
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"SamplingMetadata("
|
||||
f"seq_groups={self.seq_groups}, "
|
||||
f"selected_token_indices={self.selected_token_indices}, "
|
||||
f"categorized_sample_indices={self.categorized_sample_indices})")
|
||||
|
||||
|
||||
def _prepare_seq_groups(
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata],
|
||||
seq_lens: list[int],
|
||||
query_lens: list[int],
|
||||
device: str,
|
||||
generators: Optional[dict[str, torch.Generator]] = None,
|
||||
cache: Optional[SamplingMetadataCache] = None,
|
||||
) -> tuple[
|
||||
list[SequenceGroupToSample],
|
||||
list[int],
|
||||
dict[SamplingType, list[int]],
|
||||
int,
|
||||
]:
|
||||
"""Prepare sequence groups and indices for sampling.
|
||||
|
||||
Args:
|
||||
seq_group_metadata_list: A list of sequence group to batch.
|
||||
seq_lens: A list of sequence lens per sequence group.
|
||||
Index of prompt len should match with seq_group_metadata_list.
|
||||
query_lens: A list of query lengths. Prompt lens include the length
|
||||
of entire prompt tokens, and it could be shorter.
|
||||
device: A device to use for random number generators,
|
||||
`SequenceGroupToSample.generator`.
|
||||
generators: A store of per-request random number generators used
|
||||
for seeded requests.
|
||||
|
||||
Returns:
|
||||
seq_groups: A list of sequence group to sample.
|
||||
selected_token_indices: See the definition from `SamplingMetadata`.
|
||||
categorized_sample_indices: See the definition from `SamplingMetadata`.
|
||||
num_prompts: Total number of prompts from `seq_group_metadata_list`.
|
||||
"""
|
||||
# Batched sequence groups for the current model forward stsep.
|
||||
seq_groups: list[SequenceGroupToSample] = []
|
||||
# A list of token indices to sample/compute logprob. It is used to
|
||||
# prune the outcome logits from the model for the performance.
|
||||
selected_token_indices: list[int] = []
|
||||
# Used for selected_token_indices.
|
||||
model_output_idx = 0
|
||||
|
||||
# Sampling type -> (
|
||||
# indices to sample/prompt logprob within pruned output logits,
|
||||
# indices to sample within pruned logits)
|
||||
categorized_sample_indices: dict[SamplingType, list[int]] = {
|
||||
t: []
|
||||
for t in SamplingType
|
||||
}
|
||||
# Index of logits to compute logprob. Logits include both prompt logprob
|
||||
# and sample logprob indices.
|
||||
logit_idx = 0
|
||||
# Total number of prompts from given sequence groups.
|
||||
num_prompts = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
|
||||
if cache is not None:
|
||||
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
|
||||
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
sample_obj.seq_ids[j] = seq_id
|
||||
|
||||
sample_obj.prompt_logprob_indices.clear()
|
||||
sample_obj.sample_indices.clear()
|
||||
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
generator: Optional[torch.Generator] = None
|
||||
# If the current seq group is in decode stage, it is None.
|
||||
seq_len: Optional[int] = None
|
||||
query_len: Optional[int] = None
|
||||
prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
|
||||
if cache is not None else [])
|
||||
sample_indices: list[int] = (sample_obj.sample_indices
|
||||
if cache is not None else [])
|
||||
do_sample = seq_group_metadata.do_sample
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
if sampling_params.seed is not None:
|
||||
generator = torch.Generator(device=device).manual_seed(
|
||||
sampling_params.seed)
|
||||
if generators is not None:
|
||||
generators[seq_group_metadata.request_id] = generator
|
||||
|
||||
num_prompts += 1
|
||||
num_prefill_sample = len(seq_ids)
|
||||
assert num_prefill_sample == 1
|
||||
assert query_lens is not None and seq_lens is not None
|
||||
query_len, seq_len = query_lens[i], seq_lens[i]
|
||||
# If we need sampling, exclude num_prefill_sample tokens from
|
||||
# prompt logprob.
|
||||
prompt_logprob_len = (query_len - num_prefill_sample
|
||||
if do_sample else query_len)
|
||||
sample_len = num_prefill_sample if do_sample else 0
|
||||
else:
|
||||
# Decode
|
||||
prompt_logprob_len = 0
|
||||
query_len = query_lens[i] if query_lens is not None and len(
|
||||
query_lens) > 0 else 1
|
||||
sample_len = len(seq_ids) * query_len if do_sample else 0
|
||||
|
||||
if sampling_params.seed is not None and generators is not None:
|
||||
generator = generators.get(seq_group_metadata.request_id)
|
||||
|
||||
# Update indices to select from the model output.
|
||||
"""
|
||||
This blocks computes selected_token_indices which is used in the
|
||||
following way.
|
||||
|
||||
hidden_states = model(...)
|
||||
logits = hidden_states[selected_token_indices]
|
||||
"""
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(model_output_idx, model_output_idx + prompt_logprob_len))
|
||||
model_output_idx += prompt_logprob_len
|
||||
if do_sample:
|
||||
selected_token_indices.extend(
|
||||
range(model_output_idx, model_output_idx + sample_len))
|
||||
model_output_idx += sample_len
|
||||
|
||||
# We now find indices for logprob computation and sampling.
|
||||
"""
|
||||
This block computes categorized_sample_indices which is used in the
|
||||
following way.
|
||||
|
||||
hidden_states = model(...)
|
||||
logits = hidden_states[selected_token_indices]
|
||||
def sample(logits):
|
||||
# Use categorized_sample_indices for sampling.
|
||||
# prompt_logprob_indices to find prompt logprob indices.
|
||||
# sample_indices to find sample indices.
|
||||
"""
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
prompt_logprob_indices.extend(
|
||||
range(logit_idx, logit_idx + prompt_logprob_len))
|
||||
logit_idx += prompt_logprob_len
|
||||
if do_sample:
|
||||
sample_indices.extend(range(logit_idx, logit_idx + sample_len))
|
||||
categorized_sample_indices[sampling_params.sampling_type].extend(
|
||||
list(range(logit_idx, logit_idx + sample_len)))
|
||||
logit_idx += sample_len
|
||||
|
||||
if cache is not None:
|
||||
sample_obj.sampling_params = sampling_params
|
||||
sample_obj.seq_data = seq_group_metadata.seq_data
|
||||
sample_obj.seq_len = seq_len
|
||||
sample_obj.query_len = query_len
|
||||
sample_obj.generator = generator
|
||||
sample_obj.is_prompt = is_prompt
|
||||
else:
|
||||
sample_obj = SequenceGroupToSample(
|
||||
seq_ids=list(seq_ids),
|
||||
sampling_params=sampling_params,
|
||||
seq_data=seq_group_metadata.seq_data,
|
||||
seq_len=seq_len,
|
||||
query_len=query_len,
|
||||
generator=generator,
|
||||
is_prompt=is_prompt,
|
||||
prompt_logprob_indices=list(prompt_logprob_indices),
|
||||
sample_indices=list(sample_indices),
|
||||
)
|
||||
|
||||
seq_groups.append(sample_obj)
|
||||
|
||||
if cache is not None:
|
||||
cache.reset()
|
||||
|
||||
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
||||
num_prompts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingTensors:
|
||||
"""Tensors for sampling."""
|
||||
|
||||
temperatures: torch.Tensor
|
||||
top_ps: torch.Tensor
|
||||
top_ks: torch.Tensor
|
||||
min_ps: torch.Tensor
|
||||
presence_penalties: torch.Tensor
|
||||
frequency_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
prompt_tokens: torch.Tensor
|
||||
output_tokens: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def from_sampling_metadata(
|
||||
cls,
|
||||
sampling_metadata: "SamplingMetadata",
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple["SamplingTensors", bool, bool, bool]:
|
||||
prompt_tokens: list[array] = []
|
||||
output_tokens: list[array] = []
|
||||
top_ks: list[int] = []
|
||||
temperatures: list[float] = []
|
||||
top_ps: list[float] = []
|
||||
min_ps: list[float] = []
|
||||
presence_penalties: list[float] = []
|
||||
frequency_penalties: list[float] = []
|
||||
repetition_penalties: list[float] = []
|
||||
do_penalties = False
|
||||
do_top_p_top_k = False
|
||||
do_min_p = False
|
||||
|
||||
assert sampling_metadata.seq_groups is not None
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
temperature = sampling_params.temperature
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
r = sampling_params.repetition_penalty
|
||||
top_p = sampling_params.top_p
|
||||
min_p = sampling_params.min_p
|
||||
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
top_k = vocab_size if top_k < 1 else top_k
|
||||
if temperature < _SAMPLING_EPS:
|
||||
# NOTE: Zero temperature means deterministic sampling
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
||||
or top_k != vocab_size):
|
||||
do_top_p_top_k = True
|
||||
if not do_min_p and min_p > _SAMPLING_EPS:
|
||||
do_min_p = True
|
||||
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
||||
or abs(f) >= _SAMPLING_EPS
|
||||
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||||
do_penalties = True
|
||||
|
||||
is_prompt = seq_group.is_prompt
|
||||
if is_prompt and sampling_params.prompt_logprobs is not None:
|
||||
# For tokens in the prompt that we only need to get
|
||||
# their logprobs
|
||||
query_len = seq_group.query_len
|
||||
assert query_len is not None
|
||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||
temperatures += [temperature] * prefill_len
|
||||
top_ps += [top_p] * prefill_len
|
||||
top_ks += [top_k] * prefill_len
|
||||
min_ps += [min_p] * prefill_len
|
||||
presence_penalties += [0] * prefill_len
|
||||
frequency_penalties += [0] * prefill_len
|
||||
repetition_penalties += [1] * prefill_len
|
||||
|
||||
if seq_group.do_sample:
|
||||
sample_lens = len(seq_group.sample_indices)
|
||||
assert sample_lens >= len(seq_ids)
|
||||
temperatures += [temperature] * sample_lens
|
||||
top_ps += [top_p] * sample_lens
|
||||
top_ks += [top_k] * sample_lens
|
||||
min_ps += [min_p] * sample_lens
|
||||
presence_penalties += [p] * sample_lens
|
||||
frequency_penalties += [f] * sample_lens
|
||||
repetition_penalties += [r] * sample_lens
|
||||
|
||||
if do_penalties:
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
sampling_params = seq_group.sampling_params
|
||||
if (seq_group.is_prompt
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||
prompt_tokens.extend(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||||
for _ in range(prefill_len))
|
||||
output_tokens.extend(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE)
|
||||
for _ in range(prefill_len))
|
||||
if seq_group.do_sample:
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
||||
output_tokens.append(seq_data.output_token_ids_array)
|
||||
|
||||
sampling_tensors = SamplingTensors.from_lists(
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
min_ps,
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
repetition_penalties,
|
||||
prompt_tokens,
|
||||
output_tokens,
|
||||
vocab_size,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
||||
|
||||
@classmethod
|
||||
def from_lists(
|
||||
cls,
|
||||
temperatures: list[float],
|
||||
top_ps: list[float],
|
||||
top_ks: list[int],
|
||||
min_ps: list[float],
|
||||
presence_penalties: list[float],
|
||||
frequency_penalties: list[float],
|
||||
repetition_penalties: list[float],
|
||||
prompt_tokens: list[array],
|
||||
output_tokens: list[array],
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> "SamplingTensors":
|
||||
# Note that the performance will be very bad without
|
||||
# pinned memory.
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
do_penalties = prompt_tokens or output_tokens
|
||||
|
||||
if do_penalties:
|
||||
prompt_t = make_tensor_with_pad(
|
||||
prompt_tokens,
|
||||
vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
output_t = make_tensor_with_pad(
|
||||
output_tokens,
|
||||
vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
else:
|
||||
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
||||
prompt_t = empty_tensor
|
||||
output_t = empty_tensor
|
||||
|
||||
temperatures_t = torch.tensor(
|
||||
temperatures,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
top_ps_t = torch.tensor(
|
||||
top_ps,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
min_ps_t = torch.tensor(
|
||||
min_ps,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
presence_penalties_t = torch.tensor(
|
||||
presence_penalties,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
frequency_penalties_t = torch.tensor(
|
||||
frequency_penalties,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
repetition_penalties_t = torch.tensor(
|
||||
repetition_penalties,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
top_ks_t = torch.tensor(
|
||||
top_ks,
|
||||
device="cpu",
|
||||
dtype=torch.int,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
# Because the memory is pinned, we can do non-blocking
|
||||
# transfer to device.
|
||||
|
||||
return cls(
|
||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
||||
min_ps=min_ps_t.to(device=device, non_blocking=True),
|
||||
presence_penalties=presence_penalties_t.to(device=device,
|
||||
non_blocking=True),
|
||||
frequency_penalties=frequency_penalties_t.to(device=device,
|
||||
non_blocking=True),
|
||||
repetition_penalties=repetition_penalties_t.to(device=device,
|
||||
non_blocking=True),
|
||||
prompt_tokens=prompt_t.to(device=device, non_blocking=True),
|
||||
output_tokens=output_t.to(device=device, non_blocking=True),
|
||||
)
|
||||
# Placeholder until it can be safely removed.
|
||||
pass
|
||||
|
1322
vllm/sequence.py
1322
vllm/sequence.py
File diff suppressed because it is too large
Load Diff
@ -1,162 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
|
||||
SequenceGroup)
|
||||
|
||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally)
|
||||
from .tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: list[Optional[dict[
|
||||
int, Logprob]]],
|
||||
position_offset: int) -> None:
|
||||
"""Decodes the logprobs for the prompt of a sequence group.
|
||||
|
||||
Args:
|
||||
seq_group: The sequence group to decode.
|
||||
prompt_logprobs: The logprobs to decode.
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
relative to the start of the sequence (for chunked prefill).
|
||||
|
||||
Returns:
|
||||
The prompt logprobs with the decoded tokens.
|
||||
"""
|
||||
prms = seq_group.sampling_params
|
||||
assert prms is not None
|
||||
|
||||
# We can pick any sequence for the prompt.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
# Only prompt, without the generated token.
|
||||
all_token_ids = seq.get_token_ids()
|
||||
prompt_token_ids = all_token_ids[:-1]
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
next_iter_prefix_offset = 0
|
||||
next_iter_read_offset = 0
|
||||
next_iter_tokens: list[str] = []
|
||||
prev_tokens = None
|
||||
|
||||
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
|
||||
prompt_logprobs):
|
||||
|
||||
# Absolute token position equals the index in the logprobs
|
||||
# list plus the offset of the entire logprobs list relative
|
||||
# to the start of the sequence.
|
||||
token_position = token_position_in_logprob + position_offset
|
||||
if not prompt_logprobs_for_token:
|
||||
continue
|
||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
prompt_token_ids_with_token = (
|
||||
prompt_token_ids[:token_position] + [token_id])
|
||||
(new_tokens, new_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=prompt_token_ids_with_token,
|
||||
prev_tokens=prev_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
# Use the offsets & prev tokens corresponding to
|
||||
# real tokens to ensure detokenization is consistent
|
||||
# actual with prompt.
|
||||
if token_id == all_token_ids[token_position]:
|
||||
next_iter_prefix_offset = new_prefix_offset
|
||||
next_iter_read_offset = new_read_offset
|
||||
next_iter_tokens = new_tokens
|
||||
|
||||
# Advance to the next token position.
|
||||
prefix_offset = next_iter_prefix_offset
|
||||
read_offset = next_iter_read_offset
|
||||
if prev_tokens is None:
|
||||
prev_tokens = next_iter_tokens.copy()
|
||||
else:
|
||||
prev_tokens.extend(next_iter_tokens)
|
||||
|
||||
def decode_sequence_inplace(self, seq: Sequence,
|
||||
prms: SamplingParams) -> int:
|
||||
"""Decodes the new token for a sequence. In-place operation.
|
||||
|
||||
Args:
|
||||
seq: The sequence to decode.
|
||||
prms: The sampling parameters used to generate the sequence.
|
||||
|
||||
Returns:
|
||||
The number of characters added to the output text.
|
||||
"""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
# computation for each logprob.
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
# Decode logprobs
|
||||
logprobs = seq.output_logprobs[-1]
|
||||
if logprobs:
|
||||
previous_tokens = all_input_ids[:-1]
|
||||
for token_id, sample_logprob in logprobs.items():
|
||||
# If the token was generated this iteration,
|
||||
# use the provided text.
|
||||
if token_id == token_id_generated_this_iteration:
|
||||
sample_logprob.decoded_token = new_decoded_token_text
|
||||
continue
|
||||
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_decoded_token_text
|
||||
|
||||
return len(new_decoded_token_text)
|
@ -11,12 +11,12 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname, run_method,
|
||||
update_environment_variables,
|
||||
warn_for_unimplemented_methods)
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
Reference in New Issue
Block a user