[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Robert Shaw
2024-08-02 21:27:28 -04:00
committed by GitHub
parent 708989341e
commit ed812a73fa
20 changed files with 1567 additions and 101 deletions

View File

@ -0,0 +1,715 @@
"""
Repeat of tests in test_completion.py with the non-mp backend.
"""
# imports for guided decoding tests
import json
import re
import shutil
from tempfile import TemporaryDirectory
from typing import List
import jsonschema
import openai # use the official client for correctness check
import pytest
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from openai import BadRequestError
from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically these adapters use a different base model,
# but we're not testing generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
PA_NAME = "swapnilbp/llama_tweet_ptune"
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS = 8
@pytest.fixture(scope="module")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="module")
def zephyr_lora_added_tokens_files(zephyr_lora_files):
tmp_dir = TemporaryDirectory()
tmp_model_dir = f"{tmp_dir.name}/zephyr"
shutil.copytree(zephyr_lora_files, tmp_model_dir)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Copy tokenizer to adapter and add some unique tokens
# 32000, 32001, 32002
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
special_tokens=True)
assert added == 3
tokenizer.save_pretrained(tmp_model_dir)
yield tmp_model_dir
tmp_dir.cleanup()
@pytest.fixture(scope="module")
def zephyr_pa_files():
return snapshot_download(repo_id=PA_NAME)
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enforce-eager",
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
f"zephyr-pa={zephyr_pa_files}",
f"zephyr-pa2={zephyr_pa_files}",
"--max-prompt-adapters",
"2",
"--max-prompt-adapter-token",
"128",
"--disable-frontend-multiprocessing"
]
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name,num_virtual_tokens",
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
)
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
num_virtual_tokens: int):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
total_tokens=11 + num_virtual_tokens)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model="zephyr-lora2",
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should appear in tokenized prompt
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 32000, 32001, 32002],
echo=True,
max_tokens=5,
temperature=0.0,
)
# Added tokens should not appear in tokenized prompt
assert "vllm" not in completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras, then test prompt adapters
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora and 1 pa hereafter
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=0,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str):
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=21,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=30,
stream=True,
)
async for chunk in stream:
...
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is an LLM?"
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
False,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
False,
})
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
else:
assert chunk.usage is None
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is not None
assert chunk.usage.prompt_tokens > 0
assert chunk.usage.completion_tokens > 0
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
if chunk.choices[0].finish_reason is not None:
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=False, stream_options=
# {"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None})
# Test stream=False, stream_options=
# {"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True})
# Test stream=False, stream_options=
# {"continuous_usage_stats": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": None})
# Test stream=False, stream_options=
# {"continuous_usage_stats": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": True})
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
token_id = 1000
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token_id): 100},
seed=42,
)
assert len(completion.choices[0].text) >= 5
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
add_special_tokens=False)["input_ids"]
assert all([
response == expected
for response, expected in zip(response_tokens, expected_tokens)
])
# Test ban
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
)
response_tokens = tokenizer(completion.choices[0].text,
add_special_tokens=False)["input_ids"]
first_response = completion.choices[0].text
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
logit_bias={str(token): -100
for token in response_tokens},
)
assert first_response != completion.choices[0].text
@pytest.mark.asyncio
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 1
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
allowed_ids = [21555, 21557, 21558]
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
extra_body=dict(allowed_token_ids=allowed_ids),
logprobs=1,
)
response_tokens = completion.choices[0].logprobs.tokens
assert len(response_tokens) == 1
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}",
n=3,
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_json=sample_json_schema,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert len(completion.choices) == 3
for i in range(3):
output_json = json.loads(completion.choices[i].text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_regex):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
n=3,
temperature=1.0,
max_tokens=20,
extra_body=dict(guided_regex=sample_regex,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert len(completion.choices) == 3
for i in range(3):
assert re.fullmatch(sample_regex,
completion.choices[i].text) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_guided_choice):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
extra_body=dict(guided_choice=sample_guided_choice,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert len(completion.choices) == 2
for i in range(2):
assert completion.choices[i].text in sample_guided_choice
@pytest.mark.asyncio
async def test_guided_grammar(client: openai.AsyncOpenAI,
sample_sql_statements):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
temperature=1.0,
max_tokens=500,
extra_body=dict(guided_grammar=sample_sql_statements))
content = completion.choices[0].text
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(content)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=logprobs_arg)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert re.search(r"^" + prompt_text, completion.choices[0].text)
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema, sample_regex):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42,
guided_decoding_backend=guided_decoding_backend))
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example string that fits this regex",
extra_body=dict(guided_regex=sample_regex,
guided_json=sample_json_schema))

View File

@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
@ -928,6 +929,14 @@ class AsyncLLMEngine:
else:
return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
@ -936,6 +945,22 @@ class AsyncLLMEngine:
else:
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,

View File

