mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Gate prompt_embeds
behind a feature flag (#17607)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
60
tests/engine/test_options.py
Normal file
60
tests/engine/test_options.py
Normal file
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_skip_tokenizer_initialization(model: str):
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_enable_prompt_embeds(hf_runner, model: str,
|
||||
enable_prompt_embeds: bool):
|
||||
prompt = "abc"
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
token_ids = token_ids.to(hf_model.model.device)
|
||||
|
||||
embed_layer = hf_model.model.get_input_embeddings()
|
||||
prompt_embeds = embed_layer(token_ids).squeeze(0)
|
||||
|
||||
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
|
||||
ValueError, match="set `--enable-prompt-embeds`"))
|
||||
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
with ctx:
|
||||
llm.generate({"prompt_embeds": prompt_embeds})
|
@ -1,29 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_skip_tokenizer_initialization(model: str):
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
# in parts of the operators
|
||||
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
|
||||
|
||||
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
|
||||
"VLLM_USE_V1") == "0" else None
|
||||
prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
|
||||
else None)
|
||||
|
||||
prompt_token_ids = []
|
||||
for prompt in example_prompts:
|
||||
token_ids = hf_model.tokenizer(prompt,
|
||||
@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
max_num_seqs=2,
|
||||
enable_prompt_embeds=use_prompt_embeds,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enable_prompt_embeds=True,
|
||||
)
|
||||
|
||||
seq_lens: list[int] = []
|
||||
@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enable_prompt_embeds=True,
|
||||
)
|
||||
|
||||
context_lens: list[int] = []
|
||||
@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prompt_embeds=True,
|
||||
)
|
||||
|
||||
# Add prefill requests.
|
||||
|
@ -321,6 +321,10 @@ class ModelConfig:
|
||||
"""Skip initialization of tokenizer and detokenizer. Expects valid
|
||||
`prompt_token_ids` and `None` for prompt from the input. The generated
|
||||
output will contain token ids."""
|
||||
enable_prompt_embeds: bool = False
|
||||
"""If `True`, enables passing text embeddings as inputs via the
|
||||
`prompt_embeds` key. Note that enabling this will double the time required
|
||||
for graph compilation."""
|
||||
served_model_name: Optional[Union[str, list[str]]] = None
|
||||
"""The model name(s) used in the API. If multiple names are provided, the
|
||||
server will respond to any of the provided names. The model name in the
|
||||
|
@ -234,6 +234,7 @@ class EngineArgs:
|
||||
hf_config_path: Optional[str] = ModelConfig.hf_config_path
|
||||
task: TaskOption = ModelConfig.task
|
||||
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
|
||||
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
|
||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
@ -445,6 +446,8 @@ class EngineArgs:
|
||||
**model_kwargs["disable_cascade_attn"])
|
||||
model_group.add_argument("--skip-tokenizer-init",
|
||||
**model_kwargs["skip_tokenizer_init"])
|
||||
model_group.add_argument("--enable-prompt-embeds",
|
||||
**model_kwargs["enable_prompt_embeds"])
|
||||
model_group.add_argument("--served-model-name",
|
||||
**model_kwargs["served_model_name"])
|
||||
# This one is a special case because it is the
|
||||
@ -874,6 +877,7 @@ class EngineArgs:
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
disable_cascade_attn=self.disable_cascade_attn,
|
||||
skip_tokenizer_init=self.skip_tokenizer_init,
|
||||
enable_prompt_embeds=self.enable_prompt_embeds,
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
|
@ -303,8 +303,11 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError("You must set `--enable-prompt-embeds` to input "
|
||||
"`prompt_embeds`.")
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("prompt_embeds is only available in V0.")
|
||||
raise ValueError("`prompt_embeds` is only available in V0.")
|
||||
|
||||
prompt_embeds = parsed_content["prompt_embeds"]
|
||||
|
||||
|
@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# product.
|
||||
cudagraph_capture_sizes = self.vllm_config.compilation_config\
|
||||
.cudagraph_capture_sizes
|
||||
cudagraph_inputs_embeds = (True, False)
|
||||
cudagraph_inputs_embeds = ((
|
||||
True, False) if self.model_config.enable_prompt_embeds else
|
||||
(False, ))
|
||||
compilation_cases = itertools.product(
|
||||
cudagraph_capture_sizes,
|
||||
cudagraph_inputs_embeds,
|
||||
|
Reference in New Issue
Block a user