mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix hybrid model tests (#17182)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -531,7 +531,10 @@ class HfRunner:
|
||||
for _, hidden_state in enumerate(hidden_states):
|
||||
last_hidden_states = hidden_state[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states.to(output_embeddings.weight.device),
|
||||
last_hidden_states.to(
|
||||
device=output_embeddings.weight.device,
|
||||
dtype=output_embeddings.weight.dtype,
|
||||
),
|
||||
output_embeddings.weight.t(),
|
||||
)
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
|
@ -6,71 +6,84 @@ from tests.utils import multi_gpu_test
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
from ...utils import check_logprobs_close, check_outputs_equal
|
||||
|
||||
# This test is for the hybrid models
|
||||
MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev", "Zyphra/Zamba2-1.2B-instruct",
|
||||
"pfnet/plamo-2-1b"
|
||||
# NOTE: The first model in each list is taken as the primary model,
|
||||
# meaning that it will be used in all tests in this file
|
||||
# The rest of the models will only be tested by test_models
|
||||
|
||||
SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
# "mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
# Bamba at Fp32 is too big for the CI (L4 GPU).
|
||||
# MODELS = ["ai21labs/Jamba-tiny-dev", "ibm-ai-platform/Bamba-9B"]
|
||||
# Note: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
# not compatible with pip-compile.
|
||||
|
||||
HYBRID_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
# NOTE: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
# not compatible with pip-compile.
|
||||
"pfnet/plamo-2-1b",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"ibm-ai-platform/Bamba-9B",
|
||||
]
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
# numeric error produces different generation
|
||||
if "Bamba" in model:
|
||||
example_prompts.pop(3)
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_batching(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
# To pass the small model tests, we need full precision.
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
for_loop_outputs.append(
|
||||
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
||||
single_output, = vllm_model.generate_greedy_logprobs([prompt],
|
||||
max_tokens,
|
||||
num_logprobs)
|
||||
for_loop_outputs.append(single_output)
|
||||
|
||||
batched_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
batched_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_outputs_equal(
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=for_loop_outputs,
|
||||
outputs_1_lst=batched_outputs,
|
||||
name_0="for_loop_vllm",
|
||||
@ -78,72 +91,35 @@ def test_batching(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_mamba_prefill_chunking_with_parallel_sampling(
|
||||
hf_runner, vllm_runner, example_prompts, model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# Tests prefill chunking in conjunction with n>1, in this case,
|
||||
# prefill is populated with decoding tokens and we test that it
|
||||
# doesn't fail This test might fail if cache is not allocated
|
||||
# correctly for n > 1 decoding steps inside a
|
||||
# chunked prefill forward pass (where we have both prefills
|
||||
# and decoding together )
|
||||
|
||||
if 'plamo-2' in model:
|
||||
dtype = "float" # use a different dtype for plamo
|
||||
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=30,
|
||||
max_num_seqs=10 # forces prefill chunks with decoding
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [7])
|
||||
def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# numeric error during prefill chunking produces different generation
|
||||
# compared to w/o prefill chunking for those examples, removed them for now
|
||||
if "Jamba" in model:
|
||||
example_prompts.pop(7)
|
||||
example_prompts.pop(2)
|
||||
example_prompts.pop(1)
|
||||
elif "Bamba" in model:
|
||||
example_prompts.pop(6)
|
||||
example_prompts.pop(3)
|
||||
example_prompts.pop(2)
|
||||
dtype = "half" # use a different dtype for Bamba
|
||||
|
||||
elif "Zamba2" in model:
|
||||
example_prompts.pop(7)
|
||||
dtype = "half"
|
||||
elif "plamo-2-1b" in model:
|
||||
example_prompts.pop(7)
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
non_chunked = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
def test_chunked_prefill(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=5,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=max_tokens)
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens, num_logprobs)
|
||||
|
||||
check_outputs_equal(
|
||||
with vllm_runner(model,
|
||||
enable_chunked_prefill=False,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
non_chunked = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=chunked,
|
||||
outputs_1_lst=non_chunked,
|
||||
name_0="chunked",
|
||||
@ -151,57 +127,51 @@ def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [15])
|
||||
def test_parallel_sampling(
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_chunked_prefill_with_parallel_sampling(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Tests chunked prefill in conjunction with n > 1.
|
||||
|
||||
In this case, prefill is populated with decoding tokens and
|
||||
we test that it doesn't fail.
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
for_loop_outputs = []
|
||||
for _ in range(10):
|
||||
for_loop_outputs.append(
|
||||
# using example_prompts index 1 instead of 0 since with 0 the
|
||||
# logprobs get really close and the test doesn't pass
|
||||
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
|
||||
[0])
|
||||
sampling_params = SamplingParams(n=10,
|
||||
temperature=0.001,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
|
||||
sampling_params)
|
||||
token_ids, texts = n_lt_1_outputs[0]
|
||||
n_lt_1_outputs = [(token_id, text)
|
||||
for token_id, text in zip(token_ids, texts)]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=n_lt_1_outputs,
|
||||
outputs_1_lst=for_loop_outputs,
|
||||
name_0="vllm_n_lt_1_outputs",
|
||||
name_1="vllm",
|
||||
)
|
||||
This test might fail if cache is not allocated correctly for n > 1
|
||||
decoding steps inside a chunked prefill forward pass
|
||||
(where we have both prefill and decode together)
|
||||
"""
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=True,
|
||||
# forces prefill chunks with decoding
|
||||
max_num_batched_tokens=MAX_NUM_SEQS * 3,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_mamba_cache_cg_padding(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
"""
|
||||
This test is for verifying that mamba cache is padded to CG captured
|
||||
batch size. If it's not, a torch RuntimeError will be raised because
|
||||
tensor dimensions aren't compatible.
|
||||
"""
|
||||
vllm_config = EngineArgs(model=model,
|
||||
trust_remote_code=True).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
@ -209,7 +179,7 @@ def test_mamba_cache_cg_padding(
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
except RuntimeError:
|
||||
pytest.fail(
|
||||
@ -218,28 +188,24 @@ def test_mamba_cache_cg_padding(
|
||||
"Could be related to mamba cache not padded correctly")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_models_preemption_recompute(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# Tests that outputs are identical with and w/o preemtions (recompute)
|
||||
assert dtype == "float"
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
"""
|
||||
Tests that outputs are identical with and w/o preemptions (recompute).
|
||||
"""
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
scheduler = vllm_model.model.llm_engine.scheduler[0]
|
||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = False
|
||||
scheduler.ENABLE_ARTIFICIAL_PREEMPT = False
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
@ -250,40 +216,43 @@ def test_models_preemption_recompute(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
model: str,
|
||||
) -> None:
|
||||
# This test is for verifying that the hybrid inner state management doesn't
|
||||
# collapse in case where the number of incoming requests and
|
||||
# finished_requests_ids is larger than the maximum mamba block capacity.
|
||||
# This could generally happen due to the fact that hybrid does support
|
||||
# statelessness mechanism where it can cleanup new incoming requests in
|
||||
# a single step.
|
||||
"""
|
||||
This test is for verifying that the hybrid inner state management doesn't
|
||||
collapse in case where the number of incoming requests and
|
||||
finished_requests_ids is larger than the maximum mamba block capacity.
|
||||
|
||||
This could generally happen due to the fact that hybrid does support
|
||||
statelessness mechanism where it can cleanup new incoming requests in
|
||||
a single step.
|
||||
"""
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
|
||||
except ValueError:
|
||||
pytest.fail("Hybrid inner state wasn't cleaned up properly between"
|
||||
"steps finished requests registered unnecessarily ")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
def test_state_cleanup(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
model: str,
|
||||
) -> None:
|
||||
# This test is for verifying that the Hybrid state is cleaned up between
|
||||
# steps, If its not cleaned, an error would be expected.
|
||||
"""
|
||||
This test is for verifying that the Hybrid state is cleaned up between
|
||||
steps.
|
||||
|
||||
If its not cleaned, an error would be expected.
|
||||
"""
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
for _ in range(10):
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||
except ValueError:
|
||||
@ -291,28 +260,14 @@ def test_state_cleanup(
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_multistep(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
# This test is verifying that multistep works correctly
|
||||
#on mamba-like models
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RE-ENABLE: test is currently failing on main.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
|
||||
max_tokens: int, example_prompts) -> None:
|
||||
def test_multistep_correctness(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_multistep = vllm_model.generate_greedy(
|
||||
@ -332,18 +287,21 @@ def test_multistep_correctness(vllm_runner, model: str, dtype: str,
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_hybrid_distributed_produces_identical_generation(
|
||||
vllm_runner, model: str, dtype: str, max_tokens: int,
|
||||
example_prompts) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model:
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, tensor_parallel_size=2,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model:
|
||||
with vllm_runner(model, tensor_parallel_size=1,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
|
@ -1,337 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling for Mamba.
|
||||
|
||||
Run `pytest tests/models/test_mamba.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
# "mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
|
||||
# Use lower-level interfaces to create this greedy generator, as mamba will
|
||||
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
|
||||
def generate_greedy(model_name, example_prompts, max_tokens):
|
||||
# Create a text generation pipeline
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
# Set the device (GPU if available, else CPU)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
# Generate texts from the prompts
|
||||
outputs = []
|
||||
for prompt in example_prompts:
|
||||
# Tokenize the input prompt with truncation
|
||||
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
||||
input_ids = inputs["input_ids"].to(model.device)
|
||||
|
||||
# Generate text using the model's generate method directly
|
||||
generated_ids = model.generate(input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
do_sample=False)
|
||||
generated_text = tokenizer.decode(generated_ids[0],
|
||||
skip_special_tokens=True)
|
||||
|
||||
outputs.append((generated_ids[0].tolist(), generated_text))
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
|
||||
|
||||
# Set max_num_seqs to keep Codestral from going OOM at fp32
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_batching(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# To pass the small model tests, we need full precision.
|
||||
for_loop_outputs = []
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
for prompt in example_prompts:
|
||||
for_loop_outputs.append(
|
||||
vllm_model.generate_greedy([prompt], max_tokens)[0])
|
||||
|
||||
batched_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=for_loop_outputs,
|
||||
outputs_1_lst=batched_outputs,
|
||||
name_0="for_loop_vllm",
|
||||
name_1="batched_vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_chunked_prefill_with_parallel_sampling(vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int) -> None:
|
||||
# Tests chunked prefill in conjunction with n>1. In this case, prefill is
|
||||
# populated with decoding tokens and we test that it doesn't fail.
|
||||
# This test might fail if cache is not allocated correctly for n > 1
|
||||
# decoding steps inside a chunked prefill forward pass (where we have both
|
||||
# prefill and decode together )
|
||||
sampling_params = SamplingParams(n=3,
|
||||
temperature=1,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=30,
|
||||
max_num_seqs=10 # forces prefill chunks with decoding
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
def test_chunked_prefill(vllm_runner, example_prompts, model: str, dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int) -> None:
|
||||
"""
|
||||
Checks exact match decode between huggingface model and vllm runner with
|
||||
chunked prefill.
|
||||
"""
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
non_chunked = generate_greedy(model, example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
chunked = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=chunked,
|
||||
outputs_1_lst=non_chunked,
|
||||
name_0="chunked",
|
||||
name_1="non_chunked",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [15])
|
||||
def test_parallel_sampling(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
# Numerical differences produce slightly different output for these
|
||||
if 'state-spaces' in model:
|
||||
example_prompts.pop(0)
|
||||
example_prompts.pop(0)
|
||||
example_prompts.pop(0)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
for_loop_outputs = []
|
||||
for _ in range(10):
|
||||
for_loop_outputs.append(
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)[0])
|
||||
sampling_params = SamplingParams(n=10,
|
||||
temperature=0.001,
|
||||
seed=0,
|
||||
max_tokens=max_tokens)
|
||||
n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
|
||||
token_ids, texts = n_lt_1_outputs[0]
|
||||
n_lt_1_outputs = [(token_id, text)
|
||||
for token_id, text in zip(token_ids, texts)]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=n_lt_1_outputs,
|
||||
outputs_1_lst=for_loop_outputs,
|
||||
name_0="vllm_n_lt_1_outputs",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_mamba_cache_cg_padding(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# This test is for verifying that mamba cache is padded to CG captured
|
||||
# batch size. If it's not, a torch RuntimeError will be raised because
|
||||
# tensor dimensions aren't compatible
|
||||
vllm_config = EngineArgs(model=model).create_engine_config()
|
||||
while len(example_prompts) == vllm_config.pad_for_cudagraph(
|
||||
len(example_prompts)):
|
||||
example_prompts.append(example_prompts[0])
|
||||
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
except RuntimeError:
|
||||
pytest.fail(
|
||||
"Couldn't run batch size which is not equal to a Cuda Graph "
|
||||
"captured batch size. "
|
||||
"Could be related to mamba cache not padded correctly")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
def test_models_preemption_recompute(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# Tests that outputs are identical with and w/o preemtions (recompute)
|
||||
assert dtype == "float"
|
||||
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = True
|
||||
preempt_vllm_outputs = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
vllm_model.model.llm_engine.scheduler[
|
||||
0].ENABLE_ARTIFICIAL_PREEMPT = False
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=preempt_vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="vllm_preepmtions",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
# This test is for verifying that the Mamba inner state management doesn't
|
||||
# collapse in case where the number of incoming requests and
|
||||
# finished_requests_ids is larger than the maximum Mamba block capacity.
|
||||
# This could generally happen due to the fact that Mamba does support
|
||||
# statelessness mechanism where it can cleanup new incoming requests in
|
||||
# a single step.
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=10) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
|
||||
except ValueError:
|
||||
pytest.fail("Mamba inner state wasn't cleaned up properly between"
|
||||
"steps finished requests registered unnecessarily ")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_state_cleanup(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
# This test is for verifying that the Mamba state is cleaned up between
|
||||
# steps, If its not cleaned, an error would be expected.
|
||||
try:
|
||||
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
|
||||
for _ in range(10):
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
|
||||
except ValueError:
|
||||
pytest.fail("Mamba inner state wasn't cleaned up between states, "
|
||||
"could be related to finished_requests_ids")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
def test_multistep(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
dtype: str,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
|
||||
max_tokens: int, example_prompts) -> None:
|
||||
with vllm_runner(model, num_scheduler_steps=8,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_multistep = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, num_scheduler_steps=1,
|
||||
max_num_seqs=2) as vllm_model:
|
||||
vllm_outputs_single_step = vllm_model.generate_greedy(
|
||||
example_prompts, max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_outputs_multistep,
|
||||
outputs_1_lst=vllm_outputs_single_step,
|
||||
name_0="vllm_outputs_multistep",
|
||||
name_1="vllm_outputs_single_step",
|
||||
)
|
Reference in New Issue
Block a user