mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix async scheduling + request preemption (#26385)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
96
tests/v1/e2e/test_async_sched_and_preempt.py
Normal file
96
tests/v1/e2e/test_async_sched_and_preempt.py
Normal file
@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
from ...models.utils import check_outputs_equal
|
||||
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor, and various sampling parameters."""
|
||||
|
||||
first_prompt = (
|
||||
"The following numbers of the sequence "
|
||||
+ ", ".join(str(i) for i in range(10))
|
||||
+ " are:"
|
||||
)
|
||||
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
|
||||
f"Tell me about the number {i}: " for i in range(32)
|
||||
]
|
||||
|
||||
sampling_param_tests: list[dict[str, Any]] = [
|
||||
dict(),
|
||||
# dict(min_tokens=20),
|
||||
# TODO enable these with https://github.com/vllm-project/vllm/pull/26467.
|
||||
# dict(repetition_penalty=0.1),
|
||||
# dict(bad_words=[]),
|
||||
]
|
||||
|
||||
default_params = dict(
|
||||
temperature=0.0, # greedy
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
|
||||
|
||||
outputs = []
|
||||
for test_preemption in [False, True]:
|
||||
for executor in ["uni", "mp"]:
|
||||
for async_scheduling in [False, True]:
|
||||
cache_arg: dict[str, Any] = (
|
||||
dict(num_gpu_blocks_override=32)
|
||||
if test_preemption
|
||||
else dict(gpu_memory_utilization=0.7)
|
||||
)
|
||||
test_config = (
|
||||
f"executor={executor}, preemption={test_preemption},"
|
||||
f" async_sched={async_scheduling}"
|
||||
)
|
||||
print("-" * 80)
|
||||
print(f"---- TESTING: {test_config}")
|
||||
print("-" * 80)
|
||||
with VllmRunner(
|
||||
MODEL,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
async_scheduling=async_scheduling,
|
||||
distributed_executor_backend=executor,
|
||||
dtype="float32", # avoid precision errors
|
||||
**cache_arg,
|
||||
) as vllm_model:
|
||||
results = []
|
||||
for override_params in sampling_param_tests:
|
||||
print(f"----------- RUNNING PARAMS: {override_params}")
|
||||
results.append(
|
||||
vllm_model.generate(
|
||||
example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
**default_params, **override_params
|
||||
),
|
||||
)
|
||||
)
|
||||
outputs.append((test_config, results))
|
||||
|
||||
baseline_config, baseline_tests = outputs[0]
|
||||
|
||||
for test_config, test_outputs in outputs[1:]:
|
||||
for base_outs, test_outs, params in zip(
|
||||
baseline_tests, test_outputs, sampling_param_tests
|
||||
):
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=base_outs,
|
||||
outputs_1_lst=test_outs,
|
||||
name_0=f"baseline=[{baseline_config}], params={params}",
|
||||
name_1=f"config=[{test_config}], params={params}",
|
||||
)
|
||||
|
||||
print(f"PASSED: config=[{test_config}], params={params}")
|
@ -754,6 +754,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Replace the existing block IDs with the new ones.
|
||||
req_state.block_ids = new_block_ids
|
||||
|
||||
if self.use_async_scheduling and num_output_tokens > 0:
|
||||
# We must recover the output token ids for resumed requests in the
|
||||
# async scheduling case, so that correct input_ids are obtained.
|
||||
resumed_token_ids = req_data.resumed_req_token_ids[i]
|
||||
assert resumed_token_ids is not None
|
||||
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
|
||||
if req_index is None:
|
||||
# The request is not in the persistent batch.
|
||||
# The request was either preempted and resumed later, or was not
|
||||
@ -991,7 +997,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if num_commmon_tokens == 0:
|
||||
# No requests in common with the previous iteration
|
||||
# So input_ids_cpu will have all the input ids.
|
||||
# So input_ids.cpu will have all the input ids.
|
||||
return
|
||||
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
|
||||
# Common-case optimization: the batch is unchanged
|
||||
@ -1005,8 +1011,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.enable_prompt_embeds:
|
||||
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||
return
|
||||
# Upload the index tensors asynchronously
|
||||
# so the scatter can be non-blocking.
|
||||
# Upload the index tensors asynchronously so the scatter can be non-blocking.
|
||||
input_ids_index_tensor = torch.tensor(
|
||||
flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
|
Reference in New Issue
Block a user