Support chat template and echo for chat API (#1756)

This commit is contained in:
Adam Brusselback
2023-11-30 19:43:13 -05:00
committed by GitHub
parent 05a38612b0
commit 66785cc05c
7 changed files with 440 additions and 181 deletions

View File

@ -107,6 +107,7 @@ OpenAI-Compatible Server
------------------------ ------------------------
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
Start the server: Start the server:
@ -122,7 +123,13 @@ Use model from www.modelscope.cn
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \ $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code $ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints. By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
.. code-block:: console
$ python -m vllm.entrypoints.openai.api_server \
$ --model facebook/opt-125m \
$ --chat-template ./examples/template_chatml.json
This server can be queried in the same format as OpenAI API. For example, list the models: This server can be queried in the same format as OpenAI API. For example, list the models:
@ -130,6 +137,9 @@ This server can be queried in the same format as OpenAI API. For example, list t
$ curl http://localhost:8000/v1/models $ curl http://localhost:8000/v1/models
Using OpenAI Completions API with vLLM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Query the model with input prompts: Query the model with input prompts:
.. code-block:: console .. code-block:: console
@ -156,3 +166,45 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
print("Completion result:", completion) print("Completion result:", completion)
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_. For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
Using OpenAI Chat API with vLLM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
Querying the model using OpenAI Chat API:
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to communicate with the model in a chat-like interface:
.. code-block:: console
$ curl http://localhost:8000/v1/chat/completions \
$ -H "Content-Type: application/json" \
$ -d '{
$ "model": "facebook/opt-125m",
$ "messages": [
$ {"role": "system", "content": "You are a helpful assistant."},
$ {"role": "user", "content": "Who won the world series in 2020?"}
$ ]
$ }'
Python Client Example:
Using the `openai` python package, you can also communicate with the model in a chat-like manner:
.. code-block:: python
import openai
# Set OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
chat_response = openai.ChatCompletion.create(
model="facebook/opt-125m",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Tell me a joke."},
]
)
print("Chat response:", chat_response)
For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation.

View File

@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}

View File

@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}

View File

@ -0,0 +1,30 @@
<#meta#>
- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }}
- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }}
<#system#>
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
<#chat#>
{% for message in messages %}
{% if message['role'] == 'user' %}
<#user#>
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'assistant' %}
<#bot#>
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'user_context' %}
<#user_context#>
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
<#bot#>
{% endif %}

View File

@ -0,0 +1,119 @@
from argparse import Namespace
from dataclasses import dataclass
import pytest
from fastapi.testclient import TestClient
from vllm.entrypoints.openai.api_server import *
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", None, True,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", None, False,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", "../../examples/template_chatml.jinja", True,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
("facebook/opt-125m", "../../examples/template_chatml.jinja", False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
]
TEST_MESSAGES = [
{
'role': 'user',
'content': 'Hello'
},
{
'role': 'assistant',
'content': 'Hi there!'
},
{
'role': 'user',
'content': 'What is the capital of'
},
]
client = TestClient(app)
@dataclass
class MockTokenizer:
chat_template = None
def test_load_chat_template():
# Testing chatml template
template = "../../examples/template_chatml.jinja"
mock_args = Namespace(chat_template=template)
tokenizer = MockTokenizer()
# Call the function with the mocked args
load_chat_template(mock_args, tokenizer)
template_content = tokenizer.chat_template
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
def test_no_load_chat_template():
# Testing chatml template
template = "../../examples/does_not_exist"
mock_args = Namespace(chat_template=template)
tokenizer = MockTokenizer()
# Call the function with the mocked args
load_chat_template(mock_args, tokenizer=tokenizer)
template_content = tokenizer.chat_template
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert template_content == """../../examples/does_not_exist"""
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model,template,add_generation_prompt,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT)
async def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_args = Namespace(chat_template=template)
load_chat_template(mock_args, tokenizer)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES,
add_generation_prompt=add_generation_prompt)
# Call the function and get the result
result = tokenizer.apply_chat_template(
conversation=mock_request.messages,
tokenize=False,
add_generation_prompt=mock_request.add_generation_prompt)
# Test assertion
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
def test_health_endpoint():
response = client.get("/health")
assert response.status_code == 200

View File

