mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Chore] Separate out vllm.utils.async_utils
(#26913)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -12,7 +12,7 @@ from vllm.entrypoints.openai.api_server import (
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
MODEL_PATH = "zai-org/chatglm3-6b"
|
||||
LORA_RANK = 64
|
||||
|
42
tests/utils_/test_async_utils.py
Normal file
42
tests/utils_/test_async_utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
async def _mock_async_iterator(idx: int):
|
||||
try:
|
||||
while True:
|
||||
yield f"item from iterator {idx}"
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_async_iterators():
|
||||
iterators = [_mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator = merge_async_iterators(*iterators)
|
||||
|
||||
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
print(f"idx: {idx}, output: {output}")
|
||||
|
||||
task = asyncio.create_task(stream_output(merged_iterator))
|
||||
await asyncio.sleep(0.5)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
for iterator in iterators:
|
||||
try:
|
||||
await asyncio.wait_for(anext(iterator), 1)
|
||||
except StopAsyncIteration:
|
||||
# All iterators should be cancelled and print this message.
|
||||
print("Iterator was cancelled normally")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
raise AssertionError() from e
|
@ -2,14 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -37,7 +35,6 @@ from vllm.utils import (
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
memory_profiling,
|
||||
merge_async_iterators,
|
||||
sha256,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
@ -48,39 +45,6 @@ from vllm.utils import (
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_async_iterators():
|
||||
async def mock_async_iterator(idx: int):
|
||||
try:
|
||||
while True:
|
||||
yield f"item from iterator {idx}"
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
iterators = [mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator = merge_async_iterators(*iterators)
|
||||
|
||||
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
print(f"idx: {idx}, output: {output}")
|
||||
|
||||
task = asyncio.create_task(stream_output(merged_iterator))
|
||||
await asyncio.sleep(0.5)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
for iterator in iterators:
|
||||
try:
|
||||
await asyncio.wait_for(anext(iterator), 1)
|
||||
except StopAsyncIteration:
|
||||
# All iterators should be cancelled and print this message.
|
||||
print("Iterator was cancelled normally")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
raise AssertionError() from e
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
|
@ -34,7 +34,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
|
@ -34,7 +34,8 @@ from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import as_list, merge_async_iterators
|
||||
from vllm.utils import as_list
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -40,6 +40,7 @@ from vllm.outputs import (
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import chunk_list
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -387,8 +388,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
)
|
||||
generators.append(generator)
|
||||
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
ctx.result_generator = merge_async_iterators(*generators)
|
||||
|
||||
return None
|
||||
|
@ -90,14 +90,13 @@ from vllm.tracing import (
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (
|
||||
from vllm.utils import is_list_of, random_uuid
|
||||
from vllm.utils.async_utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
is_list_of,
|
||||
make_async,
|
||||
merge_async_iterators,
|
||||
random_uuid,
|
||||
)
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -36,7 +36,7 @@ from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -37,8 +37,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -17,7 +17,7 @@ from vllm.inputs.data import TextPrompt as EngineTextPrompt
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
|
@ -20,12 +20,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (
|
||||
_run_task_with_lock,
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.utils.async_utils import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
if ray is not None:
|
||||
@ -748,3 +747,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
# Assume that the Ray workers are healthy.
|
||||
# TODO: check the health of the Ray workers
|
||||
return
|
||||
|
||||
|
||||
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
@ -38,10 +37,8 @@ from argparse import (
|
||||
RawDescriptionHelpFormatter,
|
||||
_ArgumentGroup,
|
||||
)
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Collection,
|
||||
Generator,
|
||||
@ -51,7 +48,6 @@ from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
@ -82,7 +78,6 @@ import zmq.asyncio
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import Never, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -223,278 +218,12 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
class AsyncMicrobatchTokenizer:
|
||||
"""Asynchronous tokenizer with micro-batching.
|
||||
|
||||
Pulls pending encode/decode requests from a queue and batches them
|
||||
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||
so the event loop stays responsive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queues: dict[
|
||||
tuple,
|
||||
asyncio.Queue[
|
||||
tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future]
|
||||
],
|
||||
] = {}
|
||||
self._batcher_tasks: list[asyncio.Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
result_future: asyncio.Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
result_future: asyncio.Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((token_ids, result_future))
|
||||
return await result_future
|
||||
|
||||
# === Internal helpers ===
|
||||
def _get_queue(
|
||||
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||
) -> asyncio.Queue[
|
||||
tuple[str, dict, asyncio.Future] | tuple[list[int], asyncio.Future]
|
||||
]:
|
||||
"""Get the request queue for the given operation key, creating a new
|
||||
queue and batcher task if needed."""
|
||||
queue = self._queues.get(key)
|
||||
if queue is None:
|
||||
self._queues[key] = queue = asyncio.Queue()
|
||||
if key[0] == "encode":
|
||||
can_batch = key[1] != "other"
|
||||
coro = self._batch_encode_loop(queue, can_batch)
|
||||
else:
|
||||
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
|
||||
coro = self._batch_decode_loop(queue)
|
||||
self._batcher_tasks.append(loop.create_task(coro))
|
||||
return queue
|
||||
|
||||
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||
"""Batch incoming encode requests for efficiency."""
|
||||
while True:
|
||||
prompt, kwargs, result_future = await queue.get()
|
||||
prompts = [prompt]
|
||||
kwargs_list = [kwargs]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(prompts) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
prompts.append(prompt)
|
||||
result_futures.append(result_future)
|
||||
if not can_batch:
|
||||
kwargs_list.append(kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# If every request uses identical kwargs we can run a single
|
||||
# batched tokenizer call for a big speed-up.
|
||||
if can_batch and len(prompts) > 1:
|
||||
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, batch_encode_fn
|
||||
)
|
||||
|
||||
for i, fut in enumerate(result_futures):
|
||||
if not fut.done():
|
||||
data = {k: v[i] for k, v in results.items()}
|
||||
fut.set_result(BatchEncoding(data))
|
||||
else:
|
||||
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
|
||||
]
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn
|
||||
)
|
||||
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||
"""Batch incoming decode requests for efficiency."""
|
||||
while True:
|
||||
token_ids, result_future = await queue.get()
|
||||
token_ids_list = [token_ids]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(token_ids_list) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
token_ids, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
token_ids_list.append(token_ids)
|
||||
result_futures.append(result_future)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Perform a single batched decode call for all requests
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, self.tokenizer.batch_decode, token_ids_list
|
||||
)
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||
"""
|
||||
Return a normalized key describing operation + kwargs.
|
||||
|
||||
- `add_special_tokens`: {True/False}
|
||||
- `truncation`: {True/False}
|
||||
- If `truncation` is False (`max_length` is None),
|
||||
returns a key for a can_batch queue.
|
||||
- If `truncation` is True and `max_length` is None or equals
|
||||
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||
- Otherwise, returns a key for a cannot_batch queue.
|
||||
|
||||
Examples:
|
||||
- Decode: ("decode",)
|
||||
- Encode typical:
|
||||
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||
- Fallback: ("encode", "other")
|
||||
"""
|
||||
|
||||
if op == "decode":
|
||||
return ("decode",)
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
if not truncation:
|
||||
return "encode", add_special_tokens, False, None
|
||||
|
||||
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||
if max_length is None or (model_max is not None and max_length == model_max):
|
||||
return "encode", add_special_tokens, True, "model_max"
|
||||
|
||||
return "encode", "other"
|
||||
|
||||
def __del__(self):
|
||||
if (
|
||||
(tasks := getattr(self, "_batcher_tasks", None))
|
||||
and (loop := getattr(self, "_loop", None))
|
||||
and not loop.is_closed()
|
||||
):
|
||||
|
||||
def cancel_tasks():
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(cancel_tasks)
|
||||
|
||||
|
||||
def cancel_task_threadsafe(task: Task):
|
||||
if task and not task.done():
|
||||
run_in_loop(task.get_loop(), task.cancel)
|
||||
|
||||
|
||||
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
|
||||
for sock in sockets:
|
||||
if sock is not None:
|
||||
sock.close(linger=0)
|
||||
|
||||
|
||||
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
|
||||
if in_loop(loop):
|
||||
function(*args)
|
||||
elif not loop.is_closed():
|
||||
loop.call_soon_threadsafe(function, *args)
|
||||
|
||||
|
||||
def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||
try:
|
||||
return asyncio.get_running_loop() == event_loop
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
if len(iterators) == 1:
|
||||
# Fast-path single iterator case.
|
||||
async for item in iterators[0]:
|
||||
yield 0, item
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
|
||||
try:
|
||||
while awaits:
|
||||
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[loop.create_task(anext(it))] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
|
||||
"""Collect all items from an async generator into a list."""
|
||||
items = []
|
||||
async for item in iterator:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = envs.VLLM_HOST_IP
|
||||
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
|
||||
@ -1803,12 +1532,6 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
return processed_args
|
||||
|
||||
|
||||
async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs):
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||
|
299
vllm/utils/async_utils.py
Normal file
299
vllm/utils/async_utils.py
Normal file
@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Contains helpers related to asynchronous code."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import TypeVar
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncMicrobatchTokenizer:
|
||||
"""Asynchronous tokenizer with micro-batching.
|
||||
|
||||
Pulls pending encode/decode requests from a queue and batches them
|
||||
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||
so the event loop stays responsive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queues: dict[
|
||||
tuple,
|
||||
asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]],
|
||||
] = {}
|
||||
self._batcher_tasks: list[Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((token_ids, result_future))
|
||||
return await result_future
|
||||
|
||||
# === Internal helpers ===
|
||||
def _get_queue(
|
||||
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||
) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]:
|
||||
"""Get the request queue for the given operation key, creating a new
|
||||
queue and batcher task if needed."""
|
||||
queue = self._queues.get(key)
|
||||
if queue is None:
|
||||
self._queues[key] = queue = asyncio.Queue()
|
||||
if key[0] == "encode":
|
||||
can_batch = key[1] != "other"
|
||||
coro = self._batch_encode_loop(queue, can_batch)
|
||||
else:
|
||||
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
|
||||
coro = self._batch_decode_loop(queue)
|
||||
self._batcher_tasks.append(loop.create_task(coro))
|
||||
return queue
|
||||
|
||||
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||
"""Batch incoming encode requests for efficiency."""
|
||||
while True:
|
||||
prompt, kwargs, result_future = await queue.get()
|
||||
prompts = [prompt]
|
||||
kwargs_list = [kwargs]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(prompts) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
prompts.append(prompt)
|
||||
result_futures.append(result_future)
|
||||
if not can_batch:
|
||||
kwargs_list.append(kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# If every request uses identical kwargs we can run a single
|
||||
# batched tokenizer call for a big speed-up.
|
||||
if can_batch and len(prompts) > 1:
|
||||
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, batch_encode_fn
|
||||
)
|
||||
|
||||
for i, fut in enumerate(result_futures):
|
||||
if not fut.done():
|
||||
data = {k: v[i] for k, v in results.items()}
|
||||
fut.set_result(BatchEncoding(data))
|
||||
else:
|
||||
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
|
||||
]
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn
|
||||
)
|
||||
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||
"""Batch incoming decode requests for efficiency."""
|
||||
while True:
|
||||
token_ids, result_future = await queue.get()
|
||||
token_ids_list = [token_ids]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(token_ids_list) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
token_ids, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
token_ids_list.append(token_ids)
|
||||
result_futures.append(result_future)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Perform a single batched decode call for all requests
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, self.tokenizer.batch_decode, token_ids_list
|
||||
)
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||
"""
|
||||
Return a normalized key describing operation + kwargs.
|
||||
|
||||
- `add_special_tokens`: {True/False}
|
||||
- `truncation`: {True/False}
|
||||
- If `truncation` is False (`max_length` is None),
|
||||
returns a key for a can_batch queue.
|
||||
- If `truncation` is True and `max_length` is None or equals
|
||||
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||
- Otherwise, returns a key for a cannot_batch queue.
|
||||
|
||||
Examples:
|
||||
- Decode: ("decode",)
|
||||
- Encode typical:
|
||||
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||
- Fallback: ("encode", "other")
|
||||
"""
|
||||
|
||||
if op == "decode":
|
||||
return ("decode",)
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
if not truncation:
|
||||
return "encode", add_special_tokens, False, None
|
||||
|
||||
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||
if max_length is None or (model_max is not None and max_length == model_max):
|
||||
return "encode", add_special_tokens, True, "model_max"
|
||||
|
||||
return "encode", "other"
|
||||
|
||||
def __del__(self):
|
||||
if (
|
||||
(tasks := getattr(self, "_batcher_tasks", None))
|
||||
and (loop := getattr(self, "_loop", None))
|
||||
and not loop.is_closed()
|
||||
):
|
||||
|
||||
def cancel_tasks():
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(cancel_tasks)
|
||||
|
||||
|
||||
def cancel_task_threadsafe(task: Task):
|
||||
if task and not task.done():
|
||||
run_in_loop(task.get_loop(), task.cancel)
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: Executor | None = None,
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""
|
||||
Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
|
||||
if in_loop(loop):
|
||||
function(*args)
|
||||
elif not loop.is_closed():
|
||||
loop.call_soon_threadsafe(function, *args)
|
||||
|
||||
|
||||
def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||
try:
|
||||
return asyncio.get_running_loop() == event_loop
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
if len(iterators) == 1:
|
||||
# Fast-path single iterator case.
|
||||
async for item in iterators[0]:
|
||||
yield 0, item
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
|
||||
try:
|
||||
while awaits:
|
||||
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[loop.create_task(anext(it))] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
|
||||
"""Collect all items from an async generator into a list."""
|
||||
items = []
|
||||
async for item in iterator:
|
||||
items.append(item)
|
||||
return items
|
@ -6,12 +6,10 @@ Contains helpers that are applied to functions.
|
||||
This is similar in concept to the `functools` module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import Any, TypeVar
|
||||
|
||||
@ -32,26 +30,6 @@ def identity(value: T, **kwargs) -> T:
|
||||
return value
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: concurrent.futures.Executor | None = None,
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""
|
||||
Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if wrapper.has_run: # type: ignore[attr-defined]
|
||||
|
@ -29,7 +29,8 @@ from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv
|
||||
from vllm.utils import Device, as_list, cdiv
|
||||
from vllm.utils.async_utils import cancel_task_threadsafe
|
||||
from vllm.utils.func import deprecate_kwargs
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
|
@ -27,9 +27,9 @@ from vllm.utils import (
|
||||
close_sockets,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
in_loop,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.utils.async_utils import in_loop
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
|
Reference in New Issue
Block a user