mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Support chat template and echo
for chat API (#1756)
This commit is contained in:
@ -107,6 +107,7 @@ OpenAI-Compatible Server
|
|||||||
------------------------
|
------------------------
|
||||||
|
|
||||||
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
||||||
|
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||||
|
|
||||||
Start the server:
|
Start the server:
|
||||||
|
|
||||||
@ -122,7 +123,13 @@ Use model from www.modelscope.cn
|
|||||||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
|
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
|
||||||
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
|
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
|
||||||
|
|
||||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m vllm.entrypoints.openai.api_server \
|
||||||
|
$ --model facebook/opt-125m \
|
||||||
|
$ --chat-template ./examples/template_chatml.json
|
||||||
|
|
||||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||||
|
|
||||||
@ -130,6 +137,9 @@ This server can be queried in the same format as OpenAI API. For example, list t
|
|||||||
|
|
||||||
$ curl http://localhost:8000/v1/models
|
$ curl http://localhost:8000/v1/models
|
||||||
|
|
||||||
|
Using OpenAI Completions API with vLLM
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Query the model with input prompts:
|
Query the model with input prompts:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
@ -156,3 +166,45 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
|||||||
print("Completion result:", completion)
|
print("Completion result:", completion)
|
||||||
|
|
||||||
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||||
|
|
||||||
|
Using OpenAI Chat API with vLLM
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||||
|
|
||||||
|
Querying the model using OpenAI Chat API:
|
||||||
|
|
||||||
|
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to communicate with the model in a chat-like interface:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ curl http://localhost:8000/v1/chat/completions \
|
||||||
|
$ -H "Content-Type: application/json" \
|
||||||
|
$ -d '{
|
||||||
|
$ "model": "facebook/opt-125m",
|
||||||
|
$ "messages": [
|
||||||
|
$ {"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
$ {"role": "user", "content": "Who won the world series in 2020?"}
|
||||||
|
$ ]
|
||||||
|
$ }'
|
||||||
|
|
||||||
|
Python Client Example:
|
||||||
|
|
||||||
|
Using the `openai` python package, you can also communicate with the model in a chat-like manner:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import openai
|
||||||
|
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai.api_key = "EMPTY"
|
||||||
|
openai.api_base = "http://localhost:8000/v1"
|
||||||
|
chat_response = openai.ChatCompletion.create(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Tell me a joke."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print("Chat response:", chat_response)
|
||||||
|
|
||||||
|
For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation.
|
||||||
|
29
examples/template_alpaca.jinja
Normal file
29
examples/template_alpaca.jinja
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||||
|
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if message['role'] == 'user' %}
|
||||||
|
### Instruction:
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
### Response:
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% elif message['role'] == 'user_context' %}
|
||||||
|
### Input:
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||||
|
### Response:
|
||||||
|
{% endif %}
|
2
examples/template_chatml.jinja
Normal file
2
examples/template_chatml.jinja
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||||
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
30
examples/template_inkbot.jinja
Normal file
30
examples/template_inkbot.jinja
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
<#meta#>
|
||||||
|
- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }}
|
||||||
|
- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }}
|
||||||
|
<#system#>
|
||||||
|
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||||
|
<#chat#>
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if message['role'] == 'user' %}
|
||||||
|
<#user#>
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
<#bot#>
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% elif message['role'] == 'user_context' %}
|
||||||
|
<#user_context#>
|
||||||
|
{{ message['content']|trim -}}
|
||||||
|
{% if not loop.last %}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||||
|
<#bot#>
|
||||||
|
{% endif %}
|
119
tests/async_engine/test_openai_server.py
Normal file
119
tests/async_engine/test_openai_server.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from argparse import Namespace
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.api_server import *
|
||||||
|
|
||||||
|
# Define models, templates, and their corresponding expected outputs
|
||||||
|
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||||
|
("facebook/opt-125m", None, True,
|
||||||
|
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||||
|
("facebook/opt-125m", None, False,
|
||||||
|
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||||
|
("facebook/opt-125m", "../../examples/template_chatml.jinja", True,
|
||||||
|
"""<|im_start|>user
|
||||||
|
Hello<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
Hi there!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What is the capital of<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
"""),
|
||||||
|
("facebook/opt-125m", "../../examples/template_chatml.jinja", False,
|
||||||
|
"""<|im_start|>user
|
||||||
|
Hello<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
Hi there!<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
What is the capital of""")
|
||||||
|
]
|
||||||
|
|
||||||
|
TEST_MESSAGES = [
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Hello'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': 'Hi there!'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'What is the capital of'
|
||||||
|
},
|
||||||
|
]
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockTokenizer:
|
||||||
|
chat_template = None
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_chat_template():
|
||||||
|
# Testing chatml template
|
||||||
|
template = "../../examples/template_chatml.jinja"
|
||||||
|
mock_args = Namespace(chat_template=template)
|
||||||
|
tokenizer = MockTokenizer()
|
||||||
|
|
||||||
|
# Call the function with the mocked args
|
||||||
|
load_chat_template(mock_args, tokenizer)
|
||||||
|
|
||||||
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
|
# Test assertions
|
||||||
|
assert template_content is not None
|
||||||
|
# Hard coded value for template_chatml.jinja
|
||||||
|
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||||
|
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_load_chat_template():
|
||||||
|
# Testing chatml template
|
||||||
|
template = "../../examples/does_not_exist"
|
||||||
|
mock_args = Namespace(chat_template=template)
|
||||||
|
tokenizer = MockTokenizer()
|
||||||
|
|
||||||
|
# Call the function with the mocked args
|
||||||
|
load_chat_template(mock_args, tokenizer=tokenizer)
|
||||||
|
template_content = tokenizer.chat_template
|
||||||
|
|
||||||
|
# Test assertions
|
||||||
|
assert template_content is not None
|
||||||
|
# Hard coded value for template_chatml.jinja
|
||||||
|
assert template_content == """../../examples/does_not_exist"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,template,add_generation_prompt,expected_output",
|
||||||
|
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||||
|
async def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||||
|
expected_output):
|
||||||
|
# Initialize the tokenizer
|
||||||
|
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||||
|
|
||||||
|
mock_args = Namespace(chat_template=template)
|
||||||
|
load_chat_template(mock_args, tokenizer)
|
||||||
|
|
||||||
|
# Create a mock request object using keyword arguments
|
||||||
|
mock_request = ChatCompletionRequest(
|
||||||
|
model=model,
|
||||||
|
messages=TEST_MESSAGES,
|
||||||
|
add_generation_prompt=add_generation_prompt)
|
||||||
|
|
||||||
|
# Call the function and get the result
|
||||||
|
result = tokenizer.apply_chat_template(
|
||||||
|
conversation=mock_request.messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=mock_request.add_generation_prompt)
|
||||||
|
|
||||||
|
# Test assertion
|
||||||
|
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_endpoint():
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import codecs
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -14,7 +15,6 @@ from fastapi import Request
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
@ -31,20 +31,55 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
try:
|
|
||||||
import fastchat
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle
|
|
||||||
from fastchat.model.model_adapter import get_conversation_template
|
|
||||||
_fastchat_available = True
|
|
||||||
except ImportError:
|
|
||||||
_fastchat_available = False
|
|
||||||
|
|
||||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
served_model = None
|
served_model = None
|
||||||
app = fastapi.FastAPI()
|
app = fastapi.FastAPI()
|
||||||
engine = None
|
engine = None
|
||||||
|
response_role = None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
|
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||||
|
parser.add_argument("--allow-credentials",
|
||||||
|
action="store_true",
|
||||||
|
help="allow credentials")
|
||||||
|
parser.add_argument("--allowed-origins",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed origins")
|
||||||
|
parser.add_argument("--allowed-methods",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed methods")
|
||||||
|
parser.add_argument("--allowed-headers",
|
||||||
|
type=json.loads,
|
||||||
|
default=["*"],
|
||||||
|
help="allowed headers")
|
||||||
|
parser.add_argument("--served-model-name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The model name used in the API. If not "
|
||||||
|
"specified, the model name will be the same as "
|
||||||
|
"the huggingface name.")
|
||||||
|
parser.add_argument("--chat-template",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file path to the chat template, "
|
||||||
|
"or the template in single-line form "
|
||||||
|
"for the specified model")
|
||||||
|
parser.add_argument("--response-role",
|
||||||
|
type=str,
|
||||||
|
default="assistant",
|
||||||
|
help="The role name to return if "
|
||||||
|
"`request.add_generation_prompt=true`.")
|
||||||
|
|
||||||
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def create_error_response(status_code: HTTPStatus,
|
def create_error_response(status_code: HTTPStatus,
|
||||||
@ -54,6 +89,25 @@ def create_error_response(status_code: HTTPStatus,
|
|||||||
status_code=status_code.value)
|
status_code=status_code.value)
|
||||||
|
|
||||||
|
|
||||||
|
def load_chat_template(args, tokenizer):
|
||||||
|
if args.chat_template is not None:
|
||||||
|
try:
|
||||||
|
with open(args.chat_template, "r") as f:
|
||||||
|
chat_template = f.read()
|
||||||
|
except OSError:
|
||||||
|
# If opening a file fails, set chat template to be args to
|
||||||
|
# ensure we decode so our escape are interpreted correctly
|
||||||
|
chat_template = codecs.decode(args.chat_template, "unicode_escape")
|
||||||
|
|
||||||
|
tokenizer.chat_template = chat_template
|
||||||
|
logger.info(
|
||||||
|
f"Using supplied chat template:\n{tokenizer.chat_template}")
|
||||||
|
elif tokenizer.chat_template is not None:
|
||||||
|
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
|
||||||
|
else:
|
||||||
|
logger.warning("No chat template provided. Chat API will not work.")
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(_, exc):
|
async def validation_exception_handler(_, exc):
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||||
@ -69,53 +123,6 @@ async def check_model(request) -> Optional[JSONResponse]:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
async def get_gen_prompt(request) -> str:
|
|
||||||
if not _fastchat_available:
|
|
||||||
raise ModuleNotFoundError(
|
|
||||||
"fastchat is not installed. Please install fastchat to use "
|
|
||||||
"the chat completion and conversation APIs: `$ pip install fschat`"
|
|
||||||
)
|
|
||||||
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
|
||||||
raise ImportError(
|
|
||||||
f"fastchat version is low. Current version: {fastchat.__version__} "
|
|
||||||
"Please upgrade fastchat to use: `$ pip install -U fschat`")
|
|
||||||
|
|
||||||
conv = get_conversation_template(request.model)
|
|
||||||
conv = Conversation(
|
|
||||||
name=conv.name,
|
|
||||||
system_template=conv.system_template,
|
|
||||||
system_message=conv.system_message,
|
|
||||||
roles=conv.roles,
|
|
||||||
messages=list(conv.messages), # prevent in-place modification
|
|
||||||
offset=conv.offset,
|
|
||||||
sep_style=SeparatorStyle(conv.sep_style),
|
|
||||||
sep=conv.sep,
|
|
||||||
sep2=conv.sep2,
|
|
||||||
stop_str=conv.stop_str,
|
|
||||||
stop_token_ids=conv.stop_token_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(request.messages, str):
|
|
||||||
prompt = request.messages
|
|
||||||
else:
|
|
||||||
for message in request.messages:
|
|
||||||
msg_role = message["role"]
|
|
||||||
if msg_role == "system":
|
|
||||||
conv.system_message = message["content"]
|
|
||||||
elif msg_role == "user":
|
|
||||||
conv.append_message(conv.roles[0], message["content"])
|
|
||||||
elif msg_role == "assistant":
|
|
||||||
conv.append_message(conv.roles[1], message["content"])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown role: {msg_role}")
|
|
||||||
|
|
||||||
# Add a blank message for the assistant.
|
|
||||||
conv.append_message(conv.roles[1], None)
|
|
||||||
prompt = conv.get_prompt()
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
async def check_length(
|
async def check_length(
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
@ -207,7 +214,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
- function_call (Users should implement this by themselves)
|
- function_call (Users should implement this by themselves)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -217,7 +223,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
prompt = await get_gen_prompt(request)
|
try:
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
conversation=request.messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=request.add_generation_prompt)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in applying chat template from request: {str(e)}")
|
||||||
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@ -225,6 +239,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
model_name = request.model
|
model_name = request.model
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
created_time = int(time.monotonic())
|
created_time = int(time.monotonic())
|
||||||
|
chunk_object_type = "chat.completion.chunk"
|
||||||
try:
|
try:
|
||||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -249,81 +264,106 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
token_ids)
|
token_ids)
|
||||||
|
|
||||||
def create_stream_response_json(
|
def get_role() -> str:
|
||||||
index: int,
|
if request.add_generation_prompt:
|
||||||
text: str,
|
return response_role
|
||||||
finish_reason: Optional[str] = None,
|
else:
|
||||||
usage: Optional[UsageInfo] = None,
|
return request.messages[-1]["role"]
|
||||||
) -> str:
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
delta=DeltaMessage(content=text),
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
|
||||||
response = ChatCompletionStreamResponse(
|
|
||||||
id=request_id,
|
|
||||||
created=created_time,
|
|
||||||
model=model_name,
|
|
||||||
choices=[choice_data],
|
|
||||||
)
|
|
||||||
if usage is not None:
|
|
||||||
response.usage = usage
|
|
||||||
# exclude unset to leave details out of each sse
|
|
||||||
response_json = response.json(exclude_unset=True, ensure_ascii=False)
|
|
||||||
|
|
||||||
return response_json
|
|
||||||
|
|
||||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
# First chunk with role
|
# Send first response for each request.n (index) with the role
|
||||||
|
role = get_role()
|
||||||
for i in range(request.n):
|
for i in range(request.n):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
||||||
delta=DeltaMessage(role="assistant"),
|
|
||||||
finish_reason=None,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
model=model_name)
|
model=model_name)
|
||||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send response to echo the input portion of the last message
|
||||||
|
if request.echo:
|
||||||
|
last_msg_content = ""
|
||||||
|
if request.messages and isinstance(
|
||||||
|
request.messages, list) and request.messages[-1].get(
|
||||||
|
"content") and request.messages[-1].get(
|
||||||
|
"role") == role:
|
||||||
|
last_msg_content = request.messages[-1]["content"]
|
||||||
|
if last_msg_content:
|
||||||
|
for i in range(request.n):
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
delta=DeltaMessage(content=last_msg_content),
|
||||||
|
finish_reason=None)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# Send response for each token for each request.n (index)
|
||||||
previous_texts = [""] * request.n
|
previous_texts = [""] * request.n
|
||||||
previous_num_tokens = [0] * request.n
|
previous_num_tokens = [0] * request.n
|
||||||
|
finish_reason_sent = [False] * request.n
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
res: RequestOutput
|
res: RequestOutput
|
||||||
for output in res.outputs:
|
for output in res.outputs:
|
||||||
i = output.index
|
i = output.index
|
||||||
|
|
||||||
|
if finish_reason_sent[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if output.finish_reason is None:
|
||||||
|
# Send token-by-token response for each request.n
|
||||||
delta_text = output.text[len(previous_texts[i]):]
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
previous_texts[i] = output.text
|
previous_texts[i] = output.text
|
||||||
completion_tokens = len(output.token_ids)
|
completion_tokens = len(output.token_ids)
|
||||||
previous_num_tokens[i] = completion_tokens
|
previous_num_tokens[i] = completion_tokens
|
||||||
response_json = create_stream_response_json(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
text=delta_text,
|
delta=DeltaMessage(content=delta_text),
|
||||||
)
|
finish_reason=None)
|
||||||
yield f"data: {response_json}\n\n"
|
chunk = ChatCompletionStreamResponse(
|
||||||
if output.finish_reason is not None:
|
id=request_id,
|
||||||
|
object=chunk_object_type,
|
||||||
|
created=created_time,
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
else:
|
||||||
|
# Send the finish response for each request.n only once
|
||||||
prompt_tokens = len(res.prompt_token_ids)
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
final_usage = UsageInfo(
|
final_usage = UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
)
|
)
|
||||||
response_json = create_stream_response_json(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i, delta=[], finish_reason=output.finish_reason)
|
||||||
text="",
|
chunk = ChatCompletionStreamResponse(
|
||||||
finish_reason=output.finish_reason,
|
id=request_id,
|
||||||
usage=final_usage,
|
object=chunk_object_type,
|
||||||
)
|
created=created_time,
|
||||||
yield f"data: {response_json}\n\n"
|
choices=[choice_data],
|
||||||
|
model=model_name)
|
||||||
|
if final_usage is not None:
|
||||||
|
chunk.usage = final_usage
|
||||||
|
data = chunk.json(exclude_unset=True,
|
||||||
|
exclude_none=True,
|
||||||
|
ensure_ascii=False)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
finish_reason_sent[i] = True
|
||||||
|
# Send the final done message after all response.n are finished
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
# Streaming response
|
async def completion_full_generator():
|
||||||
if request.stream:
|
|
||||||
return StreamingResponse(completion_stream_generator(),
|
|
||||||
media_type="text/event-stream")
|
|
||||||
|
|
||||||
# Non-streaming response
|
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
async for res in result_generator:
|
async for res in result_generator:
|
||||||
if await raw_request.is_disconnected():
|
if await raw_request.is_disconnected():
|
||||||
@ -333,15 +373,29 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
"Client disconnected")
|
"Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
choices = []
|
choices = []
|
||||||
|
role = get_role()
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
message=ChatMessage(role="assistant", content=output.text),
|
message=ChatMessage(role=role, content=output.text),
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
)
|
)
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
if request.echo:
|
||||||
|
last_msg_content = ""
|
||||||
|
if request.messages and isinstance(
|
||||||
|
request.messages, list) and request.messages[-1].get(
|
||||||
|
"content") and request.messages[-1].get(
|
||||||
|
"role") == role:
|
||||||
|
last_msg_content = request.messages[-1]["content"]
|
||||||
|
|
||||||
|
for choice in choices:
|
||||||
|
full_message = last_msg_content + choice.message.content
|
||||||
|
choice.message.content = full_message
|
||||||
|
|
||||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
num_generated_tokens = sum(
|
num_generated_tokens = sum(
|
||||||
len(output.token_ids) for output in final_res.outputs)
|
len(output.token_ids) for output in final_res.outputs)
|
||||||
@ -358,20 +412,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.stream:
|
|
||||||
# When user requests streaming but we don't stream, we still need to
|
|
||||||
# return a streaming response with a single event.
|
|
||||||
response_json = response.json(ensure_ascii=False)
|
|
||||||
|
|
||||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
|
||||||
yield f"data: {response_json}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(fake_stream_generator(),
|
|
||||||
media_type="text/event-stream")
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(completion_stream_generator(),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return await completion_full_generator()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
@ -642,34 +691,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
args = parse_args()
|
||||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
|
||||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
|
||||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
|
||||||
parser.add_argument("--allow-credentials",
|
|
||||||
action="store_true",
|
|
||||||
help="allow credentials")
|
|
||||||
parser.add_argument("--allowed-origins",
|
|
||||||
type=json.loads,
|
|
||||||
default=["*"],
|
|
||||||
help="allowed origins")
|
|
||||||
parser.add_argument("--allowed-methods",
|
|
||||||
type=json.loads,
|
|
||||||
default=["*"],
|
|
||||||
help="allowed methods")
|
|
||||||
parser.add_argument("--allowed-headers",
|
|
||||||
type=json.loads,
|
|
||||||
default=["*"],
|
|
||||||
help="allowed headers")
|
|
||||||
parser.add_argument("--served-model-name",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The model name used in the API. If not "
|
|
||||||
"specified, the model name will be the same as "
|
|
||||||
"the huggingface name.")
|
|
||||||
|
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@ -686,6 +708,8 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
served_model = args.model
|
served_model = args.model
|
||||||
|
|
||||||
|
response_role = args.response_role
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
@ -696,6 +720,7 @@ if __name__ == "__main__":
|
|||||||
engine_model_config.tokenizer,
|
engine_model_config.tokenizer,
|
||||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||||
trust_remote_code=engine_model_config.trust_remote_code)
|
trust_remote_code=engine_model_config.trust_remote_code)
|
||||||
|
load_chat_template(args, tokenizer)
|
||||||
|
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||||
skip_special_tokens: Optional[bool] = True
|
skip_special_tokens: Optional[bool] = True
|
||||||
spaces_between_special_tokens: Optional[bool] = True
|
spaces_between_special_tokens: Optional[bool] = True
|
||||||
|
add_generation_prompt: Optional[bool] = True
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
|
Reference in New Issue
Block a user