mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Remove tokenizer group in vLLM (#24078)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@ -1,10 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.inputs import token_inputs
|
||||
@ -54,10 +51,7 @@ def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
||||
- When the EOS token should be ignored, and the sequence continues
|
||||
"""
|
||||
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||
get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
|
||||
stop_checker = StopChecker(max_model_len=1024,
|
||||
get_tokenizer_for_seq=get_tokenizer_for_seq)
|
||||
stop_checker = StopChecker(max_model_len=1024)
|
||||
|
||||
seq = sequence_with_eos(
|
||||
text=text_wo_eos,
|
||||
|
@ -58,16 +58,13 @@ def deepseek_r1_qwen_tokenizer():
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker():
|
||||
return StopChecker(max_model_len=10,
|
||||
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer)
|
||||
return StopChecker(max_model_len=10)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker_with_reasoner():
|
||||
reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
|
||||
return StopChecker(max_model_len=10,
|
||||
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer,
|
||||
reasoner=reasoner)
|
||||
return StopChecker(max_model_len=10, reasoner=reasoner)
|
||||
|
||||
|
||||
def test_eos_token_stopping(stop_checker):
|
||||
|
@ -208,25 +208,3 @@ def zephyr_lora_files():
|
||||
"""Download zephyr LoRA files once per test session."""
|
||||
from huggingface_hub import snapshot_download
|
||||
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||
"""Create zephyr LoRA files with added tokens once per test session."""
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tmp_dir = TemporaryDirectory()
|
||||
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
# Copy tokenizer to adapter and add some unique tokens
|
||||
# 32000, 32001, 32002
|
||||
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||
special_tokens=True)
|
||||
assert added == 3
|
||||
tokenizer.save_pretrained(tmp_model_dir)
|
||||
yield tmp_model_dir
|
||||
tmp_dir.cleanup()
|
||||
|
@ -29,11 +29,7 @@ def monkeypatch_module():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def server(
|
||||
request,
|
||||
monkeypatch_module,
|
||||
zephyr_lora_files, #noqa: F811
|
||||
zephyr_lora_added_tokens_files): # noqa: F811
|
||||
def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
@ -49,7 +45,6 @@ def server(
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
@ -79,7 +74,7 @@ async def client(server):
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str):
|
||||
messages = [{
|
||||
|
@ -27,7 +27,7 @@ GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
|
||||
def default_server_args(zephyr_lora_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -41,7 +41,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files):
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
@ -87,7 +86,7 @@ async def client(server):
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
@ -115,20 +114,6 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model="zephyr-lora2",
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
# Added tokens should appear in tokenized prompt
|
||||
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
@ -147,7 +132,7 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
@ -713,7 +698,7 @@ async def test_guided_grammar(client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
|
@ -21,10 +21,7 @@ CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(
|
||||
zephyr_lora_files,
|
||||
zephyr_lora_added_tokens_files,
|
||||
) -> list[str]:
|
||||
def default_server_args() -> list[str]:
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
|
@ -67,12 +67,6 @@ def server_with_lora_modules_json(request, monkeypatch_module,
|
||||
"base_model_name": MODEL_NAME
|
||||
}
|
||||
|
||||
lora_module_2 = {
|
||||
"name": "zephyr-lora2",
|
||||
"path": zephyr_lora_files,
|
||||
"base_model_name": MODEL_NAME
|
||||
}
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -84,7 +78,6 @@ def server_with_lora_modules_json(request, monkeypatch_module,
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
json.dumps(lora_module_1),
|
||||
json.dumps(lora_module_2),
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
@ -121,7 +114,6 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI,
|
||||
for lora_model in lora_models)
|
||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -209,7 +201,7 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
|
||||
zephyr_lora_files):
|
||||
"""Validate that many loras can be dynamically registered and inferenced
|
||||
"""Validate that many loras can be dynamically registered and inferenced
|
||||
with concurrently"""
|
||||
|
||||
# This test file configures the server with --max-cpu-loras=2 and this test
|
||||
|
@ -26,7 +26,6 @@ def server(zephyr_lora_files):
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
@ -56,4 +55,3 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
assert all(lora_model.root == zephyr_lora_files
|
||||
for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
@ -14,7 +14,7 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
def server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@ -24,12 +24,6 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--enable-tokenizer-info-endpoint",
|
||||
]
|
||||
|
||||
@ -38,10 +32,8 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer_name(model_name: str,
|
||||
zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
return zephyr_lora_added_tokens_files if (
|
||||
model_name == "zephyr-lora2") else model_name
|
||||
def tokenizer_name(model_name: str):
|
||||
return model_name
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -53,7 +45,7 @@ async def client(server):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
[(MODEL_NAME, MODEL_NAME)],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_completions(
|
||||
@ -86,7 +78,7 @@ async def test_tokenize_completions(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
[(MODEL_NAME, MODEL_NAME)],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_chat(
|
||||
@ -148,7 +140,7 @@ async def test_tokenize_chat(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
[(MODEL_NAME, MODEL_NAME)],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_chat_with_tools(
|
||||
@ -225,7 +217,7 @@ async def test_tokenize_chat_with_tools(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
[(MODEL_NAME, MODEL_NAME)],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenize_with_return_token_strs(
|
||||
@ -260,7 +252,7 @@ async def test_tokenize_with_return_token_strs(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
[(MODEL_NAME, MODEL_NAME)],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_detokenize(
|
||||
@ -287,7 +279,7 @@ async def test_detokenize(
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
[(MODEL_NAME, MODEL_NAME)],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenizer_info_basic(
|
||||
@ -384,4 +376,4 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
|
||||
if chat_template:
|
||||
assert isinstance(chat_template,
|
||||
str), ("Chat template should be a string")
|
||||
assert chat_template.strip(), "Chat template should not be empty"
|
||||
assert chat_template.strip(), "Chat template should not be empty"
|
||||
|
@ -18,6 +18,8 @@ SERVER_ARGS = [
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"{LORA_MODEL}={LORA_MODEL}",
|
||||
"--tokenizer",
|
||||
f"{LORA_MODEL}",
|
||||
]
|
||||
|
||||
TOOLS = [{
|
||||
|
@ -23,7 +23,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
encode_video_base64)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
@ -69,12 +69,7 @@ def phi3v_model_config_mm_interleaved():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def phi3v_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=PHI3V_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
return get_tokenizer(PHI3V_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -91,12 +86,7 @@ def qwen2_audio_model_config():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen2_audio_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=QWEN2AUDIO_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -115,12 +105,7 @@ def qwen25omni_model_config_mm_interleaved():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen25omni_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=QWEN25OMNI_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
return get_tokenizer(QWEN25OMNI_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@ -136,12 +121,7 @@ def mistral_model_config():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=MISTRAL_MODEL_ID,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
return get_tokenizer(MISTRAL_MODEL_ID)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -2250,15 +2230,11 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
# Build the tokenizer group and grab the underlying tokenizer
|
||||
tokenizer_group = TokenizerGroup(
|
||||
# Build the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
tools = ([{
|
||||
"type": "function",
|
||||
@ -2307,14 +2283,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = resolve_hf_chat_template(
|
||||
@ -2368,14 +2340,10 @@ def test_resolve_content_format_fallbacks(model, expected_format):
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer = get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = resolve_hf_chat_template(
|
||||
@ -2432,14 +2400,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
dummy_tokenizer = get_tokenizer(
|
||||
PHI3V_MODEL_ID, # Dummy
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
dummy_tokenizer = tokenizer_group.tokenizer
|
||||
dummy_tokenizer.chat_template = None
|
||||
|
||||
chat_template = load_chat_template(EXAMPLES_DIR / template_path)
|
||||
|
@ -13,14 +13,6 @@ from ..utils import VLLM_PATH, create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
EXPECTED_NO_LORA_OUTPUT = [
|
||||
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
|
||||
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501
|
||||
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501
|
||||
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501
|
||||
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501
|
||||
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501
|
||||
]
|
||||
EXPECTED_LORA_OUTPUT = [
|
||||
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501
|
||||
" SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501
|
||||
@ -79,23 +71,12 @@ def generate_and_test(llm,
|
||||
sql_lora_files,
|
||||
tensorizer_config_dict: Union[dict, None] = None):
|
||||
print("lora adapter created")
|
||||
assert do_sample(llm,
|
||||
sql_lora_files,
|
||||
tensorizer_config_dict=tensorizer_config_dict,
|
||||
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
|
||||
print("lora 1")
|
||||
assert do_sample(llm,
|
||||
sql_lora_files,
|
||||
tensorizer_config_dict=tensorizer_config_dict,
|
||||
lora_id=1) == EXPECTED_LORA_OUTPUT
|
||||
|
||||
print("no lora")
|
||||
assert do_sample(llm,
|
||||
sql_lora_files,
|
||||
tensorizer_config_dict=tensorizer_config_dict,
|
||||
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
|
||||
print("lora 2")
|
||||
assert do_sample(llm,
|
||||
sql_lora_files,
|
||||
@ -110,6 +91,7 @@ def test_llama_lora(sql_lora_files):
|
||||
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
enable_lora=True,
|
||||
# also test odd max_num_seqs
|
||||
max_num_seqs=13,
|
||||
@ -123,6 +105,7 @@ def test_llama_lora_tp4(sql_lora_files):
|
||||
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
@ -137,6 +120,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
||||
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
tokenizer=sql_lora_files,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
@ -184,6 +168,7 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
|
||||
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
|
||||
|
||||
loaded_llm = LLM(model=model_ref,
|
||||
tokenizer=sql_lora_files,
|
||||
load_format="tensorizer",
|
||||
enable_lora=True,
|
||||
enforce_eager=True,
|
||||
@ -195,11 +180,6 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
|
||||
tc_as_dict = tensorizer_config.to_serializable()
|
||||
|
||||
print("lora adapter created")
|
||||
assert do_sample(loaded_llm,
|
||||
sql_lora_files,
|
||||
tensorizer_config_dict=tc_as_dict,
|
||||
lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
|
||||
print("lora 1")
|
||||
assert do_sample(loaded_llm,
|
||||
sql_lora_files,
|
||||
|
@ -1,135 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
|
||||
def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id,
|
||||
sql_lora_files):
|
||||
"""
|
||||
Test that we properly resolve the range of allowed token ids for lora
|
||||
adapters that define additional tokens.
|
||||
"""
|
||||
|
||||
# Set up a base model compatible with the sql_lora_files adapter and
|
||||
# a known number of tokens in the base model.
|
||||
model_config = ModelConfig(
|
||||
model=llama_2_7b_base_huggingface_id,
|
||||
tokenizer=llama_2_7b_base_huggingface_id,
|
||||
tokenizer_mode="auto",
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
device_config=DeviceConfig(),
|
||||
lora_config=LoRAConfig(),
|
||||
)
|
||||
|
||||
tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
processor = Processor(vllm_config, tokenizer)
|
||||
|
||||
lora_request = LoRARequest("1", 1, str(sql_lora_files))
|
||||
request_id = "1"
|
||||
prompt = "a prompt"
|
||||
|
||||
# tokens added in the lora adapter should not raise an error
|
||||
lora_token_ids = [32000, 32001, 32002, 32003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=lora_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens in the base model should not raise an error
|
||||
base_token_ids = [1000, 1001, 1002, 1003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens not in the lora adapter should raise an error
|
||||
invalid_token_ids = [35000, 35001, 35002, 35003]
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=invalid_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens in the lora adapter with no lora request should raise an error
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=lora_token_ids),
|
||||
)
|
||||
|
||||
|
||||
def test_allowed_token_ids_with_lora_adapter_no_vocab(
|
||||
qwen25vl_base_huggingface_id, qwen25vl_lora_files):
|
||||
"""
|
||||
Test that we properly resolve the range of allowed token ids for lora
|
||||
adapters that do not define additional tokens.
|
||||
"""
|
||||
|
||||
# Set up a base model compatible with the qwen25vl_lora_files adapter and
|
||||
# a known number of tokens in the base model.
|
||||
model_config = ModelConfig(
|
||||
model=qwen25vl_base_huggingface_id,
|
||||
tokenizer=qwen25vl_base_huggingface_id,
|
||||
tokenizer_mode="auto",
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
device_config=DeviceConfig(),
|
||||
lora_config=LoRAConfig(),
|
||||
)
|
||||
|
||||
tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
processor = Processor(vllm_config, tokenizer)
|
||||
|
||||
lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files))
|
||||
request_id = "1"
|
||||
prompt = "a prompt"
|
||||
|
||||
# tokens in the base model should not raise an error
|
||||
base_token_ids = [1000, 1001, 1002, 1003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||
lora_request=lora_request)
|
||||
|
||||
# tokens in the base model with no lora request should not raise an error
|
||||
base_token_ids = [1000, 1001, 1002, 1003]
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=base_token_ids),
|
||||
)
|
||||
|
||||
# tokens not in the base model should raise an error
|
||||
invalid_token_ids = [200000, 200001, 200002, 200003]
|
||||
with pytest.raises(ValueError):
|
||||
processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params=SamplingParams(allowed_token_ids=invalid_token_ids),
|
||||
lora_request=lora_request)
|
@ -82,31 +82,20 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
gpu_memory_utilization=0.2, #avoid OOM
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=True)
|
||||
enable_chunked_prefill=True,
|
||||
tokenizer=tinyllama_lora_files)
|
||||
|
||||
if model.quantization is None:
|
||||
expected_no_lora_output = [
|
||||
"Here are some examples of orange-brown colors",
|
||||
"I'm sorry, I don't have"
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#ff8050",
|
||||
"#ff8080",
|
||||
]
|
||||
elif model.quantization == "awq":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't understand",
|
||||
"I'm sorry, I don't understand",
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#f07700: A v",
|
||||
"#f00000: A v",
|
||||
]
|
||||
elif model.quantization == "gptq":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't have",
|
||||
"I'm sorry, I don't have",
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#f08800: This is",
|
||||
"#f07788 \n#",
|
||||
@ -117,7 +106,6 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
# Assert that the outputs changed.
|
||||
if (model.quantization == "gptq"
|
||||
and expected_output is expected_lora_output):
|
||||
assert output != expected_no_lora_output
|
||||
for i, o in enumerate(output):
|
||||
assert o.startswith(
|
||||
'#'), f"Expected example {i} to start with # but got {o}"
|
||||
@ -127,12 +115,6 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
max_tokens = 10
|
||||
|
||||
print("lora adapter created")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=0,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_no_lora_output)
|
||||
|
||||
print("lora 1")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
@ -140,13 +122,6 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_lora_output)
|
||||
|
||||
print("no lora")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=0,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_no_lora_output)
|
||||
|
||||
print("lora 2")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
|
@ -1,72 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
||||
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=True,
|
||||
max_num_seqs=1,
|
||||
max_loras=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
prompt="prompt", lora_request=lora_request)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(
|
||||
prompt="prompt", lora_request=lora_request)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
||||
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
lora_request) != tokenizer_group.get_lora_tokenizer(None)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
lora_request) == await tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
|
||||
def test_get_lora_tokenizer(sql_lora_files, tmp_path):
|
||||
lora_request = None
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert not tokenizer
|
||||
|
||||
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert tokenizer.get_added_vocab()
|
||||
|
||||
lora_request = LoRARequest("1", 1, str(tmp_path))
|
||||
tokenizer = get_lora_tokenizer(lora_request)
|
||||
assert not tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_lora", [True, False])
|
||||
@pytest.mark.parametrize("max_num_seqs", [1, 2])
|
||||
@pytest.mark.parametrize("max_loras", [1, 2])
|
||||
def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras):
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=enable_lora,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_loras=max_loras,
|
||||
max_input_length=None,
|
||||
)
|
||||
if enable_lora:
|
||||
assert tokenizer_group.lora_tokenizers.capacity == max(
|
||||
max_num_seqs, max_loras)
|
||||
else:
|
||||
assert tokenizer_group.lora_tokenizers.capacity == 0
|
@ -11,7 +11,7 @@ import pytest
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Sequence
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# Make two prefixes with different first blocks.
|
||||
prefix_start = [("You are an expert"), ("You are a")]
|
||||
@ -47,12 +47,7 @@ def flatten_2d(li):
|
||||
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
|
||||
concurrent_lora_int_ids: list[Optional[int]]):
|
||||
|
||||
tokenizer = TokenizerGroup(
|
||||
tokenizer_id="facebook/opt-125m",
|
||||
enable_lora=False,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_input_length=None,
|
||||
)
|
||||
tokenizer = get_tokenizer("facebook/opt-125m")
|
||||
|
||||
hashes: list[list[list[int]]] = []
|
||||
|
||||
@ -76,7 +71,7 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
|
||||
inputs=token_inputs(prompt_token_ids,
|
||||
prompt=prompt),
|
||||
block_size=block_size,
|
||||
eos_token_id=tokenizer.tokenizer.eos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
lora_request=lora_request)
|
||||
|
||||
num_blocks = len(prompt_token_ids) // block_size
|
||||
|
@ -11,7 +11,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer,
|
||||
@ -221,17 +221,14 @@ def test_oov_decode(tokenizer, fast):
|
||||
|
||||
@pytest.fixture
|
||||
def detokenizer(tokenizer_name: str) -> Detokenizer:
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id=tokenizer_name,
|
||||
enable_lora=False,
|
||||
max_num_seqs=100,
|
||||
max_input_length=None,
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_name,
|
||||
tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto",
|
||||
trust_remote_code=False,
|
||||
revision=None,
|
||||
)
|
||||
|
||||
return Detokenizer(tokenizer_group)
|
||||
return Detokenizer(tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(name="complete_sequence_token_ids")
|
||||
@ -312,8 +309,7 @@ def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
# don't support that.
|
||||
if complete_sequence not in SPECIAL_TOKS_TRUTH:
|
||||
skip_special_tokens = True
|
||||
elif not isinstance(detokenizer.tokenizer_group.get_lora_tokenizer(None),
|
||||
MistralTokenizer):
|
||||
elif not isinstance(detokenizer.tokenizer, MistralTokenizer):
|
||||
skip_special_tokens = False
|
||||
else:
|
||||
pytest.skip("MistralTokenizers don't support "
|
||||
@ -339,7 +335,7 @@ def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenizer = detokenizer.get_tokenizer_for_seq(seq)
|
||||
tokenizer = detokenizer.tokenizer
|
||||
text_full = tokenizer.decode(token_ids,
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
text_first = tokenizer.decode(token_ids[0],
|
||||
|
@ -1,27 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_group():
|
||||
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
tokenizer_group = TokenizerGroup(
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
prompt="prompt", lora_request=None)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(prompt="prompt",
|
||||
lora_request=None)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
@ -57,6 +57,10 @@ class TestTokenizer(TokenizerBase):
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, list[str], list[int]],
|
||||
|
@ -12,7 +12,6 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
generate_dummy_prompt_logprobs_tensors,
|
||||
generate_dummy_sample_logprobs)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||
|
||||
@ -24,7 +23,7 @@ EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
"""Generate output processor dummy test vectors, without logprobs
|
||||
|
||||
|
||||
Returns:
|
||||
DummyOutputProcessorTestVectors instance with no logprobs
|
||||
"""
|
||||
@ -48,9 +47,6 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
]
|
||||
return DummyOutputProcessorTestVectors(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_group=init_tokenizer_from_configs(
|
||||
vllm_config.model_config, vllm_config.scheduler_config,
|
||||
vllm_config.lora_config),
|
||||
vllm_config=vllm_config,
|
||||
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -68,7 +64,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
@pytest.fixture
|
||||
def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
|
||||
"""Generate output processor dummy test vectors, with logprobs
|
||||
|
||||
|
||||
Returns:
|
||||
DummyOutputProcessorTestVectors instance with logprobs
|
||||
"""
|
||||
|
@ -43,7 +43,7 @@ def _ref_convert_id_to_token(
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens)
|
||||
@ -382,7 +382,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
@ -535,7 +535,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
) # '<|end_of_text|>'
|
||||
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
|
||||
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
# Dummy engine core outputs, with control tokens suffixed to test stops
|
||||
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
|
||||
@ -642,7 +642,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
def test_stop_string(include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int], dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
@ -763,7 +763,7 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
|
||||
|
||||
def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
@ -9,7 +9,6 @@ import torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, FinishReason
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
@ -39,7 +38,7 @@ def _create_random_top_logprob_test_vector(
|
||||
upper: float,
|
||||
) -> torch.Tensor:
|
||||
"""Create a random vector of top logprob float values.
|
||||
|
||||
|
||||
Use to create fake sample logprobs for testing.
|
||||
|
||||
Note that a real production scenario would require
|
||||
@ -63,7 +62,7 @@ def _create_random_top_logprob_test_matrix(
|
||||
upper: float,
|
||||
) -> torch.Tensor:
|
||||
"""Create a random matrix of top logprob float values.
|
||||
|
||||
|
||||
Use to create fake prompt logprobs for testing.
|
||||
|
||||
Note that a real production scenario would require
|
||||
@ -296,7 +295,6 @@ def generate_dummy_prompt_logprobs_tensors(
|
||||
class DummyOutputProcessorTestVectors:
|
||||
"""Dummy test vectors for output processor tests"""
|
||||
tokenizer: GeneralTokenizerType
|
||||
tokenizer_group: TokenizerGroup
|
||||
vllm_config: EngineArgs
|
||||
full_tokens: list[list[int]] # Prompt + generated tokens
|
||||
prompt_tokens: list[list[int]]
|
||||
|
@ -582,7 +582,7 @@ def test_structured_output_with_reasoning_matrices(
|
||||
reasoning_parser=reasoning_parser,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
tokenizer = llm.get_tokenizer(None)
|
||||
tokenizer = llm.get_tokenizer()
|
||||
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
|
||||
tokenizer=tokenizer)
|
||||
|
||||
|
@ -37,7 +37,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
@ -100,8 +100,8 @@ class BenchmarkDataset(ABC):
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the BenchmarkDataset with an optional dataset path and random
|
||||
seed.
|
||||
|
||||
seed.
|
||||
|
||||
Args:
|
||||
dataset_path (Optional[str]): Path to the dataset. If None, it
|
||||
indicates that a default or random dataset might be used.
|
||||
@ -133,10 +133,10 @@ class BenchmarkDataset(ABC):
|
||||
elif isinstance(mm_content, dict):
|
||||
content.append(mm_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
raise TypeError(
|
||||
"Could not process multimodal content of type: " +
|
||||
f"{type(mm_content)}"
|
||||
)
|
||||
f"{type(mm_content)}"
|
||||
)
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def load_data(self) -> None:
|
||||
@ -155,34 +155,26 @@ class BenchmarkDataset(ABC):
|
||||
|
||||
def get_random_lora_request(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
|
||||
) -> Optional[LoRARequest]:
|
||||
"""
|
||||
Optionally select a random LoRA request and return its associated
|
||||
tokenizer.
|
||||
Optionally select a random LoRA request.
|
||||
|
||||
This method is used when LoRA parameters are provided. It randomly
|
||||
selects a LoRA based on max_loras and retrieves a cached tokenizer for
|
||||
that LoRA if available. Otherwise, it returns the base tokenizer.
|
||||
selects a LoRA based on max_loras.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
|
||||
LoRA is selected.
|
||||
max_loras (Optional[int]): The maximum number of LoRAs available.
|
||||
If `None`, LoRA is not used.
|
||||
lora_path (Optional[str]): Path to the LoRA parameters on disk.
|
||||
If `None`, LoRA is not used.
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- A new [LoRARequest][] (or `None` if not applicable).
|
||||
- The tokenizer associated with the LoRA request
|
||||
(or the base tokenizer).
|
||||
A new [LoRARequest][] (or `None` if not applicable).
|
||||
"""
|
||||
if max_loras is None or lora_path is None:
|
||||
return None, tokenizer
|
||||
return None
|
||||
|
||||
# Generate a random LoRA ID in the range [1, max_loras].
|
||||
lora_id = random.randint(1, max_loras)
|
||||
@ -191,11 +183,7 @@ class BenchmarkDataset(ABC):
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(lora_path),
|
||||
)
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
# Return lora_request and the cached tokenizer if available; otherwise,
|
||||
# return the base tokenizer
|
||||
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
|
||||
return lora_request
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||
@ -213,7 +201,7 @@ class BenchmarkDataset(ABC):
|
||||
for processing the dataset's text.
|
||||
num_requests (int): The number of sample requests to generate.
|
||||
request_id_prefix (str) The prefix of request_id.
|
||||
|
||||
|
||||
|
||||
Returns:
|
||||
list[SampleRequest]: A list of sample requests generated from the
|
||||
@ -527,7 +515,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
size=num_requests)
|
||||
output_lens = self._rng.integers(output_low, output_high + 1,
|
||||
size=num_requests)
|
||||
offsets = self._rng.integers(0, tokenizer.vocab_size,
|
||||
offsets = self._rng.integers(0, tokenizer.vocab_size,
|
||||
size=num_requests)
|
||||
return input_lens, output_lens, offsets
|
||||
|
||||
@ -555,7 +543,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
the encoded sequence is truncated before being decoded again.
|
||||
"""
|
||||
# Build the inner sequence by sampling sequentially from the vocab
|
||||
inner_seq = ((offset + index + np.arange(input_len))
|
||||
inner_seq = ((offset + index + np.arange(input_len))
|
||||
% vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
|
||||
@ -590,9 +578,9 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
|
||||
The maximum is further clamped to the sum of per-modality limits.
|
||||
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
|
||||
mapping (height, width, num_frames) → probability. We treat
|
||||
`num_frames`=1 as image and and `num_frames` > 1 as video.
|
||||
Entries with zero probability are removed and the rest are renormalized
|
||||
mapping (height, width, num_frames) → probability. We treat
|
||||
`num_frames`=1 as image and and `num_frames` > 1 as video.
|
||||
Entries with zero probability are removed and the rest are renormalized
|
||||
to sum to 1.
|
||||
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
|
||||
When a modality reaches its cap, all of its buckets are excluded and the
|
||||
@ -600,8 +588,8 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
|
||||
Example bucket configuration:
|
||||
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
|
||||
- Two image buckets (`num_frames`=1) and one video bucket
|
||||
(`num_frames`=16).
|
||||
- Two image buckets (`num_frames`=1) and one video bucket
|
||||
(`num_frames`=16).
|
||||
OBS.: Only image sampling is supported for now.
|
||||
"""
|
||||
|
||||
@ -624,9 +612,9 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
|
||||
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
|
||||
"""Generate synthetic PIL image with random RGB values.
|
||||
|
||||
NOTE: iid pixel sampling results in worst-case compression
|
||||
(good for stressing I/O), but very unlike real photos.
|
||||
|
||||
NOTE: iid pixel sampling results in worst-case compression
|
||||
(good for stressing I/O), but very unlike real photos.
|
||||
We could consider a “low-freq” mode (e.g., noise blur)
|
||||
to emulate network realism instead of max stress.
|
||||
"""
|
||||
@ -638,11 +626,11 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
)
|
||||
return Image.fromarray(random_pixels)
|
||||
|
||||
def generate_synthetic_video(self, width: int,
|
||||
height: int,
|
||||
def generate_synthetic_video(self, width: int,
|
||||
height: int,
|
||||
num_frames: int) -> Any:
|
||||
"""Generate synthetic video with random values.
|
||||
|
||||
|
||||
TODO: Finish this method.
|
||||
"""
|
||||
raise NotImplementedError("Video sampling is WIP.")
|
||||
@ -656,7 +644,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
else:
|
||||
raise ValueError(f"Invalid multimodal item configuration: {config}")
|
||||
|
||||
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
|
||||
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
|
||||
float]) -> dict[tuple[int, int, int], float]:
|
||||
"""
|
||||
Remove zero probability entries
|
||||
@ -676,24 +664,24 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
return {k: v / total for k, v in bucket_config.items()}
|
||||
|
||||
|
||||
def generate_mm_item(self,
|
||||
def generate_mm_item(self,
|
||||
mm_item_config: tuple[int, int, int],
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Create synthetic images and videos and
|
||||
Create synthetic images and videos and
|
||||
apply process_image/process_video respectively.
|
||||
This follows the OpenAI API chat completions
|
||||
https://github.com/openai/openai-python
|
||||
"""
|
||||
|
||||
|
||||
if self.map_config_to_modality(mm_item_config) == "image":
|
||||
return process_image(self.generate_synthetic_image(
|
||||
mm_item_config[1],
|
||||
mm_item_config[0]))
|
||||
elif self.map_config_to_modality(mm_item_config) == "video":
|
||||
return process_video(self.generate_synthetic_video(
|
||||
mm_item_config[1],
|
||||
mm_item_config[0],
|
||||
mm_item_config[1],
|
||||
mm_item_config[0],
|
||||
mm_item_config[2]))
|
||||
else:
|
||||
raise ValueError(f"Invalid multimodal item configuration: "
|
||||
@ -723,17 +711,17 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
f"limit_mm_per_prompt: "
|
||||
f"{limit_mm_per_prompt.keys()}")
|
||||
|
||||
# Remove zero probability entries
|
||||
# Remove zero probability entries
|
||||
# and normalize bucket config to sum to 1
|
||||
bucket_config = self.normalize_bucket_config(bucket_config)
|
||||
logger.info(
|
||||
"Normalized bucket config: %s", bucket_config,
|
||||
)
|
||||
# Only consider limit per prompt for modalities in bucket config
|
||||
allowed_modalities = {self.map_config_to_modality(cfg)
|
||||
allowed_modalities = {self.map_config_to_modality(cfg)
|
||||
for cfg in bucket_config}
|
||||
limit_mm_per_prompt = {
|
||||
k: v for k, v in limit_mm_per_prompt.items()
|
||||
k: v for k, v in limit_mm_per_prompt.items()
|
||||
if k in allowed_modalities}
|
||||
if not limit_mm_per_prompt:
|
||||
raise ValueError("No valid limits for modalities present in "
|
||||
@ -746,19 +734,19 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
# Get max and min num mm items and ensure
|
||||
# it is at most the sum of limit_mm_per_prompt for all modalities
|
||||
max_num_mm_items = min(
|
||||
sum(limit_mm_per_prompt.values()),
|
||||
sum(limit_mm_per_prompt.values()),
|
||||
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
|
||||
)
|
||||
# Ensure min num mm items is at least 0
|
||||
min_num_mm_items = max(
|
||||
0,
|
||||
0,
|
||||
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
|
||||
)
|
||||
# Raise error if min num mm items is greater than max num mm items
|
||||
if min_num_mm_items > max_num_mm_items:
|
||||
raise ValueError(f"Min num mm items is greater than max mm items: "
|
||||
f"{min_num_mm_items} > {max_num_mm_items}")
|
||||
|
||||
|
||||
logger.info(
|
||||
"Sampling number of multimodal items from [%s, %s]",
|
||||
min_num_mm_items, max_num_mm_items,
|
||||
@ -783,8 +771,8 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
whose size is between min_num_mm_items and max_num_mm_items.
|
||||
|
||||
Loop over the bucket config and sample a multimodal item.
|
||||
Loop until the number of multimodal items sampled is equal to
|
||||
request_num_mm_items or limit of multimodal items per prompt
|
||||
Loop until the number of multimodal items sampled is equal to
|
||||
request_num_mm_items or limit of multimodal items per prompt
|
||||
for all modalities is reached.
|
||||
|
||||
Note:
|
||||
@ -796,19 +784,19 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
# Get the number of multimodal items to sample
|
||||
request_num_mm_items = int(
|
||||
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
|
||||
)
|
||||
)
|
||||
# If request_num_mm_items is 0, yield an empty iterator
|
||||
if request_num_mm_items == 0:
|
||||
return
|
||||
# Initialize modality counters
|
||||
modality_counter = {self.map_config_to_modality(k): 0
|
||||
modality_counter = {self.map_config_to_modality(k): 0
|
||||
for k in bucket_config}
|
||||
# Copy the bucket config to avoid modifying the original
|
||||
bucket_config_copy = bucket_config.copy()
|
||||
# Loop over the number of multimodal items to sample
|
||||
while sum(modality_counter.values()) < request_num_mm_items:
|
||||
# Sample a multimodal item config
|
||||
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
|
||||
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
|
||||
p=list(bucket_config_copy.values()))
|
||||
modality = self.map_config_to_modality(mm_item_config)
|
||||
# Check that modality count is less than limit per prompt
|
||||
@ -849,7 +837,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
|
||||
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
|
||||
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
|
||||
bucket_config: dict[tuple[int, int, int], float] =
|
||||
bucket_config: dict[tuple[int, int, int], float] =
|
||||
DEFAULT_MM_ITEM_BUCKET_CONFIG,
|
||||
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
|
||||
**kwargs,
|
||||
@ -857,7 +845,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
|
||||
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
|
||||
# and probability is non-zero.
|
||||
if any(self.map_config_to_modality(cfg) == "video" and p > 0
|
||||
if any(self.map_config_to_modality(cfg) == "video" and p > 0
|
||||
for cfg, p in bucket_config.items()):
|
||||
raise NotImplementedError("Video sampling not implemented; "
|
||||
"set its probability to 0.")
|
||||
@ -908,7 +896,7 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
])
|
||||
|
||||
if enable_multimodal_chat:
|
||||
# NOTE: For now this option is only provided for completeness
|
||||
# NOTE: For now this option is only provided for completeness
|
||||
# given that the serve.py benchmark currently does not use it.
|
||||
mm_chat_prompt: Any = prompt
|
||||
mm_chat_prompt = self.apply_multimodal_chat_transformation(
|
||||
@ -982,8 +970,8 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
entry["conversations"][1]["value"],
|
||||
)
|
||||
|
||||
lora_request, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
lora_request = self.get_random_lora_request(
|
||||
max_loras=max_loras, lora_path=lora_path)
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
@ -994,11 +982,11 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
skip_min_output_len_check=output_len
|
||||
is not None):
|
||||
continue
|
||||
if image_path := entry.get("image"):
|
||||
mm_content = process_image(image_path)
|
||||
elif video_path := entry.get("video"):
|
||||
if image_path := entry.get("image"):
|
||||
mm_content = process_image(image_path)
|
||||
elif video_path := entry.get("video"):
|
||||
mm_content = process_video(video_path)
|
||||
else:
|
||||
else:
|
||||
mm_content = None
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
@ -1013,9 +1001,9 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
request_id=request_id_prefix + str(ind),
|
||||
))
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(samples,
|
||||
num_requests,
|
||||
request_id_prefix,
|
||||
self.maybe_oversample_requests(samples,
|
||||
num_requests,
|
||||
request_id_prefix,
|
||||
no_oversample)
|
||||
return samples
|
||||
|
||||
@ -1024,11 +1012,11 @@ class _ValidateDatasetArgs(argparse.Action):
|
||||
"""Argparse action to validate dataset name and path compatibility."""
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, values)
|
||||
|
||||
|
||||
# Get current values of both dataset_name and dataset_path
|
||||
dataset_name = getattr(namespace, 'dataset_name', 'random')
|
||||
dataset_path = getattr(namespace, 'dataset_path', None)
|
||||
|
||||
|
||||
# Validate the combination
|
||||
if dataset_name == "random" and dataset_path is not None:
|
||||
parser.error(
|
||||
@ -1053,7 +1041,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
default="random",
|
||||
action=_ValidateDatasetArgs,
|
||||
choices=[
|
||||
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
||||
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
||||
"custom", "prefix_repetition", "spec_bench"
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
@ -1502,7 +1490,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"spec_bench":
|
||||
lambda: SpecBench(dataset_path=args.dataset_path,
|
||||
lambda: SpecBench(dataset_path=args.dataset_path,
|
||||
category=args.spec_bench_category).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1660,7 +1648,7 @@ class CustomDataset(BenchmarkDataset):
|
||||
logger.info("num_requests is set to 0 or negative, "
|
||||
"so using all available samples: %d",
|
||||
num_requests)
|
||||
|
||||
|
||||
sampled_requests = []
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -1686,7 +1674,7 @@ class CustomDataset(BenchmarkDataset):
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
|
||||
return sampled_requests
|
||||
@ -1700,7 +1688,7 @@ class CustomDataset(BenchmarkDataset):
|
||||
class SpecBench(CustomDataset):
|
||||
"""
|
||||
Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench
|
||||
Download the dataset using:
|
||||
Download the dataset using:
|
||||
wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
|
||||
""" # noqa: E501
|
||||
|
||||
@ -1736,8 +1724,8 @@ class SpecBench(CustomDataset):
|
||||
# leverage CustomDataset sample
|
||||
kwargs["skip_chat_template"] = False
|
||||
return super().sample(**kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -1882,8 +1870,8 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
for i in range(num_requests):
|
||||
input_len = int(data[i][2])
|
||||
output_len = int(data[i][3])
|
||||
lora_req, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
lora_req = self.get_random_lora_request(
|
||||
max_loras=max_loras, lora_path=lora_path)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
# Generate a synthetic prompt: a list of token IDs computed as (i +
|
||||
# j) modulo vocab_size.
|
||||
@ -1995,7 +1983,7 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
request_id=request_id_prefix + str(ind),
|
||||
))
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
@ -2055,7 +2043,7 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
@ -2172,7 +2160,7 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
@ -2234,7 +2222,7 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
@ -2288,8 +2276,8 @@ class BlazeditDataset(HuggingFaceDataset):
|
||||
# compare the levenshtein distance normalized by code length
|
||||
if norm_distance < min_distance or norm_distance > max_distance:
|
||||
continue
|
||||
|
||||
# template copied from
|
||||
|
||||
# template copied from
|
||||
# https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
|
||||
instruction = f"""Given a code file, please apply the change requests and generate the new file.
|
||||
|
||||
@ -2322,9 +2310,9 @@ Please generate the new code file in the "New file" section below.""" # noqa: E5
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -2376,7 +2364,6 @@ class AIMODataset(HuggingFaceDataset):
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=None,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
|
||||
))
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
@ -2470,9 +2457,9 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
))
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples,
|
||||
num_requests,
|
||||
request_id_prefix,
|
||||
self.maybe_oversample_requests(samples,
|
||||
num_requests,
|
||||
request_id_prefix,
|
||||
no_oversample)
|
||||
return samples
|
||||
|
||||
@ -2562,7 +2549,7 @@ class ASRDataset(HuggingFaceDataset):
|
||||
" what Whisper supports.",
|
||||
skipped,
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
@ -2647,7 +2634,7 @@ class MLPerfDataset(HuggingFaceDataset):
|
||||
)
|
||||
ind += 1
|
||||
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests,
|
||||
request_id_prefix, no_oversample)
|
||||
return sampled_requests
|
||||
|
||||
@ -2658,7 +2645,7 @@ class MLPerfDataset(HuggingFaceDataset):
|
||||
|
||||
|
||||
class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the repeated prefix
|
||||
# Default values copied from benchmark_serving.py for the repeated prefix
|
||||
# dataset.
|
||||
DEFAULT_PREFIX_LEN = 256
|
||||
DEFAULT_SUFFIX_LEN = 256
|
||||
|
@ -390,11 +390,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def get_tokenizer_async(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> AnyTokenizer:
|
||||
return await (
|
||||
self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
|
||||
async def get_tokenizer_async(self) -> AnyTokenizer:
|
||||
return self.get_tokenizer()
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
@ -435,7 +432,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
processed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
@ -614,11 +610,8 @@ class AsyncLLMEngine(EngineClient):
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.engine.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return await self.engine.get_tokenizer_async(lora_request)
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.engine.get_tokenizer()
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
|
@ -49,9 +49,8 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter, Device, resolve_obj_by_qualname, weak_bind
|
||||
@ -186,7 +185,7 @@ class LLMEngine:
|
||||
|
||||
return outputs_
|
||||
|
||||
tokenizer: Optional[TokenizerGroup]
|
||||
tokenizer: Optional[AnyTokenizer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -233,18 +232,9 @@ class LLMEngine:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
tokenizer_group = None
|
||||
else:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
|
||||
# Ensure that the function doesn't contain a reference to self,
|
||||
# to avoid engine GC issues
|
||||
def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
|
||||
assert tokenizer_group, ("tokenizer_group cannot be None, "
|
||||
"make sure skip_tokenizer_init is False")
|
||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = (
|
||||
@ -389,10 +379,8 @@ class LLMEngine:
|
||||
self.detokenizer,
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
get_tokenizer_for_seq,
|
||||
self.reasoner if self.decoding_config.reasoning_backend
|
||||
and self.tokenizer else None,
|
||||
),
|
||||
@ -521,24 +509,15 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
|
||||
|
||||
def _init_tokenizer(self) -> TokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
lora_config=self.lora_config)
|
||||
def _init_tokenizer(self) -> AnyTokenizer:
|
||||
return init_tokenizer_from_configs(model_config=self.model_config)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@ -574,11 +553,11 @@ class LLMEngine:
|
||||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
@ -700,7 +679,6 @@ class LLMEngine:
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
self._add_processed_request(
|
||||
@ -1739,29 +1717,22 @@ class LLMEngine:
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE,
|
||||
metrics.model_execute_time)
|
||||
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest]):
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
self._validate_model_input(encoder_inputs, prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
self._validate_model_input(decoder_inputs, prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = (None if self.tokenizer is None else
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
|
||||
if not prompt_ids:
|
||||
@ -1822,7 +1793,7 @@ class LLMEngine:
|
||||
logits_processors = []
|
||||
|
||||
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processors = get_openai_logits_processors(
|
||||
logit_bias=sampling_params.logit_bias,
|
||||
@ -1835,7 +1806,7 @@ class LLMEngine:
|
||||
sampling_params.allowed_token_ids = None
|
||||
|
||||
if len(sampling_params.bad_words) > 0:
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
processors = get_bad_words_logits_processors(
|
||||
bad_words=sampling_params.bad_words, tokenizer=tokenizer)
|
||||
logits_processors.extend(processors)
|
||||
|
@ -2,14 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
from typing import List
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
|
||||
from vllm.sequence import SequenceGroup, SequenceGroupOutput
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
@ -31,7 +30,6 @@ class SequenceGroupOutputProcessor(ABC):
|
||||
detokenizer: Detokenizer,
|
||||
scheduler: List[Scheduler],
|
||||
seq_counter: Counter,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
stop_checker: "StopChecker",
|
||||
):
|
||||
"""Create an output processor.
|
||||
|
@ -1,13 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class StopChecker:
|
||||
@ -20,12 +19,10 @@ class StopChecker:
|
||||
def __init__(
|
||||
self,
|
||||
max_model_len: int,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
reasoner: Optional[ReasoningParser] = None,
|
||||
):
|
||||
# Do not use it directly, but use `self._get_max_model_len`.
|
||||
self._max_model_len = max_model_len
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
self.reasoner = reasoner
|
||||
|
||||
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
|
||||
|
@ -76,8 +76,7 @@ class EngineClient(ABC):
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
preprocessor = await self.get_input_preprocessor()
|
||||
tokenizer_group = preprocessor.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
|
||||
tokenizer = preprocessor.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
@ -260,11 +259,8 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
"""Get the appropriate tokenizer for the request"""
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
"""Get the tokenizer"""
|
||||
...
|
||||
|
||||
async def get_io_processor(self) -> IOProcessor:
|
||||
|
@ -301,23 +301,17 @@ class LLM:
|
||||
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
|
||||
io_processor_plugin)
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
|
||||
lora_request)
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.llm_engine.get_tokenizer()
|
||||
|
||||
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
|
||||
tokenizer_group = self.llm_engine.get_tokenizer_group()
|
||||
|
||||
# While CachedTokenizer is dynamic, have no choice but
|
||||
# compare class name. Misjudgment will arise from
|
||||
# user-defined tokenizer started with 'Cached'
|
||||
if tokenizer.__class__.__name__.startswith("Cached"):
|
||||
tokenizer_group.tokenizer = tokenizer
|
||||
self.llm_engine.tokenizer = tokenizer
|
||||
else:
|
||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
def get_default_sampling_params(self) -> SamplingParams:
|
||||
if self.default_sampling_params is None:
|
||||
@ -707,7 +701,6 @@ class LLM:
|
||||
self,
|
||||
messages: Union[list[ChatCompletionMessageParam],
|
||||
list[list[ChatCompletionMessageParam]]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
add_generation_prompt: bool = True,
|
||||
@ -739,7 +732,7 @@ class LLM:
|
||||
cast(list[ChatCompletionMessageParam], messages)
|
||||
]
|
||||
|
||||
tokenizer = self.get_tokenizer(lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
model_config = self.llm_engine.get_model_config()
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
@ -872,7 +865,6 @@ class LLM:
|
||||
|
||||
prompts = self.preprocess_chat(
|
||||
messages=messages,
|
||||
lora_request=lora_request,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
@ -1519,7 +1511,7 @@ class LLM:
|
||||
):
|
||||
"""
|
||||
Validate that if any multi-modal data is skipped (i.e. None),
|
||||
then its corresponding UUID must be set.
|
||||
then its corresponding UUID must be set.
|
||||
"""
|
||||
if multi_modal_data is None:
|
||||
return
|
||||
|
@ -188,7 +188,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
|
@ -50,10 +50,7 @@ class ClassificationMixin(OpenAIServing):
|
||||
return None
|
||||
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||
ctx.lora_request)
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
|
@ -127,8 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
|
@ -76,8 +76,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||
)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
@ -394,8 +393,8 @@ class EmbeddingMixin(OpenAIServing):
|
||||
) -> Optional[ErrorResponse]:
|
||||
"""Collect and aggregate batch results
|
||||
with support for chunked processing.
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
|
||||
For chunked requests, performs online aggregation to
|
||||
minimize memory usage.
|
||||
For regular requests, collects results normally.
|
||||
"""
|
||||
|
@ -103,8 +103,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
|
@ -240,7 +240,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
model_name = self.models.model_name(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
if self.use_harmony:
|
||||
messages, request_prompts, engine_prompts = (
|
||||
|
@ -269,7 +269,7 @@ class ServingScores(OpenAIServing):
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
@ -65,7 +65,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
@ -130,7 +130,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
|
@ -9,13 +9,11 @@ from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalUUIDDict)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
EncoderDecoderInputs, ProcessorInputs, PromptType,
|
||||
@ -31,7 +29,7 @@ class InputPreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[TokenizerGroup],
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> None:
|
||||
@ -42,32 +40,28 @@ class InputPreprocessor:
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_processor_cache
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("You cannot pass text prompts when "
|
||||
"`skip_tokenizer_init` is True")
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_bos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
def get_bos_token_id(self) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for BOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
||||
return self.tokenizer.bos_token_id
|
||||
|
||||
def get_eos_token_id(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
def get_eos_token_id(self) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for EOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
def get_decoder_start_token_id(self) -> Optional[int]:
|
||||
"""
|
||||
@ -190,14 +184,13 @@ class InputPreprocessor:
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Apply the model's tokenizer to a text prompt, returning the
|
||||
corresponding token IDs.
|
||||
"""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
||||
|
||||
encoder_config = self.model_config.encoder_config
|
||||
@ -205,50 +198,39 @@ class InputPreprocessor:
|
||||
if encoder_config and encoder_config.get("do_lower_case", False):
|
||||
prompt = prompt.lower()
|
||||
|
||||
return tokenizer.encode(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
**tokenization_kwargs)
|
||||
return tokenizer.encode(prompt, **tokenization_kwargs)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Async version of
|
||||
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
|
||||
"""
|
||||
tokenizer = self.get_tokenizer_group()
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
|
||||
|
||||
return await tokenizer.encode_async(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
**tokenization_kwargs)
|
||||
return tokenizer.encode(prompt, **tokenization_kwargs)
|
||||
|
||||
def _get_mm_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> AnyTokenizer:
|
||||
def _get_mm_tokenizer(self) -> AnyTokenizer:
|
||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||
# while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
return cast(AnyTokenizer, object()) # Dummy
|
||||
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
return tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
return tokenizer
|
||||
|
||||
async def _get_mm_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> AnyTokenizer:
|
||||
async def _get_mm_tokenizer_async(self) -> AnyTokenizer:
|
||||
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
|
||||
# while using also multi-modal input
|
||||
if not self.tokenizer:
|
||||
return cast(AnyTokenizer, object()) # Dummy
|
||||
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
|
||||
tokenizer = self.get_tokenizer()
|
||||
return tokenizer
|
||||
|
||||
def _process_multimodal(
|
||||
self,
|
||||
@ -256,7 +238,6 @@ class InputPreprocessor:
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalInputs:
|
||||
@ -264,7 +245,7 @@ class InputPreprocessor:
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
tokenizer = self._get_mm_tokenizer(lora_request)
|
||||
tokenizer = self._get_mm_tokenizer()
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config,
|
||||
@ -299,7 +280,6 @@ class InputPreprocessor:
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Optional[Mapping[str, object]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> MultiModalInputs:
|
||||
@ -307,7 +287,7 @@ class InputPreprocessor:
|
||||
Async version of
|
||||
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
|
||||
"""
|
||||
tokenizer = await self._get_mm_tokenizer_async(lora_request)
|
||||
tokenizer = await self._get_mm_tokenizer_async()
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config,
|
||||
@ -386,7 +366,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
@ -400,7 +379,6 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
@ -415,7 +393,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
@ -429,7 +406,6 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
@ -444,7 +420,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
@ -457,13 +432,11 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
inputs = token_inputs(
|
||||
@ -480,7 +453,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
@ -493,13 +465,11 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
inputs = token_inputs(
|
||||
@ -516,7 +486,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> SingletonInputs:
|
||||
@ -526,7 +495,6 @@ class InputPreprocessor:
|
||||
Arguments:
|
||||
|
||||
* prompt: single encoder or decoder input prompt
|
||||
* lora_request: this is only valid for decoder prompts
|
||||
|
||||
Returns:
|
||||
|
||||
@ -539,21 +507,18 @@ class InputPreprocessor:
|
||||
if parsed["type"] == "tokens":
|
||||
return self._process_tokens(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
return self._process_text(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return self._process_text(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
@ -563,7 +528,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> SingletonInputs:
|
||||
@ -578,21 +542,18 @@ class InputPreprocessor:
|
||||
if parsed["type"] == "tokens":
|
||||
return await self._process_tokens_async(
|
||||
parsed["content"],
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "text":
|
||||
return await self._process_text_async(
|
||||
parsed["content"],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
if parsed["type"] == "str":
|
||||
return await self._process_text_async(
|
||||
TextPrompt(prompt=parsed["content"]),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
@ -844,7 +805,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
@ -856,7 +816,6 @@ class InputPreprocessor:
|
||||
Arguments:
|
||||
|
||||
* prompt: input prompt
|
||||
* lora_request
|
||||
|
||||
Returns:
|
||||
|
||||
@ -866,7 +825,6 @@ class InputPreprocessor:
|
||||
prompt_comps = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
@ -876,7 +834,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
@ -887,7 +844,6 @@ class InputPreprocessor:
|
||||
prompt_comps = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
@ -897,7 +853,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> ProcessorInputs:
|
||||
@ -919,7 +874,6 @@ class InputPreprocessor:
|
||||
return self._process_decoder_only_prompt(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
@ -927,7 +881,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> ProcessorInputs:
|
||||
@ -952,7 +905,6 @@ class InputPreprocessor:
|
||||
return await self._process_decoder_only_prompt_async(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
|
@ -10,18 +10,13 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence,
|
||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally)
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
|
||||
def __init__(self, tokenizer_group: TokenizerGroup):
|
||||
self.tokenizer_group = tokenizer_group
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
"""Returns the HF tokenizer to use for a given sequence."""
|
||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: list[Optional[dict[
|
||||
@ -32,9 +27,9 @@ class Detokenizer:
|
||||
Args:
|
||||
seq_group: The sequence group to decode.
|
||||
prompt_logprobs: The logprobs to decode.
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
relative to the start of the sequence (for chunked prefill).
|
||||
|
||||
|
||||
Returns:
|
||||
The prompt logprobs with the decoded tokens.
|
||||
"""
|
||||
@ -46,7 +41,6 @@ class Detokenizer:
|
||||
# Only prompt, without the generated token.
|
||||
all_token_ids = seq.get_token_ids()
|
||||
prompt_token_ids = all_token_ids[:-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
next_iter_prefix_offset = 0
|
||||
@ -70,7 +64,7 @@ class Detokenizer:
|
||||
prompt_token_ids[:token_position] + [token_id])
|
||||
(new_tokens, new_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=prompt_token_ids_with_token,
|
||||
prev_tokens=prev_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
@ -111,7 +105,6 @@ class Detokenizer:
|
||||
"""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
@ -119,14 +112,14 @@ class Detokenizer:
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
@ -150,7 +143,7 @@ class Detokenizer:
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
|
@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
import huggingface_hub
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
@ -19,7 +20,6 @@ from vllm.transformers_utils.config import (
|
||||
get_sentence_transformer_tokenizer_config)
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import make_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -274,20 +274,19 @@ def cached_tokenizer_from_config(
|
||||
)
|
||||
|
||||
|
||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||
**kwargs) -> Optional[AnyTokenizer]:
|
||||
if lora_request is None:
|
||||
return None
|
||||
try:
|
||||
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
|
||||
except Exception as e:
|
||||
# No tokenizer was found in the LoRA folder,
|
||||
# use base model tokenizer
|
||||
logger.warning(
|
||||
"No tokenizer found in %s, using base model tokenizer instead. "
|
||||
"(Exception: %s)", lora_request.lora_path, e)
|
||||
tokenizer = None
|
||||
return tokenizer
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig):
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type == "generate" or runner_type == "draft":
|
||||
truncation_side = "left"
|
||||
elif runner_type == "pooling":
|
||||
truncation_side = "right"
|
||||
else:
|
||||
assert_never(runner_type)
|
||||
|
||||
|
||||
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
|
||||
return get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
|
@ -61,6 +61,11 @@ class TokenizerBase(ABC):
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
|
@ -1,132 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, encode_tokens,
|
||||
get_lora_tokenizer,
|
||||
get_lora_tokenizer_async,
|
||||
get_tokenizer)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
|
||||
class TokenizerGroup:
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
||||
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||
max_input_length: Optional[int], **tokenizer_config):
|
||||
self.tokenizer_id = tokenizer_id
|
||||
self.tokenizer_config = tokenizer_config
|
||||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.truncation_side = tokenizer_config.get("truncation_side", "left")
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
max_loras = tokenizer_config.get("max_loras", 0)
|
||||
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
||||
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
return self.max_input_length
|
||||
|
||||
def _raise_if_input_too_long(self,
|
||||
encoded_tokens: list[int],
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
input_length = len(encoded_tokens)
|
||||
if lora_request:
|
||||
max_input_length = (lora_request.long_lora_max_len
|
||||
or self.max_input_length)
|
||||
else:
|
||||
max_input_length = self.max_input_length
|
||||
if max_input_length is not None and input_length > max_input_length:
|
||||
raise ValueError("Input too long.", input_length, max_input_length)
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
|
||||
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||
ret = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens)
|
||||
self._raise_if_input_too_long(ret, lora_request)
|
||||
return ret
|
||||
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> list[int]:
|
||||
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||
ret = encode_tokens(tokenizer,
|
||||
prompt,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens)
|
||||
self._raise_if_input_too_long(ret, lora_request)
|
||||
return ret
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
tokenizer = (get_lora_tokenizer(
|
||||
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
tokenizer = (await get_lora_tokenizer_async(
|
||||
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers[lora_request.lora_int_id]
|
||||
|
||||
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig]):
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type == "generate" or runner_type == "draft":
|
||||
truncation_side = "left"
|
||||
elif runner_type == "pooling":
|
||||
truncation_side = "right"
|
||||
else:
|
||||
assert_never(runner_type)
|
||||
|
||||
return TokenizerGroup(
|
||||
tokenizer_id=model_config.tokenizer,
|
||||
enable_lora=bool(lora_config),
|
||||
max_num_seqs=scheduler_config.max_num_seqs,
|
||||
max_loras=lora_config.max_loras if lora_config else 0,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision,
|
||||
truncation_side=truncation_side)
|
@ -327,6 +327,10 @@ class MistralTokenizer(TokenizerBase):
|
||||
def max_token_id(self) -> int:
|
||||
return self._max_token_id
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
|
@ -29,8 +29,8 @@ from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv,
|
||||
deprecate_kwargs)
|
||||
@ -112,9 +112,7 @@ class AsyncLLM(EngineClient):
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
model_config=vllm_config.model_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
@ -596,15 +594,12 @@ class AsyncLLM(EngineClient):
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
return self.processor.input_preprocessor
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
async def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
return self.tokenizer
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.observability_config.otlp_traces_endpoint is not None
|
||||
|
@ -20,8 +20,8 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
@ -89,9 +89,7 @@ class LLMEngine:
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
model_config=vllm_config.model_config)
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
@ -297,7 +295,7 @@ class LLMEngine:
|
||||
assert self.log_stats, "Stat logging disabled"
|
||||
return get_metrics_snapshot()
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
@ -14,7 +14,6 @@ from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
|
||||
extract_trace_context)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
@ -290,7 +289,7 @@ class RequestState:
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerGroup, log_stats: bool):
|
||||
def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
@ -347,10 +346,7 @@ class OutputProcessor:
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
tokenizer = None if not self.tokenizer else \
|
||||
self.tokenizer.get_lora_tokenizer(request.lora_request)
|
||||
|
||||
req_state = RequestState.from_new_request(tokenizer=tokenizer,
|
||||
req_state = RequestState.from_new_request(tokenizer=self.tokenizer,
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
|
@ -9,6 +9,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
@ -17,7 +18,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
@ -28,13 +29,15 @@ from vllm.v1.structured_output.backend_outlines import (
|
||||
from vllm.v1.structured_output.backend_xgrammar import (
|
||||
validate_xgrammar_grammar)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: TokenizerGroup,
|
||||
tokenizer: AnyTokenizer,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
@ -90,7 +93,6 @@ class Processor:
|
||||
def _validate_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
@ -103,8 +105,7 @@ class Processor:
|
||||
# When skip_tokenizer_init=True, we can't validate token IDs
|
||||
# Skip validation and let the model handle invalid tokens
|
||||
return
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
vocab_size = len(tokenizer)
|
||||
vocab_size = len(self.tokenizer)
|
||||
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id!")
|
||||
@ -144,7 +145,6 @@ class Processor:
|
||||
def _validate_params(
|
||||
self,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest],
|
||||
):
|
||||
"""
|
||||
Validate supported SamplingParam.
|
||||
@ -155,14 +155,14 @@ class Processor:
|
||||
return
|
||||
|
||||
self._validate_logprobs(params)
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
self._validate_sampling_params(params)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
Only checks lengths; `None` entries are allowed and will be
|
||||
Only checks lengths; `None` entries are allowed and will be
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
@ -202,10 +202,22 @@ class Processor:
|
||||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
if lora_request is None:
|
||||
return
|
||||
|
||||
# LoRA request passed in while LoRA is not enabled
|
||||
if not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
if self.tokenizer is not None:
|
||||
logger.warning_once(
|
||||
"vLLM has deprecated support for supporting different "
|
||||
"tokenizers for different LoRAs. By default, vLLM uses base "
|
||||
"model's tokenizer. If you are using a LoRA "
|
||||
"with its own tokenizer, consider specifying `--tokenizer "
|
||||
"[lora_path]` to use the LoRA tokenizer.")
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
@ -326,7 +338,7 @@ class Processor:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params, lora_request)
|
||||
self._validate_params(params)
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
@ -365,7 +377,6 @@ class Processor:
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@ -375,9 +386,9 @@ class Processor:
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id()
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
|
||||
@ -394,8 +405,7 @@ class Processor:
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
if self.tokenizer is not None:
|
||||
sampling_params.update_from_tokenizer(
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
sampling_params.update_from_tokenizer(self.tokenizer)
|
||||
else:
|
||||
pooling_params = params.clone()
|
||||
|
||||
@ -436,24 +446,17 @@ class Processor:
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self,
|
||||
inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest] = None):
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
|
||||
|
||||
if encoder_inputs is not None:
|
||||
self._validate_model_input(encoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="encoder")
|
||||
self._validate_model_input(encoder_inputs, prompt_type="encoder")
|
||||
|
||||
self._validate_model_input(decoder_inputs,
|
||||
lora_request,
|
||||
prompt_type="decoder")
|
||||
self._validate_model_input(decoder_inputs, prompt_type="decoder")
|
||||
|
||||
def _validate_model_input(
|
||||
self,
|
||||
prompt_inputs: SingletonInputs,
|
||||
lora_request: Optional[LoRARequest],
|
||||
*,
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
@ -469,7 +472,7 @@ class Processor:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
tokenizer = self.tokenizer
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
|
||||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||||
|
@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
@ -60,10 +60,7 @@ class StructuredOutputManager:
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
lora_config=self.vllm_config.lora_config,
|
||||
).get_lora_tokenizer(None)
|
||||
model_config=self.vllm_config.model_config)
|
||||
reasoning_backend = \
|
||||
self.vllm_config.decoding_config.reasoning_backend
|
||||
if reasoning_backend:
|
||||
|
Reference in New Issue
Block a user