mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][Mamba] Fix Multistep on Mamba-like models (#10705)
Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
@ -275,6 +275,44 @@ def test_state_cleanup(
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_multistep(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
# This test is verifying that multistep works correctly
|
||||
#on mamba-like models
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
|
||||
max_tokens: int, example_prompts) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_multistep = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, num_scheduler_steps=1,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_single_step = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_outputs_multistep,
|
||||
outputs_1_lst=vllm_outputs_single_step,
|
||||
name_0="vllm_outputs_multistep",
|
||||
name_1="vllm_outputs_single_step",
|
||||
)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
|
@ -283,3 +283,39 @@ def test_state_cleanup(
|
||||
except ValueError:
|
||||
pytest.fail("Mamba inner state wasn't cleaned up between states, "
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_multistep(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
|
||||
max_tokens: int, example_prompts) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_multistep = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, num_scheduler_steps=1,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_single_step = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_outputs_multistep,
|
||||
outputs_1_lst=vllm_outputs_single_step,
|
||||
name_0="vllm_outputs_multistep",
|
||||
name_1="vllm_outputs_single_step",
|
||||
)
|
||||
|
@ -300,6 +300,9 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
@ -311,13 +314,13 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
else:
|
||||
finished_requests_ids = list()
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
|
@ -1398,6 +1398,9 @@ class LLMEngine:
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Maybe switch from async mode to sync mode
|
||||
if not allow_async_output_proc and len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
@ -1409,13 +1412,13 @@ class LLMEngine:
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine, seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc)
|
||||
else:
|
||||
finished_requests_ids = list()
|
||||
|
||||
assert seq_group_metadata_list is not None
|
||||
assert scheduler_outputs is not None
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
finished_requests_ids = self.scheduler[
|
||||
virtual_engine].get_and_reset_finished_requests_ids()
|
||||
|
||||
# Check if we have a cached last_output from the previous iteration.
|
||||
# For supporting PP this is probably the best way to pass the
|
||||
|
Reference in New Issue
Block a user