[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)

This commit is contained in:
Nick Hill
2024-08-21 11:45:55 -04:00
committed by GitHub
parent dd3fa0e430
commit c75363fbc0
2 changed files with 101 additions and 21 deletions

View File

@ -1,14 +1,19 @@
import asyncio
import os
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Optional
import pytest
import pytest_asyncio
import torch
from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear
@ -118,15 +123,38 @@ async def test_new_requests_event():
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
def test_asyncio_run():
def start_engine():
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m"))
return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
@pytest_asyncio.fixture(scope="module")
async def async_engine():
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
func=start_engine)
try:
yield engine
finally:
engine.shutdown_background_loop()
del engine
await asyncio.sleep(0.1)
cleanup()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
# So we can share the async engine fixture between these tests
return False
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
async def run(prompt: str):
sampling_params = SamplingParams(
@ -134,17 +162,64 @@ def test_asyncio_run():
max_tokens=32,
)
async for output in engine.generate(prompt,
sampling_params,
request_id=prompt):
async for output in async_engine.generate(prompt,
sampling_params,
request_id=prompt):
final_output = output
return final_output
async def generate():
return await asyncio.gather(
run("test0"),
run("test1"),
)
results = asyncio.run(generate())
results = await asyncio.gather(
run("test0"),
run("test1"),
)
assert len(results) == 2
@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")
assert i == 5
@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)
stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
final_output = output
if i == 0:
# wait for generation to complete before consuming
# the remaining messages
await asyncio.sleep(1)
if i < 9:
assert not output.finished
i += 1
assert i == 10
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished

View File

@ -2,8 +2,8 @@ import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
import torch
from typing_extensions import assert_never
@ -85,9 +85,8 @@ class AsyncStream:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
@ -96,7 +95,7 @@ class AsyncStream:
if not self._finished:
self._finished = True
self._queue.put_nowait(
exception if exception is not None else STOP_ITERATION)
exception if self._is_raisable(exception) else STOP_ITERATION)
@property
def finished(self) -> bool:
@ -106,9 +105,9 @@ class AsyncStream:
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while not self._finished:
while True:
result = await self._queue.get()
if isinstance(result, Exception):
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
@ -117,6 +116,12 @@ class AsyncStream:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
class RequestTracker:
"""Synchronous abstraction for tracking requests."""