@ -38,9 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
@ -485,19 +484,12 @@ class LLMEngine:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@ -759,10 +751,22 @@ class LLMEngine:
"""Gets the model configuration."""
return self.model_config
def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config
def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config
def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config
def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups()

84
vllm/engine/protocol.py Normal file
View File

@ -0,0 +1,84 @@
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
runtime_checkable)
from transformers import PreTrainedTokenizer
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput
@runtime_checkable
class AsyncEngineClient(Protocol):
"""Protocol class for Clients to AsyncLLMEngine"""
@property
def is_running(self) -> bool:
...
@property
def is_stopped(self) -> bool:
...
@property
def errored(self) -> bool:
...
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generates outputs for a request"""
async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model."""
async def abort(self, request_id: str) -> None:
"""Abort a request.
Args:
request_id: The unique id of the request.
"""
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""
async def is_tracing_enabled(self) -> bool:
pass
async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass
async def check_health(self) -> None:
"""Raise if unhealthy"""

View File

@ -5,7 +5,8 @@ import re
import signal
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Optional, Set
from multiprocessing import Process
from typing import AsyncIterator, Set
import fastapi
import uvicorn
@ -17,8 +18,10 @@ from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
@ -31,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@ -39,12 +44,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
engine: AsyncLLMEngine
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
@ -56,13 +61,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
@ -72,6 +86,52 @@ async def lifespan(app: fastapi.FastAPI):
yield
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
try:
yield async_engine_client
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
router = APIRouter()
@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI):
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
await async_engine_client.check_health()
return Response(status_code=200)
@ -215,8 +275,8 @@ def build_app(args):
async def build_server(
async_engine_client: AsyncEngineClient,
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
app = build_app(args)
@ -226,14 +286,7 @@ async def build_server(
else:
served_model_names = [args.model]
global engine, engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
model_config = await engine.get_model_config()
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests:
request_logger = None
@ -246,7 +299,7 @@ async def build_server(
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(
engine,
async_engine_client,
model_config,
served_model_names,
args.response_role,
@ -257,7 +310,7 @@ async def build_server(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
engine,
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@ -266,13 +319,13 @@ async def build_server(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
async_engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
engine,
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@ -304,32 +357,39 @@ async def build_server(
return uvicorn.Server(config)
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)
shutdown_task = None
async with build_async_engine_client(args) as async_engine_client:
loop = asyncio.get_running_loop()
server = await build_server(
async_engine_client,
args,
**uvicorn_kwargs,
)
server_task = loop.create_task(server.serve())
loop = asyncio.get_running_loop()
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
server_task = loop.create_task(server.serve())
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
shutdown_task = server.shutdown()
if shutdown_task:
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
if __name__ == "__main__":

View File

@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as"
"strings of the form 'token_id:{token_id}' so that tokens that"
help="When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified.")
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser = AsyncEngineArgs.add_cli_args(parser)

View File

@ -1,4 +1,4 @@
from functools import lru_cache
from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
@ -40,6 +40,14 @@ def _get_allowed_token_ids_logits_processor(
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(logit_bias: Dict[str,
float], token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
@ -64,13 +72,8 @@ def get_logits_processors(
raise ValueError("token_id in logit_bias contains "
"out-of-vocab token id")
def logit_bias_logits_processor(token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in clamped_logit_bias.items():
logits[token_id] += bias
return logits
logits_processors.append(logit_bias_logits_processor)
logits_processors.append(
partial(logit_bias_logits_processor, clamped_logit_bias))
if allowed_token_ids is not None:
logits_processors.append(

View File

@ -0,0 +1,42 @@
from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional, Union
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
VLLM_RPC_SUCCESS_STR = "SUCCESS"
VLLM_RPC_HEALTHY_STR = "HEALTHY"
@dataclass
class RPCGenerateRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCAbortRequest:
request_id: str
class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8
IS_TRACING_ENABLED = 9
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
RPCUtilityRequest]

View File

@ -0,0 +1,248 @@
from contextlib import contextmanager
from typing import Any, AsyncIterator, Mapping, Optional
import cloudpickle
import zmq
import zmq.asyncio
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
class AsyncEngineRPCClient:
def __init__(self, port: int):
self.context = zmq.asyncio.Context()
self.path = f"tcp://localhost:{port}"
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await self.wait_for_server()
# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()
self.tracing_flag = await self._is_tracing_enabled_rpc()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)
def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()
@contextmanager
def socket(self):
# Ensure client sockets are always closed after use
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.path)
yield socket
finally:
socket.close()
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
with self.socket() as socket:
# Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request))
# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
raise ValueError(error_message)
return data
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
# Await acknowledgement from RPCServer.
response = cloudpickle.loads(await socket.recv())
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
raise ValueError(error_message)
return response
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def wait_for_server(self):
"""Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ParallelConfig from RPC Server")
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self):
"""Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
"""Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled flag from RPC "
"Server")
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with self.socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart([
cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
])
# Stream back the results from the RPC Server.
while True:
message = await socket.recv()
request_output = cloudpickle.loads(message)
if isinstance(request_output, Exception):
raise request_output
if request_output.finished:
break
yield request_output
yield request_output
async def check_health(self) -> None:
"""Raise if unhealthy"""
with self.socket() as socket:
# Ping RPCServer with CHECK_HEALTH request.
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message = cloudpickle.loads(await socket.recv())
if isinstance(health_message, Exception):
raise health_message
if health_message != VLLM_RPC_HEALTHY_STR:
raise ValueError("Expected healthy response from backend but got "
"f{health_message}")
async def encode(self, *args,
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")

View File

@ -0,0 +1,216 @@
import asyncio
import signal
from typing import Any, Coroutine
import cloudpickle
import zmq
import zmq.asyncio
from typing_extensions import Never
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
self.socket.bind(f"tcp://localhost:{port}")
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])
async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])
async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
else:
raise ValueError(f"Unknown RPCRequest type: {request}")
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
asyncio.run(run_server(server))

View File

@ -8,7 +8,7 @@ from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
@ -50,7 +50,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@ -89,7 +89,8 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
@ -161,7 +162,8 @@ class OpenAIServingChat(OpenAIServing):
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
@ -169,7 +171,7 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
result_generator = self.async_engine_client.generate(
engine_inputs,
sampling_params,
request_id,
@ -441,7 +443,7 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None

View File

@ -8,7 +8,7 @@ from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
@ -51,7 +51,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@ -91,7 +91,8 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
@ -119,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = await self.engine.is_tracing_enabled()
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
@ -127,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.engine.generate(
generator = self.async_engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
@ -168,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
@ -230,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
await self.async_engine_client.abort(
f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:

View File

@ -6,7 +6,7 @@ import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None,
@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
pooling_params = request.to_pooling_params()
@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine.encode(
generator = self.async_engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
request_id_item,
@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
await self.async_engine_client.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res

View File

@ -8,7 +8,7 @@ from pydantic import Field
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -61,7 +61,7 @@ class OpenAIServing:
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
@ -72,7 +72,7 @@ class OpenAIServing:
):
super().__init__()
self.engine = engine
self.async_engine_client = async_engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
@ -155,7 +155,7 @@ class OpenAIServing:
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine.get_decoding_config()
decoding_config = await self.async_engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(

View File

@ -1,9 +1,9 @@
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
@ -32,7 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@ -57,7 +57,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
@ -113,7 +113,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_RPC_PORT: int = 5570
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
@ -140,6 +141,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv('VLLM_PORT', '0'))
if 'VLLM_PORT' in os.environ else None,
# used when the frontend api server is running in multi-processing mode,
# to communicate with the backend engine process over ZMQ.
'VLLM_RPC_PORT':
lambda: int(os.getenv('VLLM_PORT', '5570')),
# If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE":

View File

@ -21,6 +21,8 @@ from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
import torch
from lark import Lark
from outlines import grammars
from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema
@ -44,6 +46,23 @@ class BaseLogitsProcessor:
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token)
else:
# Note: this is a hack.
# Lark pickling does not work properly (silent failure),
# which breaks the RPC (which uses python pickleing).
# We need to find a better solution.
# On the first time this is called, we simply re-create
# the Lark object.
if isinstance(self._guide, CFGGuide):
self._guide.parser = Lark(
self._guide.cfg_string,
parser="lalr",
lexer="contextual",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
import_paths=[grammars.GRAMMAR_PATH],
)
instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])

View File

@ -60,7 +60,7 @@ def get_span_exporter(endpoint):
OTLPSpanExporter)
elif protocol == "http/protobuf":
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter)
OTLPSpanExporter) # type: ignore
else:
raise ValueError(
f"Unsupported OTLP protocol '{protocol}' is configured")

View File

@ -1,6 +1,7 @@
from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
TokenizerPoolConfig)
from vllm.executor.ray_utils import ray
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
@ -13,6 +14,22 @@ else:
RayTokenizerGroupPool = None # type: ignore
def init_tokenizer_from_configs(model_config: ModelConfig,
scheduler_config: SchedulerConfig,
parallel_config: ParallelConfig,
enable_lora: bool):
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
enable_lora=enable_lora,
max_num_seqs=scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision)
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
**init_kwargs)
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
tokenizer_cls: Type[BaseTokenizerGroup]

View File

@ -290,6 +290,10 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
class ProducerFinished:
pass
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
@ -298,9 +302,10 @@ def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
producers = len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
@ -308,7 +313,8 @@ def merge_async_iterators(
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())
_tasks = [
asyncio.create_task(producer(i, iterator))
@ -316,9 +322,17 @@ def merge_async_iterators(
]
async def consumer():
remaining = producers
try:
while not all(finished) or not queue.empty():
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()
if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue
if isinstance(item, Exception):
raise item
yield item
@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_port() -> int:
port = envs.VLLM_PORT
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = envs.VLLM_PORT
if port is not None:
while True:
try: