[Frontend] Add tokenize/detokenize endpoints (#5054)

This commit is contained in:
sasha0552
2024-06-26 16:54:22 +00:00
committed by GitHub
parent 5bfd1bbc98
commit c54269d967
5 changed files with 143 additions and 6 deletions

View File

@ -9,6 +9,7 @@ import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
@ -1366,5 +1367,53 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize(server, client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(server, client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)
response = requests.post(base_url + "detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -19,10 +19,17 @@ import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingRequest, ErrorResponse)
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@ -85,6 +92,28 @@ async def health() -> Response:
return Response(status_code=200)
@app.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump())
@app.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump())
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()

View File

@ -699,3 +699,24 @@ class BatchRequestOutput(OpenAIBaseModel):
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]
class TokenizeRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeResponse(OpenAIBaseModel):
tokens: List[int]
count: int
max_model_len: int
class DetokenizeRequest(OpenAIBaseModel):
model: str
tokens: List[int]
class DetokenizeResponse(OpenAIBaseModel):
prompt: str

View File

@ -16,7 +16,11 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
@ -442,3 +446,29 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)
return DetokenizeResponse(prompt=input_text)

View File

@ -10,9 +10,10 @@ from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission)
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
@ -99,8 +100,9 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
@ -126,7 +128,8 @@ class OpenAIServing:
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
EmbeddingRequest],
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
@ -174,6 +177,11 @@ class OpenAIServing:
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(