[BugFix] fix num_lookahead_slots missing in async executor (#4165)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
leiwen83
2024-05-01 01:12:59 +08:00
committed by GitHub
parent 26f2fb5113
commit 4bb53e2dde
9 changed files with 163 additions and 19 deletions

View File

@ -1,10 +1,127 @@
from typing import List, Tuple
import asyncio
from typing import List, Optional, Tuple, Union
import pytest
import ray
from tests.conftest import cleanup
from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
class AsyncLLM:
"""AsyncLLM
Note: Current LLM class in vllm don't support async mode, for test purpose,
we implement async one in here. Maybe we could move to
vllm/entrypoints/llm.py in future.
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
to make to work in async mode.
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
self.engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
engine_use_ray=True,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.request_counter = Counter()
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
llm_engine = AsyncLLMEngine.from_engine_args(
self.engine_args, usage_context=UsageContext.LLM_CLASS)
if prompts is None:
raise ValueError("prompts must be provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None:
num_requests = len(prompts)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and "
"sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str:
request_id = random_uuid()
results_generator = llm_engine.generate(prompt, sampling_param,
request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
return final_output
outputs = []
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
res = asyncio.run(get_output(prompt, sampling_params))
outputs.append(res)
finally:
ray.shutdown()
return outputs
@pytest.fixture
@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
use_async = False
if "use_async" in kwargs:
use_async = kwargs.pop("use_async")
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
set_random_seed(seed)
yield llm

View File

@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
temperature=temperature,
)
with pytest.raises(AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()
@pytest.mark.parametrize(

View File

@ -40,17 +40,24 @@ from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
[
{
# Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
# Required for spec decode.
"use_v2_block_manager": True,
# whether use AsyncLLM engine
"use_async": async_mode,
}
# Try both async and sync engine execution
for async_mode in [True, False]
])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[

View File

@ -211,9 +211,11 @@ class _AsyncLLMEngine(LLMEngine):
if not scheduler_outputs.is_empty():
# Execute the model.
output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
seq_group_metadata_list,
scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy)
scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
else:
output = []

View File

@ -109,12 +109,14 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots)
return output
async def check_health_async(self) -> None:

View File

@ -112,6 +112,7 @@ class ExecutorAsyncBase(ExecutorBase):
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError

View File

@ -163,10 +163,12 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots)
return output

View File

@ -84,6 +84,7 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, )

View File

@ -196,6 +196,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
"num_lookahead_slots": num_lookahead_slots,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)