mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] fix num_lookahead_slots missing in async executor (#4165)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
@ -1,10 +1,127 @@
|
||||
from typing import List, Tuple
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from tests.conftest import cleanup
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
|
||||
class AsyncLLM:
|
||||
"""AsyncLLM
|
||||
|
||||
Note: Current LLM class in vllm don't support async mode, for test purpose,
|
||||
we implement async one in here. Maybe we could move to
|
||||
vllm/entrypoints/llm.py in future.
|
||||
|
||||
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
|
||||
to make to work in async mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
self.engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
engine_use_ray=True,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
|
||||
llm_engine = AsyncLLMEngine.from_engine_args(
|
||||
self.engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
if prompts is None:
|
||||
raise ValueError("prompts must be provided.")
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
elif isinstance(sampling_params,
|
||||
list) and len(sampling_params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and "
|
||||
"sampling_params must be the same.")
|
||||
|
||||
async def get_output(prompt, sampling_param) -> str:
|
||||
request_id = random_uuid()
|
||||
results_generator = llm_engine.generate(prompt, sampling_param,
|
||||
request_id)
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
return final_output
|
||||
|
||||
outputs = []
|
||||
try:
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
res = asyncio.run(get_output(prompt, sampling_params))
|
||||
outputs.append(res)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
|
||||
def generator_inner():
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
use_async = False
|
||||
if "use_async" in kwargs:
|
||||
use_async = kwargs.pop("use_async")
|
||||
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
|
@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
try:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
finally:
|
||||
# we need to free up ray resource,
|
||||
# so that latter test could use the gpu we allocated here
|
||||
import ray
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -40,17 +40,24 @@ from .conftest import get_output_from_llm_generator
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
[
|
||||
{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# whether use AsyncLLM engine
|
||||
"use_async": async_mode,
|
||||
}
|
||||
# Try both async and sync engine execution
|
||||
for async_mode in [True, False]
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
|
@ -211,9 +211,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
output = await self.model_executor.execute_model_async(
|
||||
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
||||
seq_group_metadata_list,
|
||||
scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy)
|
||||
scheduler_outputs.blocks_to_copy,
|
||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
||||
else:
|
||||
output = []
|
||||
|
||||
|
@ -109,12 +109,14 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy)
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
|
@ -112,6 +112,7 @@ class ExecutorAsyncBase(ExecutorBase):
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Executes one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
@ -163,10 +163,12 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy)
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
return output
|
||||
|
@ -84,6 +84,7 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list, )
|
||||
|
@ -196,6 +196,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
"num_lookahead_slots": num_lookahead_slots,
|
||||
},
|
||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||
|
||||
|
Reference in New Issue
Block a user