mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ Frontend ] Multiprocessing for OpenAI Server with zeromq
(#6883)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <joe@joerun.de> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
715
tests/entrypoints/openai/test_disable_mp.py
Normal file
715
tests/entrypoints/openai/test_disable_mp.py
Normal file
@ -0,0 +1,715 @@
|
||||
"""
|
||||
Repeat of tests in test_completion.py with the non-mp backend.
|
||||
"""
|
||||
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# technically these adapters use a different base model,
|
||||
# but we're not testing generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
PA_NAME = "swapnilbp/llama_tweet_ptune"
|
||||
# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also
|
||||
# need to change to match the prompt adapter
|
||||
PA_NUM_VIRTUAL_TOKENS = 8
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_added_tokens_files(zephyr_lora_files):
|
||||
tmp_dir = TemporaryDirectory()
|
||||
tmp_model_dir = f"{tmp_dir.name}/zephyr"
|
||||
shutil.copytree(zephyr_lora_files, tmp_model_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
# Copy tokenizer to adapter and add some unique tokens
|
||||
# 32000, 32001, 32002
|
||||
added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"],
|
||||
special_tokens=True)
|
||||
assert added == 3
|
||||
tokenizer.save_pretrained(tmp_model_dir)
|
||||
yield tmp_model_dir
|
||||
tmp_dir.cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_pa_files():
|
||||
return snapshot_download(repo_id=PA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
|
||||
zephyr_pa_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
# pa config
|
||||
"--enable-prompt-adapter",
|
||||
"--prompt-adapters",
|
||||
f"zephyr-pa={zephyr_pa_files}",
|
||||
f"zephyr-pa2={zephyr_pa_files}",
|
||||
"--max-prompt-adapters",
|
||||
"2",
|
||||
"--max-prompt-adapter-token",
|
||||
"128",
|
||||
"--disable-frontend-multiprocessing"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(default_server_args):
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras, then test prompt adapters
|
||||
"model_name,num_virtual_tokens",
|
||||
[(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0),
|
||||
("zephyr-pa", PA_NUM_VIRTUAL_TOKENS),
|
||||
("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
|
||||
num_virtual_tokens: int):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5,
|
||||
prompt_tokens=6 + num_virtual_tokens,
|
||||
total_tokens=11 + num_virtual_tokens)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model="zephyr-lora2",
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
# Added tokens should appear in tokenized prompt
|
||||
assert completion.choices[0].text.startswith("<unk><unk>vllm1vllm2vllm3")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
# Added tokens should not appear in tokenized prompt
|
||||
assert "vllm" not in completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras, then test prompt adapters
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"],
|
||||
)
|
||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora and 1 pa hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=0,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=21,
|
||||
)
|
||||
...
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
stream = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=30,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert chunk.usage is None
|
||||
else:
|
||||
assert chunk.usage is None
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is not None
|
||||
assert chunk.usage.prompt_tokens > 0
|
||||
assert chunk.usage.completion_tokens > 0
|
||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
||||
chunk.usage.completion_tokens)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": True})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": True})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
|
||||
)
|
||||
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test both text and token IDs
|
||||
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
||||
# for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logits_bias(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 5
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
token_id = 1000
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token_id): 100},
|
||||
seed=42,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 5
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
||||
add_special_tokens=False)["input_ids"]
|
||||
assert all([
|
||||
response == expected
|
||||
for response, expected in zip(response_tokens, expected_tokens)
|
||||
])
|
||||
|
||||
# Test ban
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
)
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
first_response = completion.choices[0].text
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token): -100
|
||||
for token in response_tokens},
|
||||
)
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 1
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
allowed_ids = [21555, 21557, 21558]
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
extra_body=dict(allowed_token_ids=allowed_ids),
|
||||
logprobs=1,
|
||||
)
|
||||
response_tokens = completion.choices[0].logprobs.tokens
|
||||
assert len(response_tokens) == 1
|
||||
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=sample_json_schema,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 3
|
||||
for i in range(3):
|
||||
output_json = json.loads(completion.choices[i].text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_regex):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 3
|
||||
for i in range(3):
|
||||
assert re.fullmatch(sample_regex,
|
||||
completion.choices[i].text) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
n=2,
|
||||
temperature=1.0,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=sample_guided_choice,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 2
|
||||
for i in range(2):
|
||||
assert completion.choices[i].text in sample_guided_choice
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_grammar(client: openai.AsyncOpenAI,
|
||||
sample_sql_statements):
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_grammar=sample_sql_statements))
|
||||
|
||||
content = completion.choices[0].text
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_statements)
|
||||
parser.parse(content)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
||||
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, logprobs_arg: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=logprobs_arg)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert re.search(r"^" + prompt_text, completion.choices[0].text)
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, sample_regex):
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example JSON that fits this schema: 42",
|
||||
extra_body=dict(guided_json=42,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example string that fits this regex",
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_json=sample_json_schema))
|
@ -7,7 +7,8 @@ from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
@ -928,6 +929,14 @@ class AsyncLLMEngine:
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
async def get_parallel_config(self) -> ParallelConfig:
|
||||
"""Get the parallel configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_parallel_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_parallel_config()
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
@ -936,6 +945,22 @@ class AsyncLLMEngine:
|
||||
else:
|
||||
return self.engine.get_decoding_config()
|
||||
|
||||
async def get_scheduler_config(self) -> SchedulerConfig:
|
||||
"""Get the scheduling configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_scheduler_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_scheduler_config()
|
||||
|
||||
async def get_lora_config(self) -> LoRAConfig:
|
||||
"""Get the lora configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_lora_config.remote( # type: ignore
|
||||
)
|
||||
else:
|
||||
return self.engine.get_lora_config()
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
|
@ -38,9 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
|
||||
BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter
|
||||
@ -485,19 +484,12 @@ class LLMEngine:
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(
|
||||
sequence.lora_request)
|
||||
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=bool(self.lora_config),
|
||||
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision)
|
||||
init_kwargs.update(tokenizer_init_kwargs)
|
||||
|
||||
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
|
||||
**init_kwargs)
|
||||
def _init_tokenizer(self) -> BaseTokenizerGroup:
|
||||
return init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
parallel_config=self.parallel_config,
|
||||
enable_lora=bool(self.lora_config))
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@ -759,10 +751,22 @@ class LLMEngine:
|
||||
"""Gets the model configuration."""
|
||||
return self.model_config
|
||||
|
||||
def get_parallel_config(self) -> ParallelConfig:
|
||||
"""Gets the parallel configuration."""
|
||||
return self.parallel_config
|
||||
|
||||
def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Gets the decoding configuration."""
|
||||
return self.decoding_config
|
||||
|
||||
def get_scheduler_config(self) -> SchedulerConfig:
|
||||
"""Gets the scheduler configuration."""
|
||||
return self.scheduler_config
|
||||
|
||||
def get_lora_config(self) -> LoRAConfig:
|
||||
"""Gets the LoRA configuration."""
|
||||
return self.lora_config
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return sum(scheduler.get_num_unfinished_seq_groups()
|
||||
|
84
vllm/engine/protocol.py
Normal file
84
vllm/engine/protocol.py
Normal file
@ -0,0 +1,84 @@
|
||||
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
|
||||
runtime_checkable)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import DecodingConfig, ModelConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncEngineClient(Protocol):
|
||||
"""Protocol class for Clients to AsyncLLMEngine"""
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
...
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
...
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generates outputs for a request"""
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model."""
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
|
||||
async def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PreTrainedTokenizer:
|
||||
"""Get the appropriate Tokenizer for the request"""
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
pass
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
scheduler_outputs: Optional[SchedulerOutputs] = None,
|
||||
model_output: Optional[List[SamplerOutput]] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def check_health(self) -> None:
|
||||
"""Raise if unhealthy"""
|
@ -5,7 +5,8 @@ import re
|
||||
import signal
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Set
|
||||
from multiprocessing import Process
|
||||
from typing import AsyncIterator, Set
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
@ -17,8 +18,10 @@ from prometheus_client import make_asgi_app
|
||||
from starlette.routing import Mount
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
# yapf conflicts with isort for this block
|
||||
@ -31,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
@ -39,12 +44,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
engine: AsyncLLMEngine
|
||||
async_engine_client: AsyncEngineClient
|
||||
engine_args: AsyncEngineArgs
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
@ -56,13 +61,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def model_is_embedding(model_name: str) -> bool:
|
||||
return ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16").embedding_mode
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await engine.do_log_stats()
|
||||
await async_engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
@ -72,6 +86,52 @@ async def lifespan(app: fastapi.FastAPI):
|
||||
yield
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
|
||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||
# TODO: support embedding model via RPC.
|
||||
if (model_is_embedding(args.model)
|
||||
or args.disable_frontend_multiprocessing):
|
||||
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
yield async_engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
port = get_open_port(envs.VLLM_RPC_PORT)
|
||||
rpc_server_process = Process(target=run_rpc_server,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
port))
|
||||
rpc_server_process.start()
|
||||
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
async_engine_client = AsyncEngineRPCClient(port)
|
||||
await async_engine_client.setup()
|
||||
|
||||
try:
|
||||
yield async_engine_client
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
rpc_server_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
async_engine_client.close()
|
||||
|
||||
# Wait for server process to join
|
||||
rpc_server_process.join()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI):
|
||||
@router.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
await openai_serving_chat.engine.check_health()
|
||||
await async_engine_client.check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@ -215,8 +275,8 @@ def build_app(args):
|
||||
|
||||
|
||||
async def build_server(
|
||||
async_engine_client: AsyncEngineClient,
|
||||
args,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
**uvicorn_kwargs,
|
||||
) -> uvicorn.Server:
|
||||
app = build_app(args)
|
||||
@ -226,14 +286,7 @@ async def build_server(
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
global engine, engine_args
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (llm_engine
|
||||
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
@ -246,7 +299,7 @@ async def build_server(
|
||||
global openai_serving_tokenization
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
args.response_role,
|
||||
@ -257,7 +310,7 @@ async def build_server(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
@ -266,13 +319,13 @@ async def build_server(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
@ -304,32 +357,39 @@ async def build_server(
|
||||
return uvicorn.Server(config)
|
||||
|
||||
|
||||
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
server = await build_server(
|
||||
args,
|
||||
llm_engine,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
shutdown_task = None
|
||||
async with build_async_engine_client(args) as async_engine_client:
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
server = await build_server(
|
||||
async_engine_client,
|
||||
args,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
server_task = loop.create_task(server.serve())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
server_task = loop.create_task(server.serve())
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
print("Gracefully stopping http server")
|
||||
await server.shutdown()
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Gracefully stopping http server")
|
||||
shutdown_task = server.shutdown()
|
||||
|
||||
if shutdown_task:
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -131,9 +131,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
parser.add_argument(
|
||||
"--return-tokens-as-token-ids",
|
||||
action="store_true",
|
||||
help="When --max-logprobs is specified, represents single tokens as"
|
||||
"strings of the form 'token_id:{token_id}' so that tokens that"
|
||||
help="When --max-logprobs is specified, represents single tokens as "
|
||||
"strings of the form 'token_id:{token_id}' so that tokens that "
|
||||
"are not JSON-encodable can be identified.")
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from functools import lru_cache
|
||||
from functools import lru_cache, partial
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -40,6 +40,14 @@ def _get_allowed_token_ids_logits_processor(
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(logit_bias: Dict[str,
|
||||
float], token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
@ -64,13 +72,8 @@ def get_logits_processors(
|
||||
raise ValueError("token_id in logit_bias contains "
|
||||
"out-of-vocab token id")
|
||||
|
||||
def logit_bias_logits_processor(token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in clamped_logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
logits_processors.append(logit_bias_logits_processor)
|
||||
logits_processors.append(
|
||||
partial(logit_bias_logits_processor, clamped_logit_bias))
|
||||
|
||||
if allowed_token_ids is not None:
|
||||
logits_processors.append(
|
||||
|
42
vllm/entrypoints/openai/rpc/__init__.py
Normal file
42
vllm/entrypoints/openai/rpc/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Mapping, Optional, Union
|
||||
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
VLLM_RPC_HEALTHY_STR = "HEALTHY"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCGenerateRequest:
|
||||
inputs: PromptInputs
|
||||
sampling_params: SamplingParams
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCAbortRequest:
|
||||
request_id: str
|
||||
|
||||
|
||||
class RPCUtilityRequest(Enum):
|
||||
IS_SERVER_READY = 1
|
||||
GET_MODEL_CONFIG = 2
|
||||
GET_DECODING_CONFIG = 3
|
||||
GET_PARALLEL_CONFIG = 4
|
||||
GET_SCHEDULER_CONFIG = 5
|
||||
GET_LORA_CONFIG = 6
|
||||
DO_LOG_STATS = 7
|
||||
CHECK_HEALTH = 8
|
||||
IS_TRACING_ENABLED = 9
|
||||
|
||||
|
||||
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
||||
RPCUtilityRequest]
|
248
vllm/entrypoints/openai/rpc/client.py
Normal file
248
vllm/entrypoints/openai/rpc/client.py
Normal file
@ -0,0 +1,248 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, AsyncIterator, Mapping, Optional
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
|
||||
VLLM_RPC_HEALTHY_STR,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
|
||||
class AsyncEngineRPCClient:
|
||||
|
||||
def __init__(self, port: int):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self.path = f"tcp://localhost:{port}"
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the client before it starts sending server requests."""
|
||||
|
||||
# Wait until server is ready.
|
||||
await self.wait_for_server()
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = await self._get_model_config_rpc()
|
||||
self.decoding_config = await self._get_decoding_config_rpc()
|
||||
self.tracing_flag = await self._is_tracing_enabled_rpc()
|
||||
|
||||
# Create the tokenizer group.
|
||||
# TODO: refactor OAI server to avoid needing this info.
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=(await self._get_scheduler_config_rpc()),
|
||||
parallel_config=(await self._get_parallel_config_rpc()),
|
||||
enable_lora=bool(await self._get_lora_config_rpc()),
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
self.context.destroy()
|
||||
|
||||
@contextmanager
|
||||
def socket(self):
|
||||
# Ensure client sockets are always closed after use
|
||||
|
||||
# Connect to RPC socket for Request-Reply pattern,
|
||||
# Note that we use DEALER to enable asynchronous communication
|
||||
# to enable streaming.
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
try:
|
||||
socket.connect(self.path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close()
|
||||
|
||||
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
|
||||
expected_type: Any,
|
||||
error_message: str) -> Any:
|
||||
"""Send an RPC request that is expecting data back."""
|
||||
|
||||
with self.socket() as socket:
|
||||
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send(cloudpickle.dumps(request))
|
||||
|
||||
# Await the data from the Server.
|
||||
data = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if not isinstance(data, expected_type):
|
||||
# LoRAConfig can be None.
|
||||
if expected_type == LoRAConfig and data is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(error_message)
|
||||
|
||||
return data
|
||||
|
||||
async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
|
||||
error_message: str):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
with self.socket() as socket:
|
||||
# Ping RPC Server with request.
|
||||
await socket.send(cloudpickle.dumps(request))
|
||||
|
||||
# Await acknowledgement from RPCServer.
|
||||
response = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
||||
raise ValueError(error_message)
|
||||
|
||||
return response
|
||||
|
||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
return self.decoding_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.tracing_flag
|
||||
|
||||
async def wait_for_server(self):
|
||||
"""Wait for the RPCServer to start up."""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
||||
error_message="Unable to start RPC Server.")
|
||||
|
||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
||||
"""Get the ModelConfig object from the RPC Server"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_MODEL_CONFIG,
|
||||
expected_type=ModelConfig,
|
||||
error_message="Could not get ModelConfig from RPC Server")
|
||||
|
||||
async def _get_decoding_config_rpc(self) -> DecodingConfig:
|
||||
"""Get DecodingConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_DECODING_CONFIG,
|
||||
expected_type=DecodingConfig,
|
||||
error_message="Could not get DecodingConfig from RPC Server")
|
||||
|
||||
async def _get_parallel_config_rpc(self) -> ParallelConfig:
|
||||
"""Get ParallelConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_PARALLEL_CONFIG,
|
||||
expected_type=ParallelConfig,
|
||||
error_message="Could not get ParallelConfig from RPC Server")
|
||||
|
||||
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
|
||||
"""Get SchedulerConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
|
||||
expected_type=SchedulerConfig,
|
||||
error_message="Could not get SchedulerConfig from RPC Server")
|
||||
|
||||
async def _get_lora_config_rpc(self):
|
||||
"""Get LoRAConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_LORA_CONFIG,
|
||||
expected_type=LoRAConfig,
|
||||
error_message="Could not get LoRAConfig from RPC Server")
|
||||
|
||||
async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
|
||||
"""Get is_tracing_enabled flag from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.IS_TRACING_ENABLED,
|
||||
expected_type=bool,
|
||||
error_message="Could not get is_tracing_enabled flag from RPC "
|
||||
"Server")
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id),
|
||||
error_message=f"RPCAbortRequest {request_id} failed")
|
||||
|
||||
async def do_log_stats(self):
|
||||
"""Send a DO_LOG_STATS signal to the RPC Server"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
with self.socket() as socket:
|
||||
|
||||
# Send RPCGenerateRequest to the RPCServer.
|
||||
await socket.send_multipart([
|
||||
cloudpickle.dumps(
|
||||
RPCGenerateRequest(
|
||||
inputs=inputs,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request))
|
||||
])
|
||||
|
||||
# Stream back the results from the RPC Server.
|
||||
while True:
|
||||
message = await socket.recv()
|
||||
request_output = cloudpickle.loads(message)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
raise request_output
|
||||
|
||||
if request_output.finished:
|
||||
break
|
||||
yield request_output
|
||||
|
||||
yield request_output
|
||||
|
||||
async def check_health(self) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
|
||||
with self.socket() as socket:
|
||||
|
||||
# Ping RPCServer with CHECK_HEALTH request.
|
||||
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
|
||||
)
|
||||
|
||||
# Await the reply from the server.
|
||||
# TODO: do we need an internal timeout here?
|
||||
# Or do we expect the external probe to timeout and let this chill?
|
||||
health_message = cloudpickle.loads(await socket.recv())
|
||||
|
||||
if isinstance(health_message, Exception):
|
||||
raise health_message
|
||||
|
||||
if health_message != VLLM_RPC_HEALTHY_STR:
|
||||
raise ValueError("Expected healthy response from backend but got "
|
||||
"f{health_message}")
|
||||
|
||||
async def encode(self, *args,
|
||||
**kwargs) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
216
vllm/entrypoints/openai/rpc/server.py
Normal file
216
vllm/entrypoints/openai/rpc/server.py
Normal file
@ -0,0 +1,216 @@
|
||||
import asyncio
|
||||
import signal
|
||||
from typing import Any, Coroutine
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from typing_extensions import Never
|
||||
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncEngineRPCServer:
|
||||
|
||||
def __init__(self, async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, port: int):
|
||||
# Initialize engine first.
|
||||
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
|
||||
usage_context)
|
||||
|
||||
# Initialize context.
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
# Init socket for readiness state.
|
||||
self.socket = self.context.socket(zmq.constants.ROUTER)
|
||||
self.socket.bind(f"tcp://localhost:{port}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup all resources."""
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
|
||||
async def get_model_config(self, identity):
|
||||
"""Send the ModelConfig"""
|
||||
model_config = await self.engine.get_model_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(model_config)])
|
||||
|
||||
async def get_decoding_config(self, identity):
|
||||
"""Send the DecodingConfig"""
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(decoding_config)])
|
||||
|
||||
async def get_lora_config(self, identity):
|
||||
lora_config = await self.engine.get_lora_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(lora_config)])
|
||||
|
||||
async def get_scheduler_config(self, identity):
|
||||
"""Send the SchedulerConfig"""
|
||||
parallel_config = await self.engine.get_scheduler_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(parallel_config)])
|
||||
|
||||
async def get_parallel_config(self, identity):
|
||||
"""Send the ParallelConfig"""
|
||||
parallel_config = await self.engine.get_parallel_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(parallel_config)])
|
||||
|
||||
async def is_tracing_enabled(self, identity):
|
||||
"""Send the is_tracing_enabled flag"""
|
||||
tracing_flag = await self.engine.is_tracing_enabled()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(tracing_flag)])
|
||||
|
||||
async def do_log_stats(self, identity):
|
||||
"""Log stats and confirm success."""
|
||||
await self.engine.do_log_stats()
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def is_server_ready(self, identity):
|
||||
"""Notify the client that we are ready."""
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def abort(self, identity, request: RPCAbortRequest):
|
||||
"""Abort request and notify the client of success."""
|
||||
# Abort the request in the llm engine.
|
||||
await self.engine.abort(request.request_id)
|
||||
|
||||
# Send confirmation to the client.
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||
try:
|
||||
results_generator = self.engine.generate(
|
||||
generate_request.inputs,
|
||||
sampling_params=generate_request.sampling_params,
|
||||
request_id=generate_request.request_id,
|
||||
lora_request=generate_request.lora_request,
|
||||
trace_headers=generate_request.trace_headers,
|
||||
prompt_adapter_request=generate_request.prompt_adapter_request)
|
||||
|
||||
async for request_output in results_generator:
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(request_output)])
|
||||
|
||||
except Exception as e:
|
||||
### Notify client of all failures
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
async def check_health(self, identity):
|
||||
try:
|
||||
await self.engine.check_health()
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
def _make_handler_coro(self, identity,
|
||||
message) -> Coroutine[Any, Any, Never]:
|
||||
"""Route the zmq message to the handler coroutine."""
|
||||
|
||||
request = cloudpickle.loads(message)
|
||||
|
||||
if isinstance(request, RPCGenerateRequest):
|
||||
return self.generate(identity, request)
|
||||
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
return self.abort(identity, request)
|
||||
|
||||
elif isinstance(request, RPCUtilityRequest):
|
||||
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
||||
return self.get_model_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
||||
return self.get_parallel_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
||||
return self.get_decoding_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
||||
return self.get_scheduler_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
||||
return self.get_lora_config(identity)
|
||||
elif request == RPCUtilityRequest.DO_LOG_STATS:
|
||||
return self.do_log_stats(identity)
|
||||
elif request == RPCUtilityRequest.IS_SERVER_READY:
|
||||
return self.is_server_ready(identity)
|
||||
elif request == RPCUtilityRequest.CHECK_HEALTH:
|
||||
return self.check_health(identity)
|
||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||
return self.is_tracing_enabled(identity)
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCRequest type: {request}")
|
||||
|
||||
async def run_server_loop(self):
|
||||
"""Inner RPC Server Loop"""
|
||||
|
||||
running_tasks = set()
|
||||
while True:
|
||||
# Wait for a request.
|
||||
identity, message = await self.socket.recv_multipart()
|
||||
|
||||
# Process the request async.
|
||||
task = asyncio.create_task(
|
||||
self._make_handler_coro(identity, message))
|
||||
|
||||
# We need to keep around a strong reference to the task,
|
||||
# to avoid the task disappearing mid-execution as running tasks
|
||||
# can be GC'ed. Below is a common "fire-and-forget" tasks
|
||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||
running_tasks.add(task)
|
||||
task.add_done_callback(running_tasks.discard)
|
||||
|
||||
|
||||
async def run_server(server: AsyncEngineRPCServer):
|
||||
# Put the server task into the asyncio loop.
|
||||
loop = asyncio.get_running_loop()
|
||||
server_task = loop.create_task(server.run_server_loop())
|
||||
|
||||
# Interruption handling.
|
||||
def signal_handler() -> None:
|
||||
# Kill the server on interrupt / terminate
|
||||
server_task.cancel()
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("vLLM ZMQ RPC Server was interrupted.")
|
||||
finally:
|
||||
# Clean up all resources.
|
||||
server.cleanup()
|
||||
|
||||
|
||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, port: int):
|
||||
server = AsyncEngineRPCServer(async_engine_args, usage_context, port)
|
||||
asyncio.run(run_server(server))
|
@ -8,7 +8,7 @@ from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
@ -50,7 +50,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@ -89,7 +89,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
@ -161,7 +162,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
is_tracing_enabled = (
|
||||
await self.async_engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
@ -169,7 +171,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine.generate(
|
||||
result_generator = self.async_engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
@ -441,7 +443,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async for res in result_generator:
|
||||
if raw_request is not None and await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
await self.async_engine_client.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
@ -8,7 +8,7 @@ from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -42,7 +42,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
@ -51,7 +51,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@ -91,7 +91,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
|
||||
guided_decode_logits_processor = (
|
||||
await self._guided_decode_logits_processor(request, tokenizer))
|
||||
@ -119,7 +120,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
is_tracing_enabled = (
|
||||
await self.async_engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
@ -127,7 +129,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
generator = self.engine.generate(
|
||||
generator = self.async_engine_client.generate(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
@ -168,7 +170,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
await self.async_engine_client.abort(f"{request_id}-{i}")
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
|
||||
@ -230,7 +232,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
# Abort the request if the client disconnects.
|
||||
if await raw_request.is_disconnected():
|
||||
await self.engine.abort(f"{request_id}-{prompt_idx}")
|
||||
await self.async_engine_client.abort(
|
||||
f"{request_id}-{prompt_idx}")
|
||||
raise StopAsyncIteration()
|
||||
|
||||
for output in res.outputs:
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
@ -56,13 +56,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=None,
|
||||
@ -99,7 +99,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
@ -124,7 +125,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
generator = self.engine.encode(
|
||||
generator = self.async_engine_client.encode(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
@ -146,7 +147,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
await self.async_engine_client.abort(f"{request_id}-{i}")
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
|
||||
|
@ -8,7 +8,7 @@ from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -61,7 +61,7 @@ class OpenAIServing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
@ -72,7 +72,7 @@ class OpenAIServing:
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine = engine
|
||||
self.async_engine_client = async_engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
@ -155,7 +155,7 @@ class OpenAIServing:
|
||||
async def _guided_decode_logits_processor(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
decoding_config = await self.async_engine_client.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
return await get_guided_decoding_logits_processor(
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
@ -24,7 +24,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
@ -32,7 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@ -57,7 +57,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
@ -113,7 +113,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
if TYPE_CHECKING:
|
||||
VLLM_HOST_IP: str = ""
|
||||
VLLM_PORT: Optional[int] = None
|
||||
VLLM_RPC_PORT: int = 5570
|
||||
VLLM_USE_MODELSCOPE: bool = False
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
|
||||
VLLM_INSTANCE_ID: Optional[str] = None
|
||||
@ -140,6 +141,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: int(os.getenv('VLLM_PORT', '0'))
|
||||
if 'VLLM_PORT' in os.environ else None,
|
||||
|
||||
# used when the frontend api server is running in multi-processing mode,
|
||||
# to communicate with the backend engine process over ZMQ.
|
||||
'VLLM_RPC_PORT':
|
||||
lambda: int(os.getenv('VLLM_PORT', '5570')),
|
||||
|
||||
# If true, will load models from ModelScope instead of Hugging Face Hub.
|
||||
# note that the value is true or false, not numbers
|
||||
"VLLM_USE_MODELSCOPE":
|
||||
|
@ -21,6 +21,8 @@ from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from lark import Lark
|
||||
from outlines import grammars
|
||||
from outlines.caching import cache
|
||||
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
@ -44,6 +46,23 @@ class BaseLogitsProcessor:
|
||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||
self._fsm_state[seq_id] = self._guide.get_next_state(
|
||||
state=self._fsm_state[last_seq_id], token_id=last_token)
|
||||
else:
|
||||
# Note: this is a hack.
|
||||
# Lark pickling does not work properly (silent failure),
|
||||
# which breaks the RPC (which uses python pickleing).
|
||||
# We need to find a better solution.
|
||||
# On the first time this is called, we simply re-create
|
||||
# the Lark object.
|
||||
if isinstance(self._guide, CFGGuide):
|
||||
self._guide.parser = Lark(
|
||||
self._guide.cfg_string,
|
||||
parser="lalr",
|
||||
lexer="contextual",
|
||||
propagate_positions=False,
|
||||
maybe_placeholders=False,
|
||||
regex=True,
|
||||
import_paths=[grammars.GRAMMAR_PATH],
|
||||
)
|
||||
|
||||
instruction = self._guide.get_next_instruction(
|
||||
state=self._fsm_state[seq_id])
|
||||
|
@ -60,7 +60,7 @@ def get_span_exporter(endpoint):
|
||||
OTLPSpanExporter)
|
||||
elif protocol == "http/protobuf":
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
|
||||
OTLPSpanExporter)
|
||||
OTLPSpanExporter) # type: ignore
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported OTLP protocol '{protocol}' is configured")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
TokenizerPoolConfig)
|
||||
from vllm.executor.ray_utils import ray
|
||||
|
||||
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
|
||||
@ -13,6 +14,22 @@ else:
|
||||
RayTokenizerGroupPool = None # type: ignore
|
||||
|
||||
|
||||
def init_tokenizer_from_configs(model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
enable_lora: bool):
|
||||
init_kwargs = dict(tokenizer_id=model_config.tokenizer,
|
||||
enable_lora=enable_lora,
|
||||
max_num_seqs=scheduler_config.max_num_seqs,
|
||||
max_input_length=None,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.tokenizer_revision)
|
||||
|
||||
return get_tokenizer_group(parallel_config.tokenizer_pool_config,
|
||||
**init_kwargs)
|
||||
|
||||
|
||||
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
**init_kwargs) -> BaseTokenizerGroup:
|
||||
tokenizer_cls: Type[BaseTokenizerGroup]
|
||||
|
@ -290,6 +290,10 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
class ProducerFinished:
|
||||
pass
|
||||
|
||||
|
||||
def merge_async_iterators(
|
||||
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
@ -298,9 +302,10 @@ def merge_async_iterators(
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
|
||||
Exception]] = asyncio.Queue()
|
||||
|
||||
finished = [False] * len(iterators)
|
||||
producers = len(iterators)
|
||||
|
||||
async def producer(i: int, iterator: AsyncIterator[T]):
|
||||
try:
|
||||
@ -308,7 +313,8 @@ def merge_async_iterators(
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
finished[i] = True
|
||||
# Signal to the consumer that we've finished
|
||||
await queue.put(ProducerFinished())
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
@ -316,9 +322,17 @@ def merge_async_iterators(
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
remaining = producers
|
||||
try:
|
||||
while not all(finished) or not queue.empty():
|
||||
while remaining or not queue.empty():
|
||||
# we think there is a race condition here
|
||||
item = await queue.get()
|
||||
|
||||
if isinstance(item, ProducerFinished):
|
||||
# Signal that a producer finished- not a real item
|
||||
remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
port = envs.VLLM_PORT
|
||||
def get_open_port(port: Optional[int] = None) -> int:
|
||||
if port is None:
|
||||
# Default behavior here is to return a port for multi-gpu communication
|
||||
port = envs.VLLM_PORT
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
|
Reference in New Issue
Block a user