[Bugfix][Mamba] Fix Multistep on Mamba-like models (#10705)

Signed-off-by: mzusman <mor.zusmann@gmail.com>
This commit is contained in:
Mor Zusman
2024-11-27 21:02:27 +02:00
committed by GitHub
parent b98c62ba49
commit 197b4484a3
4 changed files with 84 additions and 4 deletions

View File

@ -275,6 +275,44 @@ def test_state_cleanup(
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is verifying that multistep works correctly
#on mamba-like models
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
max_tokens: int, example_prompts) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)
with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_outputs_multistep,
outputs_1_lst=vllm_outputs_single_step,
name_0="vllm_outputs_multistep",
name_1="vllm_outputs_single_step",
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])

View File

@ -283,3 +283,39 @@ def test_state_cleanup(
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
max_tokens: int, example_prompts) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)
with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_outputs_multistep,
outputs_1_lst=vllm_outputs_single_step,
name_0="vllm_outputs_multistep",
name_1="vllm_outputs_single_step",
)

View File

@ -300,6 +300,9 @@ class _AsyncLLMEngine(LLMEngine):
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
@ -311,13 +314,13 @@ class _AsyncLLMEngine(LLMEngine):
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
else:
finished_requests_ids = list()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the

View File

@ -1398,6 +1398,9 @@ class LLMEngine:
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
@ -1409,13 +1412,13 @@ class LLMEngine:
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
else:
finished_requests_ids = list()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the