[Core] Add engine option to return only deltas or final output (#7381)

This commit is contained in:
Nick Hill
2024-09-12 20:02:00 +01:00
committed by GitHub
parent a6c0f3658d
commit 551ce01078
10 changed files with 371 additions and 137 deletions

View File

@ -50,6 +50,7 @@ steps:
- tests/worker
commands:
- pytest -v -s async_engine # Async Engine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py
- pytest -v -s multimodal
- pytest -v -s test_utils.py # Utils

View File

@ -1,7 +1,10 @@
import asyncio
import os
import uuid
from asyncio import CancelledError
from copy import copy
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional
import pytest
import pytest_asyncio
@ -11,6 +14,7 @@ from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from vllm.sampling_params import RequestOutputKind
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear
@ -122,8 +126,17 @@ def start_engine():
timeout_s=60,
)
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
AsyncEngineArgs(model="facebook/opt-125m",
enforce_eager=True,
num_scheduler_steps=num_scheduler_steps))
def uid() -> str:
return str(uuid.uuid4())
@pytest_asyncio.fixture(scope="module")
@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
async def run(prompt: str):
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
)
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
sampling_params,
request_id=prompt):
request_id=uid()):
output_count += 1
final_output = output
return final_output
return final_output, output_count
results = await asyncio.gather(
run("test0"),
run("test1"),
run("test0"),
)
assert len(results) == 2
first, second = results
# remove nondeterministic fields for comparison
first[0].metrics = None
second[0].metrics = None
first[0].request_id = None
second[0].request_id = None
assert str(first) == str(second)
output_count = results[0][1]
if num_scheduler_steps == 1:
assert output_count == 32
else:
assert 1 < output_count < 32
@pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
)
async def run(prompt: str, kind: RequestOutputKind):
params = copy(sampling_params)
params.output_kind = kind
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
output_count += 1
final_output = output
assert final_output is not None
return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)
async def run_deltas(prompt: str):
params = copy(sampling_params)
params.output_kind = RequestOutputKind.DELTA
prompt_tokens = None
output_tokens: List[int] = []
output_text = ""
output_count = 0
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text
# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
prompt_tokens = output.prompt_token_ids
output_tokens.extend(token_ids)
output_text += text
output_count += 1
return prompt_tokens, output_tokens, output_text, output_count
results = await asyncio.gather(
run("common input prompt", RequestOutputKind.CUMULATIVE),
run("common input prompt", RequestOutputKind.FINAL_ONLY),
run_deltas("common input prompt"))
# Make sure outputs are the same
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
assert len(prompt_set) == 1
text_set = set(text for _, _, text, _ in results)
assert len(text_set) == 1
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
assert len(tokens_set) == 1
cumulative, final, deltas = results
# output message counts
assert cumulative[3] == deltas[3]
if num_scheduler_steps == 1:
assert cumulative[3] == 32
else:
assert 1 < cumulative[3] < 32
assert final[3] == 1
@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
min_tokens=13,
max_tokens=13,
)
stop_at = 5 if num_scheduler_steps == 1 else 1
request_id = uid()
i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
request_id=request_id):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")
if i == stop_at:
await async_engine.abort(request_id)
assert i == 5
assert i == stop_at
@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
stream = async_engine.generate("test3", sampling_params, request_id=uid())
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:

View File

@ -39,7 +39,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
@ -225,9 +225,6 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
@ -295,7 +292,6 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
@ -1273,7 +1269,7 @@ class LLMEngine:
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
"""
now = time.time()
@ -1378,7 +1374,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if request_output:
ctx.request_outputs.append(request_output)
# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
@ -1415,14 +1412,19 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(seq_group)
if request_output:
ctx.request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if request_output:
ctx.request_outputs.append(request_output)
# Immediately process request outputs here (if callback is given)
if (ctx.request_outputs

View File

@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -642,14 +642,12 @@ class LLM:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options)
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
@ -709,9 +707,6 @@ class LLM:
f"output: {0:.2f} toks/s"),
)
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
@ -724,6 +719,7 @@ class LLM:
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
@ -735,9 +731,6 @@ class LLM:
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.

View File

@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)
@model_validator(mode="before")
@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)
@model_validator(mode="before")

View File

@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1]["role"]
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self,
@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[List[List[int]]]
if tool_choice_auto:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
try:
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
assert output.logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
token_ids=output.token_ids,
top_logprobs=output.logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
delta_message: Optional[DeltaMessage] = None
delta_text = output.text
delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is
ChatCompletionNamedToolChoiceParam):
if tool_choice_function_name:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name,
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request)
and tool_parser):
elif tool_choice_auto:
assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i],
current_text=output.text,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids= \
output.token_ids[
:-1 * len(delta_token_ids)
],
current_token_ids=output.token_ids,
delta_token_ids=delta_token_ids
)
)
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_num_tokens[i] += len(output.token_ids)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}))
# get what we've streamed so for for arguments
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
])
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
or "")
choice.message.content = full_message
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message
output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
)

View File

@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
try:
async for prompt_idx, res in result_generator:
@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_token_ids is not None
assert prompt_text is not None
# only return the prompt
delta_text = prompt_text
@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
if request.logprobs is not None:
assert out_logprobs is not None, (
@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]),
initial_text_offset=previous_text_lens[i],
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids)
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
@ -92,7 +93,7 @@ class RequestOutput:
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
@ -113,19 +114,26 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
if seq_group.sampling_params is None:
def from_seq_group(cls,
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = seq_group.sampling_params.n
if seq_group.sampling_params.use_beam_search:
n = sampling_params.n
if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.data._output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
]
include_logprobs = sampling_params.logprobs is not None
text_buffer_length = sampling_params.output_text_buffer_length
delta = sampling_params.output_kind == RequestOutputKind.DELTA
outputs = []
include_prompt = True
for seq in top_n_seqs:
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta)
output_logprobs = seq.output_logprobs if include_logprobs else None
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
include_prompt = False
outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))
# Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished()
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,

View File

@ -1,6 +1,6 @@
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from."""
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
@ -147,6 +156,7 @@ class SamplingParams(
logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input.
# They are set in post_init.
@ -182,6 +192,7 @@ class SamplingParams(
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
) -> "SamplingParams":
return SamplingParams(
n=1 if n is None else n,
@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
)
def __post_init__(self) -> None:
@ -317,6 +329,9 @@ class SamplingParams(
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1:

View File

@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Optional, Set, Tuple, Union, cast)
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
import msgspec
import torch
@ -407,6 +408,10 @@ class Sequence:
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_text_offset: int = 0
# Used for incremental detokenization
self.prefix_offset = 0
self.read_offset = 0
@ -462,11 +467,35 @@ class Sequence:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int):
def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
if not delta:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text) - buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size