[V1][Mamba] - Enable V1 by default for Mamba Models (#23650)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-27 23:53:30 +03:00
committed by GitHub
parent 8bf6266a17
commit 853c371fc3
3 changed files with 70 additions and 83 deletions

View File

@ -100,21 +100,19 @@ def test_models(
else:
hf_outputs = None
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
@ -137,7 +135,7 @@ def test_models(
)
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_batching(
@ -147,10 +145,6 @@ def test_batching(
max_tokens: int,
num_logprobs: int,
) -> None:
if model in V0_UNSUPPORTED_MODELS:
pytest.skip(
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
@ -188,29 +182,32 @@ def test_chunked_prefill(
max_tokens: int,
num_logprobs: int,
chunked_prefill_token_size: int,
monkeypatch,
) -> None:
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
with vllm_runner(model,
enable_chunked_prefill=True,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model,
enable_chunked_prefill=True,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs) as vllm_model:
chunked = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model,
enable_chunked_prefill=False,
max_num_seqs=max_num_seqs) as vllm_model:
non_chunked = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model,
enable_chunked_prefill=False,
max_num_seqs=max_num_seqs) as vllm_model:
non_chunked = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
name_0="chunked",
name_1="non_chunked",
)
check_logprobs_close(
outputs_0_lst=chunked,
outputs_1_lst=non_chunked,
name_0="chunked",
name_1="non_chunked",
)
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@ -281,25 +278,29 @@ def test_models_preemption_recompute(
example_prompts,
model: str,
max_tokens: int,
monkeypatch,
) -> None:
"""
Tests that outputs are identical with and w/o preemptions (recompute).
"""
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
scheduler = vllm_model.llm.llm_engine.scheduler[0]
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
scheduler = vllm_model.llm.llm_engine.scheduler[0]
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)
check_outputs_equal(
outputs_0_lst=preempt_vllm_outputs,
outputs_1_lst=vllm_outputs,
name_0="vllm_preepmtions",
name_1="vllm",
)
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@ -402,24 +403,18 @@ def test_full_cuda_graph(
else:
hf_outputs = None
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
compilation_config={'full_cuda_graph': True},
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
m.setenv("VLLM_USE_V1", "0")
if model not in V0_UNSUPPORTED_MODELS:
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v0_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None and vllm_v0_outputs is not None:
check_logprobs_close(
@ -466,24 +461,20 @@ def test_fp32_state(
else:
hf_outputs = None
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32") as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
mamba_ssm_cache_dtype="float32",
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,

View File

@ -1463,11 +1463,6 @@ class EngineArgs:
recommend_to_remove=False)
return False
# V1 mamba models are unoptimized.
if model_config.has_inner_state and _warn_or_fallback(
feature_name="Mamba"):
return False
# No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills

View File

@ -417,4 +417,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
}