mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Perf] Simpler request output queues (#15156)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
@ -11,11 +11,13 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore)
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_output_collector():
|
||||
NUM_REQS = 3
|
||||
TEXT = "a"
|
||||
|
||||
def make_outputs() -> list[RequestOutput]:
|
||||
return [
|
||||
RequestOutput(
|
||||
request_id="my-request-id",
|
||||
prompt=None,
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_logprobs=None,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
index=0,
|
||||
text=TEXT,
|
||||
token_ids=[idx],
|
||||
cumulative_logprob=(idx + 1 * 1.0),
|
||||
logprobs=[{
|
||||
"a": idx,
|
||||
"b": idx
|
||||
}],
|
||||
finish_reason="length" if
|
||||
(idx == NUM_REQS - 1) else None,
|
||||
)
|
||||
],
|
||||
finished=(idx == NUM_REQS - 1),
|
||||
) for idx in range(NUM_REQS)
|
||||
]
|
||||
|
||||
collector = RequestOutputCollector(RequestOutputKind.DELTA)
|
||||
|
||||
# CASE 1: Put then get.
|
||||
outputs = make_outputs()
|
||||
collector.put(outputs[0])
|
||||
output = await collector.get()
|
||||
assert not collector.ready.is_set()
|
||||
assert collector.output is None
|
||||
assert output.outputs[0].text == "a"
|
||||
assert output.outputs[0].token_ids == [0]
|
||||
|
||||
# CASE 2: 2 puts then get.
|
||||
num_to_put = 2
|
||||
outputs = make_outputs()
|
||||
for i in range(num_to_put):
|
||||
collector.put(outputs[i])
|
||||
output = await collector.get()
|
||||
assert not collector.ready.is_set()
|
||||
assert collector.output is None
|
||||
|
||||
assert not output.finished
|
||||
# Text, token_ids, and logprobs should get merged.
|
||||
assert output.outputs[0].text == TEXT * num_to_put
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
|
||||
list(range(num_to_put))):
|
||||
assert tok_0 == tok_1
|
||||
assert len(output.outputs[0].logprobs) == num_to_put
|
||||
|
||||
# Cumulative logprobs should be the last one.
|
||||
cumulative_logprob_expected = 1.0 * num_to_put
|
||||
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
|
||||
|
||||
# CASE 3: Put all 3 (including a finished).
|
||||
num_to_put = 3
|
||||
outputs = make_outputs()
|
||||
for i in range(num_to_put):
|
||||
collector.put(outputs[i])
|
||||
output = await collector.get()
|
||||
assert not collector.ready.is_set()
|
||||
assert collector.output is None
|
||||
|
||||
assert output.finished
|
||||
assert output.outputs[0].finish_reason == "length"
|
||||
# Text, token_ids, and logprobs should get merged.
|
||||
assert output.outputs[0].text == TEXT * num_to_put
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
|
||||
list(range(num_to_put))):
|
||||
assert tok_0 == tok_1
|
||||
assert len(output.outputs[0].logprobs) == num_to_put
|
||||
|
||||
# Cumulative logprobs should be the last one.
|
||||
cumulative_logprob_expected = 1.0 * num_to_put
|
||||
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
|
||||
|
@ -21,14 +21,15 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv, kill_process_tree
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
@ -176,11 +177,14 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> asyncio.Queue[RequestOutput]:
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
# Create a new output queue for the request.
|
||||
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||
assert isinstance(params, SamplingParams), \
|
||||
"Pooling is not supported in V1"
|
||||
|
||||
# Create a new output collector for the request.
|
||||
queue = RequestOutputCollector(output_kind=params.output_kind)
|
||||
|
||||
# Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
@ -189,17 +193,15 @@ class AsyncLLM(EngineClient):
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
if params.n == 1:
|
||||
await self._add_request(request, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
for idx in range(params.n):
|
||||
request_id, params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request = request if idx == params.n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
await self._add_request(child_request, parent_request, idx, queue)
|
||||
@ -207,7 +209,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: asyncio.Queue[RequestOutput]):
|
||||
queue: RequestOutputCollector):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, index, queue)
|
||||
@ -272,15 +274,7 @@ class AsyncLLM(EngineClient):
|
||||
while not finished:
|
||||
# Note: drain queue without await if possible (avoids
|
||||
# task switching under load which helps performance).
|
||||
out = q.get_nowait() if not q.empty() else await q.get()
|
||||
|
||||
# Coalesce any additional queued outputs
|
||||
while not q.empty():
|
||||
next_out = q.get_nowait()
|
||||
if sampling_params.output_kind == RequestOutputKind.DELTA:
|
||||
out.add(next_out)
|
||||
else:
|
||||
out = next_out
|
||||
out = q.get_nowait() or await q.get()
|
||||
|
||||
# Note: both OutputProcessor and EngineCore handle their
|
||||
# own request cleanup based on finished.
|
||||
|
@ -17,6 +17,46 @@ from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
|
||||
RequestStateStats)
|
||||
|
||||
|
||||
class RequestOutputCollector:
|
||||
"""
|
||||
Collects streamed RequestOutputs per individual request,
|
||||
for hand-off to the consuming asyncio generate task.
|
||||
|
||||
When streaming deltas, RequestOutputs are merged if the
|
||||
producer gets ahead of the consumer.
|
||||
"""
|
||||
|
||||
def __init__(self, output_kind: RequestOutputKind):
|
||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||
self.output: Optional[RequestOutput] = None
|
||||
self.ready = asyncio.Event()
|
||||
|
||||
def put(self, output: RequestOutput) -> None:
|
||||
if self.output is None:
|
||||
self.output = output
|
||||
self.ready.set()
|
||||
elif self.aggregate:
|
||||
# Coalesce the outputs in delta case.
|
||||
self.output.add(output)
|
||||
else:
|
||||
# Just replace latest in non-delta case.
|
||||
self.output = output
|
||||
|
||||
async def get(self) -> RequestOutput:
|
||||
while (output := self.output) is None:
|
||||
await self.ready.wait()
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
return output
|
||||
|
||||
def get_nowait(self) -> Optional[RequestOutput]:
|
||||
output = self.output
|
||||
if output is not None:
|
||||
self.output = None
|
||||
self.ready.clear()
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
|
||||
@ -39,7 +79,7 @@ class RequestState:
|
||||
detokenizer: IncrementalDetokenizer,
|
||||
max_tokens_param: Optional[int],
|
||||
arrival_time: float,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
):
|
||||
self.request_id = request_id
|
||||
@ -66,7 +106,7 @@ class RequestState:
|
||||
request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest],
|
||||
request_index: int,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
if not request.sampling_params.detokenize:
|
||||
@ -217,7 +257,7 @@ class OutputProcessor:
|
||||
request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest] = None,
|
||||
request_index: int = 0,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]] = None,
|
||||
queue: Optional[RequestOutputCollector] = None,
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
if request_id in self.request_states:
|
||||
@ -300,7 +340,7 @@ class OutputProcessor:
|
||||
new_token_ids, finish_reason, stop_reason):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put_nowait(request_output)
|
||||
req_state.queue.put(request_output)
|
||||
else:
|
||||
# LLMEngine: return list of RequestOutputs.
|
||||
request_outputs.append(request_output)
|
||||
|
Reference in New Issue
Block a user