@ -3,6 +3,7 @@
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import time import time
from http import HTTPStatus from http import HTTPStatus
@ -14,7 +15,6 @@ from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response from fastapi.responses import JSONResponse, StreamingResponse, Response
from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -31,20 +31,55 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
app = fastapi.FastAPI() app = fastapi.FastAPI()
engine = None engine = None
response_role = None
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
def create_error_response(status_code: HTTPStatus, def create_error_response(status_code: HTTPStatus,
@ -54,6 +89,25 @@ def create_error_response(status_code: HTTPStatus,
status_code=status_code.value) status_code=status_code.value)
def load_chat_template(args, tokenizer):
if args.chat_template is not None:
try:
with open(args.chat_template, "r") as f:
chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
chat_template = codecs.decode(args.chat_template, "unicode_escape")
tokenizer.chat_template = chat_template
logger.info(
f"Using supplied chat template:\n{tokenizer.chat_template}")
elif tokenizer.chat_template is not None:
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
else:
logger.warning("No chat template provided. Chat API will not work.")
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
@ -69,53 +123,6 @@ async def check_model(request) -> Optional[JSONResponse]:
return ret return ret
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length( async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
@ -207,7 +214,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
- function_call (Users should implement this by themselves) - function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine) - logit_bias (to be supported by vLLM engine)
""" """
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -217,7 +223,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported") "logit_bias is not currently supported")
prompt = await get_gen_prompt(request) try:
prompt = tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
logger.error(f"Error in applying chat template from request: {str(e)}")
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
token_ids, error_check_ret = await check_length(request, prompt=prompt) token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -225,6 +239,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic()) created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk"
try: try:
spaces_between_special_tokens = request.spaces_between_special_tokens spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -249,81 +264,106 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id, result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids) token_ids)
def create_stream_response_json( def get_role() -> str:
index: int, if request.add_generation_prompt:
text: str, return response_role
finish_reason: Optional[str] = None, else:
usage: Optional[UsageInfo] = None, return request.messages[-1]["role"]
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
if usage is not None:
response.usage = usage
# exclude unset to leave details out of each sse
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]: async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role # Send first response for each request.n (index) with the role
role = get_role()
for i in range(request.n): for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i, delta=DeltaMessage(role=role), finish_reason=None)
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(id=request_id, chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data], choices=[choice_data],
model=model_name) model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False) data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator: async for res in result_generator:
res: RequestOutput res: RequestOutput
for output in res.outputs: for output in res.outputs:
i = output.index i = output.index
if finish_reason_sent[i]:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text previous_texts[i] = output.text
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
previous_num_tokens[i] = completion_tokens previous_num_tokens[i] = completion_tokens
response_json = create_stream_response_json( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
text=delta_text, delta=DeltaMessage(content=delta_text),
) finish_reason=None)
yield f"data: {response_json}\n\n" chunk = ChatCompletionStreamResponse(
if output.finish_reason is not None: id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo( final_usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
response_json = create_stream_response_json( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i, delta=[], finish_reason=output.finish_reason)
text="", chunk = ChatCompletionStreamResponse(
finish_reason=output.finish_reason, id=request_id,
usage=final_usage, object=chunk_object_type,
) created=created_time,
yield f"data: {response_json}\n\n" choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.json(exclude_unset=True,
exclude_none=True,
ensure_ascii=False)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
# Streaming response async def completion_full_generator():
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
@ -333,15 +373,29 @@ async def create_chat_completion(request: ChatCompletionRequest,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
assert final_res is not None assert final_res is not None
choices = [] choices = []
role = get_role()
for output in final_res.outputs: for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
message=ChatMessage(role="assistant", content=output.text), message=ChatMessage(role=role, content=output.text),
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
) )
choices.append(choice_data) choices.append(choice_data)
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum( num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
@ -358,20 +412,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
usage=usage, usage=usage,
) )
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response return response
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
else:
return await completion_full_generator()
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
@ -642,34 +691,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( args = parse_args()
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -686,6 +708,8 @@ if __name__ == "__main__":
else: else:
served_model = args.model served_model = args.model
response_role = args.response_role
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())
@ -696,6 +720,7 @@ if __name__ == "__main__":
engine_model_config.tokenizer, engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode, tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code) trust_remote_code=engine_model_config.trust_remote_code)
load_chat_template(args, tokenizer)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,

View File

@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel):
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
add_generation_prompt: Optional[bool] = True
echo: Optional[bool] = False
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):