mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Add engine option to return only deltas or final output (#7381)
This commit is contained in:
@ -50,6 +50,7 @@ steps:
|
||||
- tests/worker
|
||||
commands:
|
||||
- pytest -v -s async_engine # Async Engine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s test_utils.py # Utils
|
||||
|
@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -11,6 +14,7 @@ from vllm import SamplingParams
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput as RealRequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
|
||||
from ..conftest import cleanup
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
@ -122,8 +126,17 @@ def start_engine():
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
|
||||
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
|
||||
|
||||
return AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
|
||||
AsyncEngineArgs(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
num_scheduler_steps=num_scheduler_steps))
|
||||
|
||||
|
||||
def uid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_asyncio_run(async_engine):
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
async def run(prompt: str):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
)
|
||||
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=prompt):
|
||||
request_id=uid()):
|
||||
output_count += 1
|
||||
final_output = output
|
||||
return final_output
|
||||
return final_output, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test1"),
|
||||
run("test0"),
|
||||
)
|
||||
assert len(results) == 2
|
||||
first, second = results
|
||||
|
||||
# remove nondeterministic fields for comparison
|
||||
first[0].metrics = None
|
||||
second[0].metrics = None
|
||||
first[0].request_id = None
|
||||
second[0].request_id = None
|
||||
|
||||
assert str(first) == str(second)
|
||||
|
||||
output_count = results[0][1]
|
||||
if num_scheduler_steps == 1:
|
||||
assert output_count == 32
|
||||
else:
|
||||
assert 1 < output_count < 32
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_output_kinds(async_engine):
|
||||
"""Test that output_kind works as expected and that
|
||||
results are equivalent across different kinds."""
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
)
|
||||
|
||||
async def run(prompt: str, kind: RequestOutputKind):
|
||||
params = copy(sampling_params)
|
||||
params.output_kind = kind
|
||||
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
output_count += 1
|
||||
final_output = output
|
||||
|
||||
assert final_output is not None
|
||||
return (final_output.prompt_token_ids,
|
||||
final_output.outputs[0].token_ids,
|
||||
final_output.outputs[0].text, output_count)
|
||||
|
||||
async def run_deltas(prompt: str):
|
||||
params = copy(sampling_params)
|
||||
params.output_kind = RequestOutputKind.DELTA
|
||||
|
||||
prompt_tokens = None
|
||||
output_tokens: List[int] = []
|
||||
output_text = ""
|
||||
output_count = 0
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
text = output.outputs[0].text
|
||||
|
||||
# Ensure we get prompt ids iff we haven't yet received output tokens
|
||||
if output_tokens:
|
||||
assert 1 <= len(token_ids) <= num_scheduler_steps
|
||||
assert text
|
||||
assert not output.prompt_token_ids
|
||||
else:
|
||||
assert output.prompt_token_ids
|
||||
prompt_tokens = output.prompt_token_ids
|
||||
|
||||
output_tokens.extend(token_ids)
|
||||
output_text += text
|
||||
|
||||
output_count += 1
|
||||
return prompt_tokens, output_tokens, output_text, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
run("common input prompt", RequestOutputKind.CUMULATIVE),
|
||||
run("common input prompt", RequestOutputKind.FINAL_ONLY),
|
||||
run_deltas("common input prompt"))
|
||||
|
||||
# Make sure outputs are the same
|
||||
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
|
||||
assert len(prompt_set) == 1
|
||||
|
||||
text_set = set(text for _, _, text, _ in results)
|
||||
assert len(text_set) == 1
|
||||
|
||||
tokens_set = set(tuple(ids) for _, ids, _, _ in results)
|
||||
assert len(tokens_set) == 1
|
||||
|
||||
cumulative, final, deltas = results
|
||||
|
||||
# output message counts
|
||||
assert cumulative[3] == deltas[3]
|
||||
|
||||
if num_scheduler_steps == 1:
|
||||
assert cumulative[3] == 32
|
||||
else:
|
||||
assert 1 < cumulative[3] < 32
|
||||
|
||||
assert final[3] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_cancellation(async_engine):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
min_tokens=13,
|
||||
max_tokens=13,
|
||||
)
|
||||
|
||||
stop_at = 5 if num_scheduler_steps == 1 else 1
|
||||
|
||||
request_id = uid()
|
||||
|
||||
i = 0
|
||||
with pytest.raises(CancelledError):
|
||||
async for output in async_engine.generate("test2",
|
||||
sampling_params,
|
||||
request_id="test2"):
|
||||
request_id=request_id):
|
||||
assert not output.finished
|
||||
i += 1
|
||||
if i == 5:
|
||||
await async_engine.abort("test2")
|
||||
if i == stop_at:
|
||||
await async_engine.abort(request_id)
|
||||
|
||||
assert i == 5
|
||||
assert i == stop_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_delayed_generator(async_engine):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
|
||||
if scheduler_config.num_scheduler_steps != 1:
|
||||
pytest.skip("no need to test this one with multistep")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
stream = async_engine.generate("test3",
|
||||
sampling_params,
|
||||
request_id="test3")
|
||||
stream = async_engine.generate("test3", sampling_params, request_id=uid())
|
||||
i = 0
|
||||
final_output: Optional[RealRequestOutput] = None
|
||||
async for output in stream:
|
||||
|
@ -39,7 +39,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
@ -225,9 +225,6 @@ class LLMEngine:
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
# To improve performance, only final requests outputs may be required.
|
||||
# If this set to true, then no intermediate outputs will be returned.
|
||||
step_return_finished_only: bool = False,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
@ -295,7 +292,6 @@ class LLMEngine:
|
||||
self.observability_config = observability_config or ObservabilityConfig(
|
||||
)
|
||||
self.log_stats = log_stats
|
||||
self.step_return_finished_only = step_return_finished_only
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
@ -1273,7 +1269,7 @@ class LLMEngine:
|
||||
|
||||
ctx: The virtual engine context to work on
|
||||
request_id: If provided, then only this request is going to be processed
|
||||
|
||||
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
@ -1378,7 +1374,8 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
ctx.request_outputs.append(request_output)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# When we process a single request, we skip it for the next time,
|
||||
# and invoke the request output callback (if there was final output)
|
||||
@ -1415,14 +1412,19 @@ class LLMEngine:
|
||||
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.maybe_set_first_token_time(now)
|
||||
if (seq_group.is_finished()
|
||||
if self.step_return_finished_only else True):
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups:
|
||||
params = seq_group.sampling_params
|
||||
if params is not None and params.output_kind == (
|
||||
RequestOutputKind.DELTA) and not seq_group.is_finished():
|
||||
continue
|
||||
|
||||
request_output = RequestOutputFactory.create(seq_group)
|
||||
ctx.request_outputs.append(request_output)
|
||||
if request_output:
|
||||
ctx.request_outputs.append(request_output)
|
||||
|
||||
# Immediately process request outputs here (if callback is given)
|
||||
if (ctx.request_outputs
|
||||
|
@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@ -642,14 +642,12 @@ class LLM:
|
||||
raise ValueError("The lengths of prompts and lora_request "
|
||||
"must be the same.")
|
||||
|
||||
if isinstance(params, list):
|
||||
params = [
|
||||
self._add_guided_processor(param, guided_options)
|
||||
if isinstance(param, SamplingParams) else param
|
||||
for param in params
|
||||
]
|
||||
elif isinstance(params, SamplingParams):
|
||||
params = self._add_guided_processor(params, guided_options)
|
||||
for sp in params if isinstance(params, list) else (params, ):
|
||||
if isinstance(sp, SamplingParams):
|
||||
self._add_guided_processor(sp, guided_options)
|
||||
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
@ -709,9 +707,6 @@ class LLM:
|
||||
f"output: {0:.2f} toks/s"),
|
||||
)
|
||||
|
||||
# In the loop below, only finished outputs are used
|
||||
self.llm_engine.step_return_finished_only = True
|
||||
|
||||
# Run the engine.
|
||||
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
|
||||
total_in_toks = 0
|
||||
@ -724,6 +719,7 @@ class LLM:
|
||||
if use_tqdm:
|
||||
if isinstance(output, RequestOutput):
|
||||
# Calculate tokens only for RequestOutput
|
||||
assert output.prompt_token_ids is not None
|
||||
total_in_toks += len(output.prompt_token_ids)
|
||||
in_spd = total_in_toks / pbar.format_dict["elapsed"]
|
||||
total_out_toks += sum(
|
||||
@ -735,9 +731,6 @@ class LLM:
|
||||
f"output: {out_spd:.2f} toks/s")
|
||||
pbar.update(1)
|
||||
|
||||
# Restore original behavior
|
||||
self.llm_engine.step_return_finished_only = False
|
||||
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
# Sort the outputs by request ID.
|
||||
|
@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request.messages[-1]["role"]
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
num_prompt_tokens = 0
|
||||
|
||||
tool_parser: Optional[ToolParser] = self.tool_parser(
|
||||
tokenizer) if self.tool_parser else None
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
else:
|
||||
tool_choice_function_name = None
|
||||
|
||||
# Determine whether tools are in use with "auto" tool choice
|
||||
tool_choice_auto = (
|
||||
not tool_choice_function_name
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[List[List[int]]]
|
||||
if tool_choice_auto:
|
||||
# These are only required in "auto" tool choice case
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and request.stream_options.include_usage):
|
||||
# if continuous usage stats are requested, add it
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=prompt_tokens)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
# otherwise don't
|
||||
else:
|
||||
@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request.stream_options.include_usage):
|
||||
if (request.stream_options.
|
||||
continuous_usage_stats):
|
||||
prompt_tokens = len(
|
||||
res.prompt_token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=prompt_tokens)
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
out_logprobs = output.logprobs[
|
||||
previous_num_tokens[i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
assert output.logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
token_ids=output.token_ids,
|
||||
top_logprobs=output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
delta_message: Optional[DeltaMessage] = None
|
||||
delta_text = output.text
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if (request.tool_choice and type(request.tool_choice) is
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if tool_choice_function_name:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
])
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
elif (self._should_stream_with_auto_tool_parsing(request)
|
||||
and tool_parser):
|
||||
elif tool_choice_auto:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
assert tool_parser is not None
|
||||
#TODO optimize manipulation of these lists
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + list(
|
||||
output.token_ids)
|
||||
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_texts[i],
|
||||
current_text=output.text,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids= \
|
||||
output.token_ids[
|
||||
:-1 * len(delta_token_ids)
|
||||
],
|
||||
current_token_ids=output.token_ids,
|
||||
delta_token_ids=delta_token_ids
|
||||
)
|
||||
)
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# handle usage stats if requested & if continuous
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats):
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens +
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}))
|
||||
|
||||
# get what we've streamed so for for arguments
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
model=model_name)
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats):
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens +
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# is sent, send the usage
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
self.enable_auto_tools and self.tool_parser and delta_message
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
and output.finish_reason is not None
|
||||
)
|
||||
|
@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices * num_prompts
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[
|
||||
int, Logprob]]]]
|
||||
@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
has_echoed[i] = True
|
||||
elif (request.echo and request.max_tokens > 0
|
||||
and not has_echoed[i]):
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
delta_token_ids = output.token_ids[
|
||||
previous_num_tokens[i]:]
|
||||
out_logprobs = output.logprobs[previous_num_tokens[
|
||||
i]:] if output.logprobs else None
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats
|
||||
or output.finish_reason is not None):
|
||||
prompt_tokens = len(prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
num_generated_tokens += sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
|
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
SequenceGroup, SequenceStatus)
|
||||
|
||||
@ -92,7 +93,7 @@ class RequestOutput:
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
finished: bool,
|
||||
@ -113,19 +114,26 @@ class RequestOutput:
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
|
||||
@classmethod
|
||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||
if seq_group.sampling_params is None:
|
||||
def from_seq_group(cls,
|
||||
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
|
||||
sampling_params = seq_group.sampling_params
|
||||
if sampling_params is None:
|
||||
raise ValueError(
|
||||
"Sampling parameters are missing for a CompletionRequest.")
|
||||
finished = seq_group.is_finished()
|
||||
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
|
||||
not finished):
|
||||
return None
|
||||
|
||||
seqs = seq_group.get_seqs()
|
||||
if len(seqs) == 1:
|
||||
top_n_seqs = seqs
|
||||
else:
|
||||
# Get the top-n sequences.
|
||||
n = seq_group.sampling_params.n
|
||||
if seq_group.sampling_params.use_beam_search:
|
||||
n = sampling_params.n
|
||||
if sampling_params.use_beam_search:
|
||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||
seq_group.sampling_params.length_penalty)
|
||||
sampling_params.length_penalty)
|
||||
else:
|
||||
sorting_key = lambda seq: seq.get_cumulative_logprob()
|
||||
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
|
||||
@ -135,26 +143,49 @@ class RequestOutput:
|
||||
# NOTE: We need omit logprobs here explicitly because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
include_logprobs = seq_group.sampling_params.logprobs is not None
|
||||
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
|
||||
outputs = [
|
||||
CompletionOutput(
|
||||
seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.data._output_token_ids,
|
||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
seq.stop_reason) for seq in top_n_seqs
|
||||
]
|
||||
include_logprobs = sampling_params.logprobs is not None
|
||||
text_buffer_length = sampling_params.output_text_buffer_length
|
||||
delta = sampling_params.output_kind == RequestOutputKind.DELTA
|
||||
|
||||
outputs = []
|
||||
include_prompt = True
|
||||
for seq in top_n_seqs:
|
||||
output_text = seq.get_output_text_to_return(
|
||||
text_buffer_length, delta)
|
||||
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
||||
output_logprobs = seq.output_logprobs if include_logprobs else None
|
||||
|
||||
if delta:
|
||||
# Slice logprobs delta if applicable
|
||||
if output_logprobs:
|
||||
output_logprobs = output_logprobs[-len(output_token_ids):]
|
||||
# Don't include prompt if this is after the first output
|
||||
# containing decode token ids
|
||||
if include_prompt and seq.get_output_len() > len(
|
||||
output_token_ids):
|
||||
include_prompt = False
|
||||
|
||||
outputs.append(
|
||||
CompletionOutput(
|
||||
seqs.index(seq), output_text, output_token_ids,
|
||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||
output_logprobs,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
seq.stop_reason))
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = seq_group.prompt
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
encoder_prompt = seq_group.encoder_prompt
|
||||
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
|
||||
prompt_logprobs = seq_group.prompt_logprobs
|
||||
finished = seq_group.is_finished()
|
||||
if include_prompt:
|
||||
prompt = seq_group.prompt
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
encoder_prompt = seq_group.encoder_prompt
|
||||
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
|
||||
prompt_logprobs = seq_group.prompt_logprobs
|
||||
else:
|
||||
prompt = None
|
||||
prompt_token_ids = None
|
||||
encoder_prompt = None
|
||||
encoder_prompt_token_ids = None
|
||||
prompt_logprobs = None
|
||||
finished_time = time.time() if finished else None
|
||||
seq_group.set_finished_time(finished_time)
|
||||
return cls(seq_group.request_id,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
# Return only deltas in each RequestOutput
|
||||
DELTA = 1
|
||||
# Do not return intermediate RequestOuputs
|
||||
FINAL_ONLY = 2
|
||||
|
||||
|
||||
class SamplingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
@ -147,6 +156,7 @@ class SamplingParams(
|
||||
logits_processors: Optional[Any] = None
|
||||
include_stop_str_in_output: bool = False
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
||||
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
@ -182,6 +192,7 @@ class SamplingParams(
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
) -> "SamplingParams":
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
@ -213,6 +224,7 @@ class SamplingParams(
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -317,6 +329,9 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop.")
|
||||
if self.best_of != self.n and self.output_kind == (
|
||||
RequestOutputKind.DELTA):
|
||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
|
@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
|
||||
Optional, Set, Tuple, Union, cast)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union, cast
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
@ -407,6 +408,10 @@ class Sequence:
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
|
||||
# These are used to keep track of delta outputs
|
||||
self._last_token_ids_offset: int = 0
|
||||
self._last_output_text_offset: int = 0
|
||||
|
||||
# Used for incremental detokenization
|
||||
self.prefix_offset = 0
|
||||
self.read_offset = 0
|
||||
@ -462,11 +467,35 @@ class Sequence:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_output_text_to_return(self, buffer_length: int):
|
||||
def get_output_text_to_return(self, buffer_length: int,
|
||||
delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
# We return the full output text if the sequence is finished.
|
||||
truncate = buffer_length and not self.is_finished()
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
if not delta:
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
length = len(self.output_text) - buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
return self.output_text[last_offset:length]
|
||||
return ""
|
||||
|
||||
def get_output_token_ids_to_return(self,
|
||||
delta: bool) -> GenericSequence[int]:
|
||||
"""If delta is True, only new tokens since the last call to
|
||||
this method are returned"""
|
||||
if not delta:
|
||||
return self.get_output_token_ids()
|
||||
length = self.get_output_len()
|
||||
last_offset = self._last_token_ids_offset
|
||||
if last_offset < length:
|
||||
self._last_token_ids_offset = length
|
||||
return self.data._output_token_ids[last_offset:]
|
||||
return ()
|
||||
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# TODO This can produce incorrect hash when block size > prompt size
|
||||
|
Reference in New Issue
Block a user